Make limitreader more compatible with general io classes.
authorFredrik Tolf <fredrik@dolda2000.com>
Sun, 11 Aug 2024 02:17:31 +0000 (04:17 +0200)
committerFredrik Tolf <fredrik@dolda2000.com>
Sun, 11 Aug 2024 02:17:31 +0000 (04:17 +0200)
wrw/req.py

index 010e907..d5e22a3 100644 (file)
@@ -56,9 +56,10 @@ class shortinput(IOError, EOFError):
         super().__init__("Unexpected EOF")
 
 class limitreader(object):
-    def __init__(self, back, limit):
+    def __init__(self, back, limit, short=False):
         self.bk = back
         self.limit = limit
+        self.short = short
         self.rb = 0
         self.buf = bytearray()
 
@@ -72,11 +73,15 @@ class limitreader(object):
         while len(self.buf) < ra:
             ret = self.bk.read(ra - len(self.buf))
             if ret == b"":
+                if self.short:
+                    ret = bytes(self.buf)
+                    self.buf[:] = b""
+                    return ret
                 raise shortinput()
             self.buf.extend(ret)
             self.rb += len(ret)
         ret = bytes(self.buf[:ra])
-        self.buf = self.buf[ra:]
+        self.buf[:ra] = b""
         return ret
 
     def readline(self, size=-1):
@@ -85,16 +90,16 @@ class limitreader(object):
             p = self.buf.find(b'\n', off)
             if p >= 0:
                 ret = bytes(self.buf[:p + 1])
-                self.buf = self.buf[p + 1:]
+                self.buf[:p + 1] = b""
                 return ret
             off = len(self.buf)
             if size >= 0 and len(self.buf) >= size:
                 ret = bytes(self.buf[:size])
-                self.buf = self.buf[size:]
+                self.buf[:size] = b""
                 return ret
             if self.rb == self.limit:
                 ret = bytes(self.buf)
-                self.buf = bytearray()
+                self.buf[:] = b""
                 return ret
             ra = self.limit - self.rb
             if size >= 0:
@@ -102,6 +107,10 @@ class limitreader(object):
             ra = min(ra, 1024)
             ret = self.bk.read(ra)
             if ret == b"":
+                if self.short:
+                    ret = bytes(self.buf)
+                    self.buf[:] = b""
+                    return ret
                 raise shortinput()
             self.buf.extend(ret)
             self.rb += len(ret)
@@ -120,6 +129,16 @@ class limitreader(object):
                 return ret
         return lineiter()
 
+    def readable(self):
+        return True
+    def writable(self):
+        return False
+    def seekable(self):
+        return False
+    @property
+    def closed(self):
+        return self.bk.closed
+
 class request(object):
     def copy(self):
         return copyrequest(self)