Mangle parameter names in formparams.
[wrw.git] / wrw / util.py
CommitLineData
1864be32 1import inspect, math
2a5a8ce7 2from . import req, dispatch, session, form, resp, proto
b409a338
FT
3
4def wsgiwrap(callable):
5 def wrapper(env, startreq):
7450e2fc 6 return dispatch.handleenv(env, startreq, callable)
0f18b774 7 wrapper.__wrapped__ = callable
b409a338 8 return wrapper
d9979128 9
bbdebbab 10def formparams(callable):
4d7795d9
FT
11 sig = inspect.signature(callable)
12 haskw = inspect.Parameter.VAR_KEYWORD in (par.kind for par in sig.parameters.values())
0a22b876
FT
13
14 def mangle(nm):
15 return nm.replace("-", "_")
16
bbdebbab 17 def wrapper(req):
806b2bc4 18 try:
0a22b876
FT
19 data = dict((mangle(key), val) for (key, val)
20 in form.formdata(req).items())
806b2bc4
FT
21 except IOError:
22 raise resp.httperror(400, "Invalid request", "Form data was incomplete")
4d7795d9
FT
23
24 data["req"] = req
25 if haskw:
26 args = data
27 else:
28 args = {}
29 for par in sig.parameters.values():
30 if par.name in data:
31 args[par.name] = data[par.name]
32 for par in sig.parameters.values():
319f6448
FT
33 if par.kind is inspect.Parameter.VAR_KEYWORD:
34 continue
4d7795d9 35 if par.default is inspect.Parameter.empty and par.name not in args:
977d28bd 36 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(par.name), "' is required but not supplied."))
bbdebbab 37 return callable(**args)
0a22b876 38
0f18b774 39 wrapper.__wrapped__ = callable
bbdebbab
FT
40 return wrapper
41
612eb9f5
FT
42class funplex(object):
43 def __init__(self, *funs, **nfuns):
44 self.dir = {}
45 self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
46 self.dir.update(nfuns)
47
48 @staticmethod
24e514f0
FT
49 def unwrap(fun):
50 while hasattr(fun, "__wrapped__"):
51 fun = fun.__wrapped__
52 return fun
612eb9f5
FT
53
54 def __call__(self, req):
525d7938 55 if req.pathinfo == "":
6a9037cb
FT
56 if "__root__" in self.dir:
57 return self.dir["__root__"](req)
525d7938
FT
58 raise resp.redirect(req.uriname + "/")
59 if req.pathinfo[:1] != "/":
60 raise resp.notfound()
61 p = req.pathinfo[1:]
62 if p == "":
63 p = "__index__"
64 bi = 1
65 else:
66 p = p.partition("/")[0]
67 bi = len(p) + 1
612eb9f5 68 if p in self.dir:
65e0a59d
FT
69 sreq = req.shift(bi)
70 sreq.selfpath = req.pathinfo[1:]
71 return self.dir[p](sreq)
525d7938 72 raise resp.notfound()
612eb9f5
FT
73
74 def add(self, fun):
75 self.dir[self.unwrap(fun).__name__] = fun
76 return fun
77
78 def name(self, name):
79 def dec(fun):
80 self.dir[name] = fun
81 return fun
82 return dec
525d7938 83
9bc70dab 84def persession(data=None):
d9979128
FT
85 def dec(callable):
86 def wrapper(req):
87 sess = session.get(req)
88 if callable not in sess:
89 if data is None:
90 sess[callable] = callable()
91 else:
92 if data not in sess:
93 sess[data] = data()
94 sess[callable] = callable(data)
95 return sess[callable].handle(req)
0f18b774 96 wrapper.__wrapped__ = callable
d9979128
FT
97 return wrapper
98 return dec
d1f70c6c 99
77dd732a
FT
100class preiter(object):
101 __slots__ = ["bk", "bki", "_next"]
102 end = object()
103 def __init__(self, real):
104 self.bk = real
105 self.bki = iter(real)
106 self._next = None
a7a09080 107 self.__next__()
77dd732a
FT
108
109 def __iter__(self):
110 return self
111
a7a09080 112 def __next__(self):
77dd732a
FT
113 if self._next is self.end:
114 raise StopIteration()
115 ret = self._next
116 try:
117 self._next = next(self.bki)
118 except StopIteration:
119 self._next = self.end
120 return ret
121
122 def close(self):
123 if hasattr(self.bk, "close"):
124 self.bk.close()
125
126def pregen(callable):
127 def wrapper(*args, **kwargs):
128 return preiter(callable(*args, **kwargs))
0f18b774 129 wrapper.__wrapped__ = callable
77dd732a
FT
130 return wrapper
131
62113fc6
FT
132def stringwrap(charset):
133 def dec(callable):
134 @pregen
135 def wrapper(*args, **kwargs):
136 for string in callable(*args, **kwargs):
137 yield string.encode(charset)
138 wrapper.__wrapped__ = callable
139 return wrapper
140 return dec
141
d1f70c6c
FT
142class sessiondata(object):
143 @classmethod
9bc70dab 144 def get(cls, req, create=True):
d1f70c6c
FT
145 sess = cls.sessdb().get(req)
146 with sess.lock:
147 try:
148 return sess[cls]
149 except KeyError:
8f911ff6
FT
150 if not create:
151 return None
5b35322c 152 ret = cls(req, sess)
d1f70c6c
FT
153 sess[cls] = ret
154 return ret
155
156 @classmethod
157 def sessdb(cls):
1f61bf31 158 return session.default.val
d1f70c6c 159
f13b8f5a
FT
160class autodirty(sessiondata):
161 @classmethod
162 def get(cls, req):
d13a1a57 163 ret = super().get(req)
f13b8f5a
FT
164 if "_is_dirty" not in ret.__dict__:
165 ret.__dict__["_is_dirty"] = False
617b21df 166 return ret
f13b8f5a 167
d1f70c6c 168 def sessfrozen(self):
f13b8f5a 169 self.__dict__["_is_dirty"] = False
d1f70c6c
FT
170
171 def sessdirty(self):
f13b8f5a 172 return self._is_dirty
d1f70c6c
FT
173
174 def __setattr__(self, name, value):
a4ad119b 175 super().__setattr__(name, value)
f13b8f5a
FT
176 if "_is_dirty" in self.__dict__:
177 self.__dict__["_is_dirty"] = True
d1f70c6c
FT
178
179 def __delattr__(self, name):
d13a1a57 180 super().__delattr__(name, value)
f13b8f5a
FT
181 if "_is_dirty" in self.__dict__:
182 self.__dict__["_is_dirty"] = True
3b9bc700
FT
183
184class manudirty(object):
185 def __init__(self, *args, **kwargs):
d13a1a57 186 super().__init__(*args, **kwargs)
3b9bc700
FT
187 self.__dirty = False
188
189 def sessfrozen(self):
190 self.__dirty = False
191
192 def sessdirty(self):
193 return self.__dirty
194
195 def dirty(self):
196 self.__dirty = True
df5f7868
FT
197
198class specslot(object):
199 __slots__ = ["nm", "idx", "dirty"]
200 unbound = object()
201
202 def __init__(self, nm, idx, dirty):
203 self.nm = nm
204 self.idx = idx
205 self.dirty = dirty
206
207 @staticmethod
208 def slist(ins):
209 # Avoid calling __getattribute__
210 return specdirty.__sslots__.__get__(ins, type(ins))
211
212 def __get__(self, ins, cls):
213 val = self.slist(ins)[self.idx]
214 if val is specslot.unbound:
215 raise AttributeError("specslot %r is unbound" % self.nm)
216 return val
217
218 def __set__(self, ins, val):
219 self.slist(ins)[self.idx] = val
220 if self.dirty:
221 ins.dirty()
222
223 def __delete__(self, ins):
224 self.slist(ins)[self.idx] = specslot.unbound
225 ins.dirty()
226
227class specclass(type):
228 def __init__(self, name, bases, tdict):
f605aaf2 229 super().__init__(name, bases, tdict)
df5f7868
FT
230 sslots = set()
231 dslots = set()
232 for cls in self.__mro__:
233 css = cls.__dict__.get("__saveslots__", ())
234 sslots.update(css)
235 dslots.update(cls.__dict__.get("__dirtyslots__", css))
236 self.__sslots_l__ = list(sslots)
237 self.__sslots_a__ = list(sslots | dslots)
238 for i, slot in enumerate(self.__sslots_a__):
239 setattr(self, slot, specslot(slot, i, slot in dslots))
240
f605aaf2 241class specdirty(sessiondata, metaclass=specclass):
df5f7868
FT
242 __slots__ = ["session", "__sslots__", "_is_dirty"]
243
244 def __specinit__(self):
245 pass
246
247 @staticmethod
248 def __new__(cls, req, sess):
f605aaf2 249 self = super().__new__(cls)
df5f7868
FT
250 self.session = sess
251 self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
252 self.__specinit__()
253 self._is_dirty = False
254 return self
255
256 def __getnewargs__(self):
257 return (None, self.session)
258
259 def dirty(self):
260 self._is_dirty = True
261
262 def sessfrozen(self):
263 self._is_dirty = False
264
265 def sessdirty(self):
266 return self._is_dirty
267
268 def __getstate__(self):
269 ret = {}
270 for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
271 if val is specslot.unbound:
272 ret[nm] = False, None
273 else:
274 ret[nm] = True, val
275 return ret
276
277 def __setstate__(self, st):
278 ss = specslot.slist(self)
279 for i, nm in enumerate(type(self).__sslots_a__):
280 bound, val = st.pop(nm, (False, None))
df5f7868
FT
281 if not bound:
282 ss[i] = specslot.unbound
283 else:
284 ss[i] = val
1864be32
FT
285
286def datecheck(req, mtime):
287 if "If-Modified-Since" in req.ihead:
288 rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
438b6207 289 if rtime is not None and rtime >= math.floor(mtime):
1864be32
FT
290 raise resp.unmodified()
291 req.ohead["Last-Modified"] = proto.httpdate(mtime)