d5e22a3f9175010ff9bc30419be77735ab90011c
[wrw.git] / wrw / req.py
1 import io
2
3 __all__ = ["request"]
4
5 class headdict(object):
6     def __init__(self):
7         self.dict = {}
8
9     def __getitem__(self, key):
10         return self.dict[key.lower()][1]
11
12     def __setitem__(self, key, val):
13         self.dict[key.lower()] = [key, val]
14
15     def __contains__(self, key):
16         return key.lower() in self.dict
17
18     def __delitem__(self, key):
19         del self.dict[key.lower()]
20
21     def __iter__(self):
22         return iter((list[0] for list in self.dict.values()))
23     
24     def get(self, key, default=""):
25         if key.lower() in self.dict:
26             return self.dict[key.lower()][1]
27         return default
28
29     def getlist(self, key):
30         return self.dict.setdefault(key.lower(), [key])[1:]
31
32     def add(self, key, val):
33         self.dict.setdefault(key.lower(), [key]).append(val)
34
35     def __repr__(self):
36         return repr(self.dict)
37
38     def __str__(self):
39         return str(self.dict)
40
41 def fixcase(str):
42     str = str.lower()
43     i = 0
44     b = True
45     while i < len(str):
46         if b:
47             str = str[:i] + str[i].upper() + str[i + 1:]
48         b = False
49         if str[i] == '-':
50             b = True
51         i += 1
52     return str
53
54 class shortinput(IOError, EOFError):
55     def __init__(self):
56         super().__init__("Unexpected EOF")
57
58 class limitreader(object):
59     def __init__(self, back, limit, short=False):
60         self.bk = back
61         self.limit = limit
62         self.short = short
63         self.rb = 0
64         self.buf = bytearray()
65
66     def close(self):
67         pass
68
69     def read(self, size=-1):
70         ra = self.limit - self.rb
71         if size >= 0:
72             ra = min(ra, size)
73         while len(self.buf) < ra:
74             ret = self.bk.read(ra - len(self.buf))
75             if ret == b"":
76                 if self.short:
77                     ret = bytes(self.buf)
78                     self.buf[:] = b""
79                     return ret
80                 raise shortinput()
81             self.buf.extend(ret)
82             self.rb += len(ret)
83         ret = bytes(self.buf[:ra])
84         self.buf[:ra] = b""
85         return ret
86
87     def readline(self, size=-1):
88         off = 0
89         while True:
90             p = self.buf.find(b'\n', off)
91             if p >= 0:
92                 ret = bytes(self.buf[:p + 1])
93                 self.buf[:p + 1] = b""
94                 return ret
95             off = len(self.buf)
96             if size >= 0 and len(self.buf) >= size:
97                 ret = bytes(self.buf[:size])
98                 self.buf[:size] = b""
99                 return ret
100             if self.rb == self.limit:
101                 ret = bytes(self.buf)
102                 self.buf[:] = b""
103                 return ret
104             ra = self.limit - self.rb
105             if size >= 0:
106                 ra = min(ra, size)
107             ra = min(ra, 1024)
108             ret = self.bk.read(ra)
109             if ret == b"":
110                 if self.short:
111                     ret = bytes(self.buf)
112                     self.buf[:] = b""
113                     return ret
114                 raise shortinput()
115             self.buf.extend(ret)
116             self.rb += len(ret)
117
118     def readlines(self, hint=None):
119         return list(self)
120
121     def __iter__(rd):
122         class lineiter(object):
123             def __iter__(self):
124                 return self
125             def __next__(self):
126                 ret = rd.readline()
127                 if ret == b"":
128                     raise StopIteration()
129                 return ret
130         return lineiter()
131
132     def readable(self):
133         return True
134     def writable(self):
135         return False
136     def seekable(self):
137         return False
138     @property
139     def closed(self):
140         return self.bk.closed
141
142 class request(object):
143     def copy(self):
144         return copyrequest(self)
145
146     def shift(self, n):
147         new = self.copy()
148         new.uriname = self.uriname + self.pathinfo[:n]
149         new.pathinfo = self.pathinfo[n:]
150         return new
151
152 class origrequest(request):
153     def __init__(self, env):
154         self.env = env
155         self.method = env["REQUEST_METHOD"].upper()
156         self.uriname = env["SCRIPT_NAME"]
157         self.filename = env.get("SCRIPT_FILENAME")
158         self.uri = env["REQUEST_URI"]
159         self.pathinfo = env["PATH_INFO"]
160         self.query = env["QUERY_STRING"]
161         self.remoteaddr = env["REMOTE_ADDR"]
162         self.serverport = env["SERVER_PORT"]
163         self.servername = env["SERVER_NAME"]
164         self.https = "HTTPS" in env
165         self.ihead = headdict()
166         if "CONTENT_TYPE" in env:
167             self.ihead["Content-Type"] = env["CONTENT_TYPE"]
168             if "CONTENT_LENGTH" in env:
169                 clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
170                 if clen.isdigit():
171                     self.input = limitreader(env["wsgi.input"], int(clen))
172                 else:
173                     # XXX: What to do?
174                     self.input = io.BytesIO(b"")
175             else:
176                 # Assume input is chunked and read until ordinary EOF.
177                 self.input = env["wsgi.input"]
178         else:
179             self.input = None
180         self.ohead = headdict()
181         for k, v in env.items():
182             if k[:5] == "HTTP_":
183                 self.ihead.add(fixcase(k[5:].replace("_", "-")), v)
184         self.items = {}
185         self.statuscode = (200, "OK")
186         self.ohead["Content-Type"] = "text/html"
187         self.resources = set()
188         self.clean = set()
189         self.commitfuns = []
190
191     def status(self, code, msg):
192         self.statuscode = code, msg
193
194     def item(self, id):
195         if id in self.items:
196             return self.items[id]
197         self.items[id] = new = id(self)
198         if hasattr(new, "__enter__") and hasattr(new, "__exit__"):
199             self.withres(new)
200         return new
201
202     def withres(self, res):
203         if res not in self.resources:
204             done = False
205             res.__enter__()
206             try:
207                 self.resources.add(res)
208                 self.clean.add(res.__exit__)
209                 done = True
210             finally:
211                 if not done:
212                     res.__exit__(None, None, None)
213                     self.resources.discard(res)
214
215     def cleanup(self):
216         def clean1(list):
217             if len(list) > 0:
218                 try:
219                     list[0]()
220                 finally:
221                     clean1(list[1:])
222         clean1(list(self.clean))
223
224     def oncommit(self, fn):
225         if fn not in self.commitfuns:
226             self.commitfuns.append(fn)
227
228     def commit(self, startreq):
229         for fun in reversed(self.commitfuns):
230             fun(self)
231         hdrs = []
232         for nm in self.ohead:
233             for val in self.ohead.getlist(nm):
234                 hdrs.append((nm, val))
235         startreq("%s %s" % self.statuscode, hdrs)
236
237     def topreq(self):
238         return self
239
240 class copyrequest(request):
241     def __init__(self, p):
242         self.parent = p
243         self.top = p.topreq()
244         self.env = p.env
245         self.method = p.method
246         self.uriname = p.uriname
247         self.filename = p.filename
248         self.uri = p.uri
249         self.pathinfo = p.pathinfo
250         self.query = p.query
251         self.remoteaddr = p.remoteaddr
252         self.serverport = p.serverport
253         self.https = p.https
254         self.ihead = p.ihead
255         self.ohead = p.ohead
256         self.input = p.input
257
258     def status(self, code, msg):
259         return self.parent.status(code, msg)
260
261     def item(self, id):
262         return self.top.item(id)
263
264     def withres(self, res):
265         return self.top.withres(res)
266
267     def oncommit(self, fn):
268         return self.top.oncommit(fn)
269
270     def topreq(self):
271         return self.parent.topreq()