Remove a couple of debug messages.
[wrw.git] / wrw / session.py
... / ...
CommitLineData
1import threading, time, pickle, random, os, io
2from . import cookie, env, proto
3
4__all__ = ["db", "get"]
5
6def gennonce(length):
7 return os.urandom(length)
8
9class itempickler(pickle.Pickler):
10 def persistent_id(self, obj):
11 if isinstance(obj, session):
12 return ("session", obj.id)
13 return None
14
15class itemunpickler(pickle.Unpickler):
16 def __init__(self, *args, session, **kwargs):
17 super().__init__(*args, **kwargs)
18 self.session = session
19
20 def persistent_load(self, oid):
21 tag = oid[0]
22 if tag == "session":
23 if oid[1] == self.session.id:
24 return self.session
25 raise pickle.UnpicklingError("unexpected persistent id: " + repr(oid))
26
27class session(object):
28 def __init__(self, lock, expire=86400 * 7):
29 self.id = proto.enhex(gennonce(16))
30 self.dict = {}
31 self.lock = lock
32 self.ctime = self.atime = self.mtime = int(time.time())
33 self.expire = expire
34 self.dctl = set()
35 self.dirtyp = False
36
37 def dirty(self):
38 for d in self.dctl:
39 if d.sessdirty():
40 return True
41 return self.dirtyp
42
43 def frozen(self):
44 for d in self.dctl:
45 d.sessfrozen()
46 self.dirtyp = False
47
48 def __getitem__(self, key):
49 return self.dict[key]
50
51 def get(self, key, default=None):
52 return self.dict.get(key, default)
53
54 def __setitem__(self, key, value):
55 self.dict[key] = value
56 if hasattr(value, "sessdirty"):
57 self.dctl.add(value)
58 else:
59 self.dirtyp = True
60
61 def __delitem__(self, key):
62 old = self.dict.pop(key)
63 if old in self.dctl:
64 self.dctl.remove(old)
65 self.dirtyp = True
66
67 def __contains__(self, key):
68 return key in self.dict
69
70 def __getstate__(self):
71 items = []
72 for k, v in self.dict.items():
73 buf = io.BytesIO()
74 itempickler(buf).dump((k, v))
75 items.append(buf.getvalue())
76 ret = {"data": items}
77 for k in ["id", "ctime", "mtime", "atime", "expire"]:
78 ret[k] = getattr(self, k)
79 return ret
80
81 def __setstate__(self, st):
82 if isinstance(st, list):
83 # Only for the old session format; remove me in due time.
84 for k, v in st:
85 self.__dict__[k] = v
86 else:
87 self.dict = {}
88 self.dctl = set()
89 self.dirtyp = False
90 for k in ["id", "ctime", "mtime", "atime", "expire"]:
91 setattr(self, k, st[k])
92 for item in st["data"]:
93 try:
94 k, v = itemunpickler(io.BytesIO(item), session=self).load()
95 except Exception:
96 continue
97 self.dict[k] = v
98 if hasattr(v, "sessdirty"):
99 self.dctl.add(v)
100 # The proper lock is set by the thawer
101
102 def __repr__(self):
103 return "<session %s>" % self.id
104
105class db(object):
106 def __init__(self, backdb=None, cookiename="wrwsess", path="/"):
107 self.live = {}
108 self.cookiename = cookiename
109 self.path = path
110 self.lock = threading.Lock()
111 self.cthread = None
112 self.freezetime = 3600
113 self.backdb = backdb
114
115 def clean(self):
116 now = int(time.time())
117 with self.lock:
118 clist = list(self.live.keys())
119 for sessid in clist:
120 with self.lock:
121 try:
122 entry = self.live[sessid]
123 except KeyError:
124 continue
125 with entry[0]:
126 rm = False
127 if entry[1] == "retired":
128 pass
129 elif entry[1] is None:
130 pass
131 else:
132 sess = entry[1]
133 if sess.atime + self.freezetime < now:
134 try:
135 if sess.dirty():
136 self.freeze(sess)
137 except:
138 if sess.atime + sess.expire < now:
139 rm = True
140 else:
141 rm = True
142 if rm:
143 entry[1] = "retired"
144 with self.lock:
145 del self.live[sessid]
146
147 def cleanloop(self):
148 try:
149 while True:
150 time.sleep(300)
151 self.clean()
152 if len(self.live) == 0:
153 break
154 finally:
155 with self.lock:
156 self.cthread = None
157
158 def _fetch(self, sessid):
159 while True:
160 now = int(time.time())
161 with self.lock:
162 if sessid in self.live:
163 entry = self.live[sessid]
164 else:
165 entry = self.live[sessid] = [threading.RLock(), None]
166 with entry[0]:
167 if isinstance(entry[1], session):
168 entry[1].atime = now
169 return entry[1]
170 elif entry[1] == "retired":
171 continue
172 elif entry[1] is None:
173 try:
174 thawed = self.thaw(sessid)
175 if thawed.atime + thawed.expire < now:
176 raise KeyError()
177 thawed.lock = entry[0]
178 thawed.atime = now
179 entry[1] = thawed
180 return thawed
181 finally:
182 if entry[1] is None:
183 entry[1] = "retired"
184 with self.lock:
185 del self.live[sessid]
186 else:
187 raise Exception("Illegal session entry: " + repr(entry[1]))
188
189 def checkclean(self):
190 with self.lock:
191 if self.cthread is None:
192 self.cthread = threading.Thread(target = self.cleanloop)
193 self.cthread.setDaemon(True)
194 self.cthread.start()
195
196 def mksession(self, req):
197 return session(threading.RLock())
198
199 def mkcookie(self, req, sess):
200 cookie.add(req, self.cookiename, sess.id,
201 path=self.path,
202 expires=cookie.cdate(time.time() + sess.expire))
203
204 def fetch(self, req):
205 now = int(time.time())
206 sessid = cookie.get(req, self.cookiename)
207 new = False
208 try:
209 if sessid is None:
210 raise KeyError()
211 sess = self._fetch(sessid)
212 except KeyError:
213 sess = self.mksession(req)
214 new = True
215
216 def ckfreeze(req):
217 if sess.dirty():
218 if new:
219 self.mkcookie(req, sess)
220 with self.lock:
221 self.live[sess.id] = [sess.lock, sess]
222 try:
223 self.freeze(sess)
224 except:
225 pass
226 self.checkclean()
227 req.oncommit(ckfreeze)
228 return sess
229
230 def thaw(self, sessid):
231 if self.backdb is None:
232 raise KeyError()
233 data = self.backdb[sessid]
234 try:
235 return pickle.loads(data)
236 except:
237 raise KeyError()
238
239 def freeze(self, sess):
240 if self.backdb is None:
241 raise TypeError()
242 with sess.lock:
243 data = pickle.dumps(sess, -1)
244 self.backdb[sess.id] = data
245 sess.frozen()
246
247 def get(self, req):
248 return req.item(self.fetch)
249
250class dirback(object):
251 def __init__(self, path):
252 self.path = path
253
254 def __getitem__(self, key):
255 try:
256 with open(os.path.join(self.path, key)) as inf:
257 return inf.read()
258 except IOError:
259 raise KeyError(key)
260
261 def __setitem__(self, key, value):
262 if not os.path.exists(self.path):
263 os.makedirs(self.path)
264 with open(os.path.join(self.path, key), "w") as out:
265 out.write(value)
266
267default = env.var(db(backdb=dirback(os.path.join("/tmp", "wrwsess-" + str(os.getuid())))))
268
269def get(req):
270 return default.val.get(req)