08120cbda1bb973237c621fb03859c01b2a4e826
[wrw.git] / wrw / session.py
1 import threading, time, pickle, random, os, io
2 from . import cookie, env, proto
3
4 __all__ = ["db", "get"]
5
6 def gennonce(length):
7     return os.urandom(length)
8
9 class itempickler(pickle.Pickler):
10     def persistent_id(self, obj):
11         if isinstance(obj, session):
12             return ("session", obj.id)
13         return None
14
15 class itemunpickler(pickle.Unpickler):
16     def __init__(self, *args, session, **kwargs):
17         super().__init__(*args, **kwargs)
18         self.session = session
19
20     def persistent_load(self, oid):
21         tag = oid[0]
22         if tag == "session":
23             if oid[1] == self.session.id:
24                 return self.session
25         raise pickle.UnpicklingError("unexpected persistent id: " + repr(oid))
26
27 class session(object):
28     def __init__(self, lock, expire=86400 * 7):
29         self.id = proto.enhex(gennonce(16))
30         self.dict = {}
31         self.lock = lock
32         self.ctime = self.atime = self.mtime = int(time.time())
33         self.expire = expire
34         self.dctl = set()
35         self.dirtyp = False
36
37     def dirty(self):
38         for d in self.dctl:
39             if d.sessdirty():
40                 return True
41         return self.dirtyp
42
43     def frozen(self):
44         for d in self.dctl:
45             d.sessfrozen()
46         self.dirtyp = False
47
48     def __getitem__(self, key):
49         return self.dict[key]
50
51     def get(self, key, default=None):
52         return self.dict.get(key, default)
53
54     def __setitem__(self, key, value):
55         self.dict[key] = value
56         if hasattr(value, "sessdirty"):
57             self.dctl.add(value)
58         else:
59             self.dirtyp = True
60
61     def __delitem__(self, key):
62         old = self.dict.pop(key)
63         if old in self.dctl:
64             self.dctl.remove(old)
65         self.dirtyp = True
66
67     def __contains__(self, key):
68         return key in self.dict
69
70     def __getstate__(self):
71         items = []
72         for k, v in self.dict.items():
73             buf = io.BytesIO()
74             itempickler(buf).dump((k, v))
75             items.append(buf.getvalue())
76         ret = {"data": items}
77         for k in ["id", "ctime", "mtime", "atime", "expire"]:
78             ret[k] = getattr(self, k)
79         return ret
80
81     def __setstate__(self, st):
82         if isinstance(st, list):
83             # Only for the old session format; remove me in due time.
84             for k, v in st:
85                 self.__dict__[k] = v
86         else:
87             self.dict = {}
88             self.dctl = set()
89             self.dirtyp = False
90             for k in ["id", "ctime", "mtime", "atime", "expire"]:
91                 setattr(self, k, st[k])
92             for item in st["data"]:
93                 try:
94                     k, v = itemunpickler(io.BytesIO(item), session=self).load()
95                 except Exception:
96                     continue
97                 self.dict[k] = v
98                 if hasattr(v, "sessdirty"):
99                     self.dctl.add(v)
100         # The proper lock is set by the thawer
101
102     def __repr__(self):
103         return "<session %s>" % self.id
104
105 class db(object):
106     def __init__(self, backdb=None, cookiename="wrwsess", path="/"):
107         self.live = {}
108         self.cookiename = cookiename
109         self.path = path
110         self.lock = threading.Lock()
111         self.cthread = None
112         self.freezetime = 3600
113         self.backdb = backdb
114
115     def clean(self):
116         now = int(time.time())
117         with self.lock:
118             clist = list(self.live.keys())
119         for sessid in clist:
120             with self.lock:
121                 try:
122                     entry = self.live[sessid]
123                 except KeyError:
124                     continue
125             with entry[0]:
126                 rm = False
127                 if entry[1] == "retired":
128                     pass
129                 elif entry[1] is None:
130                     pass
131                 else:
132                     sess = entry[1]
133                     if sess.atime + self.freezetime < now:
134                         try:
135                             if sess.dirty():
136                                 self.freeze(sess)
137                         except:
138                             if sess.atime + sess.expire < now:
139                                 rm = True
140                         else:
141                             rm = True
142                 if rm:
143                     entry[1] = "retired"
144                     with self.lock:
145                         del self.live[sessid]
146
147     def cleanloop(self):
148         try:
149             while True:
150                 time.sleep(300)
151                 self.clean()
152                 if len(self.live) == 0:
153                     break
154         finally:
155             with self.lock:
156                 self.cthread = None
157
158     def _fetch(self, sessid):
159         while True:
160             now = int(time.time())
161             with self.lock:
162                 if sessid in self.live:
163                     entry = self.live[sessid]
164                 else:
165                     entry = self.live[sessid] = [threading.RLock(), None]
166             with entry[0]:
167                 if isinstance(entry[1], session):
168                     entry[1].atime = now
169                     return entry[1]
170                 elif entry[1] == "retired":
171                     continue
172                 elif entry[1] is None:
173                     try:
174                         thawed = self.thaw(sessid)
175                         if thawed.atime + thawed.expire < now:
176                             raise KeyError()
177                         thawed.lock = entry[0]
178                         thawed.atime = now
179                         entry[1] = thawed
180                         return thawed
181                     finally:
182                         if entry[1] is None:
183                             entry[1] = "retired"
184                             with self.lock:
185                                 del self.live[sessid]
186                 else:
187                     raise Exception("Illegal session entry: " + repr(entry[1]))
188
189     def checkclean(self):
190         with self.lock:
191             if self.cthread is None:
192                 self.cthread = threading.Thread(target = self.cleanloop)
193                 self.cthread.setDaemon(True)
194                 self.cthread.start()
195
196     def mksession(self, req):
197         return session(threading.RLock())
198
199     def mkcookie(self, req, sess):
200         cookie.add(req, self.cookiename, sess.id,
201                    path=self.path,
202                    expires=cookie.cdate(time.time() + sess.expire))
203
204     def fetch(self, req):
205         now = int(time.time())
206         sessid = cookie.get(req, self.cookiename)
207         new = False
208         try:
209             if sessid is None:
210                 raise KeyError()
211             sess = self._fetch(sessid)
212         except KeyError:
213             sess = self.mksession(req)
214             new = True
215
216         def ckfreeze(req):
217             if sess.dirty():
218                 if new:
219                     self.mkcookie(req, sess)
220                     with self.lock:
221                         self.live[sess.id] = [sess.lock, sess]
222                 try:
223                     self.freeze(sess)
224                 except:
225                     pass
226                 self.checkclean()
227         req.oncommit(ckfreeze)
228         return sess
229
230     def thaw(self, sessid):
231         if self.backdb is None:
232             raise KeyError()
233         data = self.backdb[sessid]
234         try:
235             return pickle.loads(data)
236         except:
237             raise KeyError()
238
239     def freeze(self, sess):
240         if self.backdb is None:
241             raise TypeError()
242         with sess.lock:
243             data = pickle.dumps(sess, -1)
244         self.backdb[sess.id] = data
245         sess.frozen()
246
247     def get(self, req):
248         return req.item(self.fetch)
249
250 class dirback(object):
251     def __init__(self, path):
252         self.path = path
253
254     def __getitem__(self, key):
255         try:
256             with open(os.path.join(self.path, key)) as inf:
257                 return inf.read()
258         except IOError:
259             raise KeyError(key)
260
261     def __setitem__(self, key, value):
262         if not os.path.exists(self.path):
263             os.makedirs(self.path)
264         with open(os.path.join(self.path, key), "w") as out:
265             out.write(value)
266
267 default = env.var(db(backdb=dirback(os.path.join("/tmp", "wrwsess-" + str(os.getuid())))))
268
269 def get(req):
270     return default.val.get(req)