]> git.dolda2000.com Git - wrw.git/commitdiff
Handle multipart/form-data forms directly in formparse. master
authorFredrik Tolf <fredrik@dolda2000.com>
Tue, 22 Jul 2025 22:18:08 +0000 (00:18 +0200)
committerFredrik Tolf <fredrik@dolda2000.com>
Tue, 22 Jul 2025 22:18:08 +0000 (00:18 +0200)
wrw/form.py

index 85114680d59d271100994fa6f9d4d97fa05e2cc5..67d766e07ec6918ed68b65744349ac0ebc28081f 100644 (file)
@@ -1,25 +1,29 @@
-import urllib.parse
+import urllib.parse, codecs
 from . import proto
 
 __all__ = ["formdata"]
 
-def formparse(req):
-    buf = {}
-    buf.update(urllib.parse.parse_qsl(req.query, keep_blank_values=True))
-    ctype, ctpars = proto.pmimehead(req.ihead.get("Content-Type", ""))
-    if ctype == "application/x-www-form-urlencoded":
-        try:
-            rbody = req.input.read(2 ** 20)
-        except IOError as exc:
-            return exc
-        if len(rbody) >= 2 ** 20:
-            return ValueError("x-www-form-urlencoded data is absurdly long")
-        buf.update(urllib.parse.parse_qsl(rbody.decode("latin1"), encoding=ctpars.get("charset", "utf-8"), keep_blank_values=True))
-    return buf
-
 class badmultipart(IOError):
     pass
 
+def getcharset(name, default=None):
+    try:
+        return codecs.lookup(name).name
+    except LookupError:
+        return default
+
+def trydecode(raw, charset):
+    if charset:
+        try:
+            return raw.decode(charset)
+        except UnicodeDecodeError as exc:
+            raise IOError("charset error") from exc
+    else:
+        try:
+            return raw.decode("utf-8")
+        except UnicodeDecodeError as exc:
+            return raw.decode("latin1")
+
 class formpart(object):
     def __init__(self, form):
         self.form = form
@@ -102,7 +106,7 @@ class formpart(object):
             if ln[-1] != ord(b'\n'):
                 raise badmultipart("Too long header line in part")
             try:
-                return ln.decode(charset).rstrip()
+                return trydecode(ln, charset).rstrip()
             except UnicodeError:
                 raise badmultipart("Form part header is not in assumed charset")
 
@@ -130,13 +134,13 @@ class formpart(object):
         self.filename = par.get("filename")
         val, par = proto.pmimehead(self.head.get("content-type", ""))
         self.ctype = val
-        self.charset = par.get("charset")
+        self.charset = getcharset(par.get("charset", ""), "utf-8")
         encoding = self.head.get("content-transfer-encoding", "binary")
         if encoding != "binary":
             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
 
 class multipart(object):
-    def __init__(self, req, charset):
+    def __init__(self, req, charset=None):
         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
         if req.method != "POST" or val != "multipart/form-data":
             raise badmultipart("Request is not a multipart form")
@@ -150,6 +154,7 @@ class multipart(object):
         self.buf = b"\r\n"
         self.eof = False
         self.headcs = charset
+        self.bodycs = charset
         self.lastpart = formpart(self)
         self.lastpart.close()
 
@@ -163,8 +168,37 @@ class multipart(object):
             raise StopIteration()
         self.lastpart = formpart(self)
         self.lastpart.parsehead(self.headcs)
+        if self.lastpart.name == "_charset_":
+            with self.lastpart:
+                self.bodycs = getcharset(self.lastpart.read(256).decode("latin1"))
+            return self.__next__()
         return self.lastpart
 
+def formparse(req):
+    buf = {}
+    buf.update(urllib.parse.parse_qsl(req.query, keep_blank_values=True))
+    ctype, ctpars = proto.pmimehead(req.ihead.get("Content-Type", ""))
+    limit = 2 ** 20
+    if ctype == "application/x-www-form-urlencoded":
+        try:
+            rbody = req.input.read(limit)
+        except IOError as exc:
+            return exc
+        if len(rbody) >= limit:
+            return IOError("x-www-form-urlencoded data is absurdly long")
+        buf.update(urllib.parse.parse_qsl(rbody.decode("latin1"), encoding=getcharset(ctpars.get("charset", ""), "utf-8"), keep_blank_values=True))
+    elif ctype == "multipart/form-data":
+        tlen = 0
+        for part in multipart(req):
+            with part:
+                left = limit - tlen
+                pbody = part.read(left)
+                if len(pbody) >= left:
+                    raise IOError("multipart/form-data body is absurdly long")
+                tlen += len(pbody)
+                buf[part.name] = trydecode(pbody, part.form.bodycs)
+    return buf
+
 def formdata(req, onerror=Exception):
     data = req.item(formparse)
     if isinstance(data, Exception):