4b9dd7139dd6064d026bdbe2b88c8e0f6092be48
[wrw.git] / wrw / req.py
1 import io
2
3 __all__ = ["request"]
4
5 class headdict(object):
6     def __init__(self):
7         self.dict = {}
8
9     def __getitem__(self, key):
10         return self.dict[key.lower()][1]
11
12     def __setitem__(self, key, val):
13         self.dict[key.lower()] = [key, val]
14
15     def __contains__(self, key):
16         return key.lower() in self.dict
17
18     def __delitem__(self, key):
19         del self.dict[key.lower()]
20
21     def __iter__(self):
22         return iter((list[0] for list in self.dict.itervalues()))
23     
24     def get(self, key, default = ""):
25         if key.lower() in self.dict:
26             return self.dict[key.lower()][1]
27         return default
28
29     def getlist(self, key):
30         return self.dict.setdefault(key.lower(), [key])[1:]
31
32     def add(self, key, val):
33         self.dict.setdefault(key.lower(), [key]).append(val)
34
35     def __repr__(self):
36         return repr(self.dict)
37
38     def __str__(self):
39         return str(self.dict)
40
41 def fixcase(str):
42     str = str.lower()
43     i = 0
44     b = True
45     while i < len(str):
46         if b:
47             str = str[:i] + str[i].upper() + str[i + 1:]
48         b = False
49         if str[i] == '-':
50             b = True
51         i += 1
52     return str
53
54 class limitreader(object):
55     def __init__(self, back, limit):
56         self.bk = back
57         self.limit = limit
58         self.rb = 0
59         self.buf = bytearray()
60
61     def close(self):
62         pass
63
64     def read(self, size=-1):
65         ra = self.limit - self.rb
66         if size >= 0:
67             ra = min(ra, size)
68         while len(self.buf) < ra:
69             ret = self.bk.read(ra - len(self.buf))
70             if ret == "":
71                 raise IOError("Unexpected EOF")
72             self.buf.extend(ret)
73             self.rb += len(ret)
74         ret = str(self.buf[:ra])
75         self.buf = self.buf[ra:]
76         return ret
77
78     def readline(self, size=-1):
79         off = 0
80         while True:
81             p = self.buf.find('\n', off)
82             if p >= 0:
83                 ret = str(self.buf[:p + 1])
84                 self.buf = self.buf[p + 1:]
85                 return ret
86             off = len(self.buf)
87             if size >= 0 and len(self.buf) >= size:
88                 ret = str(self.buf[:size])
89                 self.buf = self.buf[size:]
90                 return ret
91             if self.rb == self.limit:
92                 ret = str(self.buf)
93                 self.buf = bytearray()
94                 return ret
95             ra = self.limit - self.rb
96             if size >= 0:
97                 ra = min(ra, size)
98             ra = min(ra, 1024)
99             ret = self.bk.read(ra)
100             if ret == "":
101                 raise IOError("Unpexpected EOF")
102             self.buf.extend(ret)
103             self.rb += len(ret)
104
105     def readlines(self, hint=None):
106         return list(self)
107
108     def __iter__(rd):
109         class lineiter(object):
110             def __iter__(self):
111                 return self
112             def next(self):
113                 ret = rd.readline()
114                 if ret == "":
115                     raise StopIteration()
116                 return ret
117         return lineiter()
118
119 class request(object):
120     def copy(self):
121         return copyrequest(self)
122
123     def shift(self, n):
124         new = self.copy()
125         new.uriname = self.uriname + self.pathinfo[:n]
126         new.pathinfo = self.pathinfo[n:]
127         return new
128
129 class origrequest(request):
130     def __init__(self, env):
131         self.env = env
132         self.method = env["REQUEST_METHOD"].upper()
133         self.uriname = env["SCRIPT_NAME"]
134         self.filename = env.get("SCRIPT_FILENAME")
135         self.uri = env["REQUEST_URI"]
136         self.pathinfo = env["PATH_INFO"]
137         self.query = env["QUERY_STRING"]
138         self.remoteaddr = env["REMOTE_ADDR"]
139         self.serverport = env["SERVER_PORT"]
140         self.servername = env["SERVER_NAME"]
141         self.https = "HTTPS" in env
142         self.ihead = headdict()
143         self.input = None
144         if "CONTENT_TYPE" in env:
145             self.ihead["Content-Type"] = env["CONTENT_TYPE"]
146         if "CONTENT_LENGTH" in env:
147             clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
148             if clen.isdigit():
149                 self.input = limitreader(env["wsgi.input"], int(clen))
150         if self.input is None:
151             self.input = io.BytesIO("")
152         self.ohead = headdict()
153         for k, v in env.items():
154             if k[:5] == "HTTP_":
155                 self.ihead.add(fixcase(k[5:].replace("_", "-")), v)
156         self.items = {}
157         self.statuscode = (200, "OK")
158         self.ohead["Content-Type"] = "text/html"
159         self.resources = set()
160         self.clean = set()
161         self.commitfuns = []
162
163     def status(self, code, msg):
164         self.statuscode = code, msg
165
166     def item(self, id):
167         if id in self.items:
168             return self.items[id]
169         self.items[id] = new = id(self)
170         if hasattr(new, "__enter__") and hasattr(new, "__exit__"):
171             self.withres(new)
172         return new
173
174     def withres(self, res):
175         if res not in self.resources:
176             done = False
177             res.__enter__()
178             try:
179                 self.resources.add(res)
180                 self.clean.add(res.__exit__)
181                 done = True
182             finally:
183                 if not done:
184                     res.__exit__(None, None, None)
185                     self.resources.discard(res)
186
187     def cleanup(self):
188         def clean1(list):
189             if len(list) > 0:
190                 try:
191                     list[0]()
192                 finally:
193                     clean1(list[1:])
194         clean1(list(self.clean))
195
196     def oncommit(self, fn):
197         if fn not in self.commitfuns:
198             self.commitfuns.append(fn)
199
200     def commit(self, startreq):
201         for fun in reversed(self.commitfuns):
202             fun(self)
203         hdrs = []
204         for nm in self.ohead:
205             for val in self.ohead.getlist(nm):
206                 hdrs.append((nm, val))
207         startreq("%s %s" % self.statuscode, hdrs)
208
209     def topreq(self):
210         return self
211
212 class copyrequest(request):
213     def __init__(self, p):
214         self.parent = p
215         self.top = p.topreq()
216         self.env = p.env
217         self.method = p.method
218         self.uriname = p.uriname
219         self.filename = p.filename
220         self.uri = p.uri
221         self.pathinfo = p.pathinfo
222         self.query = p.query
223         self.remoteaddr = p.remoteaddr
224         self.serverport = p.serverport
225         self.https = p.https
226         self.ihead = p.ihead
227         self.ohead = p.ohead
228
229     def status(self, code, msg):
230         return self.parent.status(code, msg)
231
232     def item(self, id):
233         return self.top.item(id)
234
235     def withres(self, res):
236         return self.top.withres(res)
237
238     def oncommit(self, fn):
239         return self.top.oncommit(fn)
240
241     def topreq(self):
242         return self.parent.topreq()