Mangle parameter names in formparams.
[wrw.git] / wrw / util.py
1 import inspect, math
2 from . import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     sig = inspect.signature(callable)
12     haskw = inspect.Parameter.VAR_KEYWORD in (par.kind for par in sig.parameters.values())
13
14     def mangle(nm):
15         return nm.replace("-", "_")
16
17     def wrapper(req):
18         try:
19             data = dict((mangle(key), val) for (key, val)
20                         in form.formdata(req).items())
21         except IOError:
22             raise resp.httperror(400, "Invalid request", "Form data was incomplete")
23         
24         data["req"] = req
25         if haskw:
26             args = data
27         else:
28             args = {}
29             for par in sig.parameters.values():
30                 if par.name in data:
31                     args[par.name] = data[par.name]
32         for par in sig.parameters.values():
33             if par.kind is inspect.Parameter.VAR_KEYWORD:
34                 continue
35             if par.default is inspect.Parameter.empty and par.name not in args:
36                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(par.name), "' is required but not supplied."))
37         return callable(**args)
38
39     wrapper.__wrapped__ = callable
40     return wrapper
41
42 class funplex(object):
43     def __init__(self, *funs, **nfuns):
44         self.dir = {}
45         self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
46         self.dir.update(nfuns)
47
48     @staticmethod
49     def unwrap(fun):
50         while hasattr(fun, "__wrapped__"):
51             fun = fun.__wrapped__
52         return fun
53
54     def __call__(self, req):
55         if req.pathinfo == "":
56             if "__root__" in self.dir:
57                 return self.dir["__root__"](req)
58             raise resp.redirect(req.uriname + "/")
59         if req.pathinfo[:1] != "/":
60             raise resp.notfound()
61         p = req.pathinfo[1:]
62         if p == "":
63             p = "__index__"
64             bi = 1
65         else:
66             p = p.partition("/")[0]
67             bi = len(p) + 1
68         if p in self.dir:
69             sreq = req.shift(bi)
70             sreq.selfpath = req.pathinfo[1:]
71             return self.dir[p](sreq)
72         raise resp.notfound()
73
74     def add(self, fun):
75         self.dir[self.unwrap(fun).__name__] = fun
76         return fun
77
78     def name(self, name):
79         def dec(fun):
80             self.dir[name] = fun
81             return fun
82         return dec
83
84 def persession(data=None):
85     def dec(callable):
86         def wrapper(req):
87             sess = session.get(req)
88             if callable not in sess:
89                 if data is None:
90                     sess[callable] = callable()
91                 else:
92                     if data not in sess:
93                         sess[data] = data()
94                     sess[callable] = callable(data)
95             return sess[callable].handle(req)
96         wrapper.__wrapped__ = callable
97         return wrapper
98     return dec
99
100 class preiter(object):
101     __slots__ = ["bk", "bki", "_next"]
102     end = object()
103     def __init__(self, real):
104         self.bk = real
105         self.bki = iter(real)
106         self._next = None
107         self.__next__()
108
109     def __iter__(self):
110         return self
111
112     def __next__(self):
113         if self._next is self.end:
114             raise StopIteration()
115         ret = self._next
116         try:
117             self._next = next(self.bki)
118         except StopIteration:
119             self._next = self.end
120         return ret
121
122     def close(self):
123         if hasattr(self.bk, "close"):
124             self.bk.close()
125
126 def pregen(callable):
127     def wrapper(*args, **kwargs):
128         return preiter(callable(*args, **kwargs))
129     wrapper.__wrapped__ = callable
130     return wrapper
131
132 def stringwrap(charset):
133     def dec(callable):
134         @pregen
135         def wrapper(*args, **kwargs):
136             for string in callable(*args, **kwargs):
137                 yield string.encode(charset)
138         wrapper.__wrapped__ = callable
139         return wrapper
140     return dec
141
142 class sessiondata(object):
143     @classmethod
144     def get(cls, req, create=True):
145         sess = cls.sessdb().get(req)
146         with sess.lock:
147             try:
148                 return sess[cls]
149             except KeyError:
150                 if not create:
151                     return None
152                 ret = cls(req, sess)
153                 sess[cls] = ret
154                 return ret
155
156     @classmethod
157     def sessdb(cls):
158         return session.default.val
159
160 class autodirty(sessiondata):
161     @classmethod
162     def get(cls, req):
163         ret = super().get(req)
164         if "_is_dirty" not in ret.__dict__:
165             ret.__dict__["_is_dirty"] = False
166         return ret
167
168     def sessfrozen(self):
169         self.__dict__["_is_dirty"] = False
170
171     def sessdirty(self):
172         return self._is_dirty
173
174     def __setattr__(self, name, value):
175         super().__setattr__(name, value)
176         if "_is_dirty" in self.__dict__:
177             self.__dict__["_is_dirty"] = True
178
179     def __delattr__(self, name):
180         super().__delattr__(name, value)
181         if "_is_dirty" in self.__dict__:
182             self.__dict__["_is_dirty"] = True
183
184 class manudirty(object):
185     def __init__(self, *args, **kwargs):
186         super().__init__(*args, **kwargs)
187         self.__dirty = False
188
189     def sessfrozen(self):
190         self.__dirty = False
191
192     def sessdirty(self):
193         return self.__dirty
194
195     def dirty(self):
196         self.__dirty = True
197
198 class specslot(object):
199     __slots__ = ["nm", "idx", "dirty"]
200     unbound = object()
201     
202     def __init__(self, nm, idx, dirty):
203         self.nm = nm
204         self.idx = idx
205         self.dirty = dirty
206
207     @staticmethod
208     def slist(ins):
209         # Avoid calling __getattribute__
210         return specdirty.__sslots__.__get__(ins, type(ins))
211
212     def __get__(self, ins, cls):
213         val = self.slist(ins)[self.idx]
214         if val is specslot.unbound:
215             raise AttributeError("specslot %r is unbound" % self.nm)
216         return val
217
218     def __set__(self, ins, val):
219         self.slist(ins)[self.idx] = val
220         if self.dirty:
221             ins.dirty()
222
223     def __delete__(self, ins):
224         self.slist(ins)[self.idx] = specslot.unbound
225         ins.dirty()
226
227 class specclass(type):
228     def __init__(self, name, bases, tdict):
229         super().__init__(name, bases, tdict)
230         sslots = set()
231         dslots = set()
232         for cls in self.__mro__:
233             css = cls.__dict__.get("__saveslots__", ())
234             sslots.update(css)
235             dslots.update(cls.__dict__.get("__dirtyslots__", css))
236         self.__sslots_l__ = list(sslots)
237         self.__sslots_a__ = list(sslots | dslots)
238         for i, slot in enumerate(self.__sslots_a__):
239             setattr(self, slot, specslot(slot, i, slot in dslots))
240
241 class specdirty(sessiondata, metaclass=specclass):
242     __slots__ = ["session", "__sslots__", "_is_dirty"]
243     
244     def __specinit__(self):
245         pass
246
247     @staticmethod
248     def __new__(cls, req, sess):
249         self = super().__new__(cls)
250         self.session = sess
251         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
252         self.__specinit__()
253         self._is_dirty = False
254         return self
255
256     def __getnewargs__(self):
257         return (None, self.session)
258
259     def dirty(self):
260         self._is_dirty = True
261
262     def sessfrozen(self):
263         self._is_dirty = False
264
265     def sessdirty(self):
266         return self._is_dirty
267
268     def __getstate__(self):
269         ret = {}
270         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
271             if val is specslot.unbound:
272                 ret[nm] = False, None
273             else:
274                 ret[nm] = True, val
275         return ret
276
277     def __setstate__(self, st):
278         ss = specslot.slist(self)
279         for i, nm in enumerate(type(self).__sslots_a__):
280             bound, val = st.pop(nm, (False, None))
281             if not bound:
282                 ss[i] = specslot.unbound
283             else:
284                 ss[i] = val
285
286 def datecheck(req, mtime):
287     if "If-Modified-Since" in req.ihead:
288         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
289         if rtime is not None and rtime >= math.floor(mtime):
290             raise resp.unmodified()
291     req.ohead["Last-Modified"] = proto.httpdate(mtime)