55b4ee276b6d6c98b211ff11e4d62b9da667a6df
[wrw.git] / wrw / util.py
1 import inspect
2 import req, dispatch, session, form
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     return wrapper
8
9 def formparams(callable):
10     def wrapper(req):
11         data = form.formdata(req)
12         spec = inspect.getargspec(callable)
13         args = dict(data.items())
14         args["req"] = req
15         if not spec.keywords:
16             for arg in list(args):
17                 if arg not in spec.args:
18                     del args[arg]
19         return callable(**args)
20     return wrapper
21
22 def persession(data = None):
23     def dec(callable):
24         def wrapper(req):
25             sess = session.get(req)
26             if callable not in sess:
27                 if data is None:
28                     sess[callable] = callable()
29                 else:
30                     if data not in sess:
31                         sess[data] = data()
32                     sess[callable] = callable(data)
33             return sess[callable].handle(req)
34         return wrapper
35     return dec
36
37 class preiter(object):
38     __slots__ = ["bk", "bki", "_next"]
39     end = object()
40     def __init__(self, real):
41         self.bk = real
42         self.bki = iter(real)
43         self._next = None
44         self.next()
45
46     def __iter__(self):
47         return self
48
49     def next(self):
50         if self._next is self.end:
51             raise StopIteration()
52         ret = self._next
53         try:
54             self._next = next(self.bki)
55         except StopIteration:
56             self._next = self.end
57         return ret
58
59     def close(self):
60         if hasattr(self.bk, "close"):
61             self.bk.close()
62
63 def pregen(callable):
64     def wrapper(*args, **kwargs):
65         return preiter(callable(*args, **kwargs))
66     return wrapper
67
68 class sessiondata(object):
69     @classmethod
70     def get(cls, req, create = True):
71         sess = cls.sessdb().get(req)
72         with sess.lock:
73             try:
74                 return sess[cls]
75             except KeyError:
76                 if not create:
77                     return None
78                 ret = cls(req, sess)
79                 sess[cls] = ret
80                 return ret
81
82     @classmethod
83     def sessdb(cls):
84         return session.default.val
85
86 class autodirty(sessiondata):
87     @classmethod
88     def get(cls, req):
89         ret = super(autodirty, cls).get(req)
90         if "_is_dirty" not in ret.__dict__:
91             ret.__dict__["_is_dirty"] = False
92         return ret
93
94     def sessfrozen(self):
95         self.__dict__["_is_dirty"] = False
96
97     def sessdirty(self):
98         return self._is_dirty
99
100     def __setattr__(self, name, value):
101         super(autodirty, self).__setattr__(name, value)
102         if "_is_dirty" in self.__dict__:
103             self.__dict__["_is_dirty"] = True
104
105     def __delattr__(self, name):
106         super(autodirty, self).__delattr__(name, value)
107         if "_is_dirty" in self.__dict__:
108             self.__dict__["_is_dirty"] = True
109
110 class manudirty(object):
111     def __init__(self, *args, **kwargs):
112         super(manudirty, self).__init__(*args, **kwargs)
113         self.__dirty = False
114
115     def sessfrozen(self):
116         self.__dirty = False
117
118     def sessdirty(self):
119         return self.__dirty
120
121     def dirty(self):
122         self.__dirty = True
123
124 class specslot(object):
125     __slots__ = ["nm", "idx", "dirty"]
126     unbound = object()
127     
128     def __init__(self, nm, idx, dirty):
129         self.nm = nm
130         self.idx = idx
131         self.dirty = dirty
132
133     @staticmethod
134     def slist(ins):
135         # Avoid calling __getattribute__
136         return specdirty.__sslots__.__get__(ins, type(ins))
137
138     def __get__(self, ins, cls):
139         val = self.slist(ins)[self.idx]
140         if val is specslot.unbound:
141             raise AttributeError("specslot %r is unbound" % self.nm)
142         return val
143
144     def __set__(self, ins, val):
145         self.slist(ins)[self.idx] = val
146         if self.dirty:
147             ins.dirty()
148
149     def __delete__(self, ins):
150         self.slist(ins)[self.idx] = specslot.unbound
151         ins.dirty()
152
153 class specclass(type):
154     def __init__(self, name, bases, tdict):
155         super(specclass, self).__init__(name, bases, tdict)
156         sslots = set()
157         dslots = set()
158         for cls in self.__mro__:
159             css = cls.__dict__.get("__saveslots__", ())
160             sslots.update(css)
161             dslots.update(cls.__dict__.get("__dirtyslots__", css))
162         self.__sslots_l__ = list(sslots)
163         self.__sslots_a__ = list(sslots | dslots)
164         for i, slot in enumerate(self.__sslots_a__):
165             setattr(self, slot, specslot(slot, i, slot in dslots))
166
167 class specdirty(sessiondata):
168     __metaclass__ = specclass
169     __slots__ = ["session", "__sslots__", "_is_dirty"]
170     
171     def __specinit__(self):
172         pass
173
174     @staticmethod
175     def __new__(cls, req, sess):
176         self = super(specdirty, cls).__new__(cls)
177         self.session = sess
178         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
179         self.__specinit__()
180         self._is_dirty = False
181         return self
182
183     def __getnewargs__(self):
184         return (None, self.session)
185
186     def dirty(self):
187         self._is_dirty = True
188
189     def sessfrozen(self):
190         self._is_dirty = False
191
192     def sessdirty(self):
193         return self._is_dirty
194
195     def __getstate__(self):
196         ret = {}
197         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
198             if val is specslot.unbound:
199                 ret[nm] = False, None
200             else:
201                 ret[nm] = True, val
202         return ret
203
204     def __setstate__(self, st):
205         ss = specslot.slist(self)
206         for i, nm in enumerate(type(self).__sslots_a__):
207             bound, val = st.pop(nm, (False, None))
208             if not bound:
209                 ss[i] = specslot.unbound
210             else:
211                 ss[i] = val