Merge branch 'master' into python2
[wrw.git] / wrw / session.py
1 import threading, time, pickle, random, os
2 import cookie, env
3
4 __all__ = ["db", "get"]
5
6 def gennonce(length):
7     return os.urandom(length)
8
9 class session(object):
10     def __init__(self, lock, expire=86400 * 7):
11         self.id = gennonce(16).encode("hex")
12         self.dict = {}
13         self.lock = lock
14         self.ctime = self.atime = self.mtime = int(time.time())
15         self.expire = expire
16         self.dctl = set()
17         self.dirtyp = False
18
19     def dirty(self):
20         for d in self.dctl:
21             if d.sessdirty():
22                 return True
23         return self.dirtyp
24
25     def frozen(self):
26         for d in self.dctl:
27             d.sessfrozen()
28         self.dirtyp = False
29
30     def __getitem__(self, key):
31         return self.dict[key]
32
33     def get(self, key, default=None):
34         return self.dict.get(key, default)
35
36     def __setitem__(self, key, value):
37         self.dict[key] = value
38         if hasattr(value, "sessdirty"):
39             self.dctl.add(value)
40         else:
41             self.dirtyp = True
42
43     def __delitem__(self, key):
44         old = self.dict.pop(key)
45         if old in self.dctl:
46             self.dctl.remove(old)
47         self.dirtyp = True
48
49     def __contains__(self, key):
50         return key in self.dict
51
52     def __getstate__(self):
53         ret = []
54         for k, v in self.__dict__.items():
55             if k == "lock": continue
56             ret.append((k, v))
57         return ret
58     
59     def __setstate__(self, st):
60         for k, v in st:
61             self.__dict__[k] = v
62         # The proper lock is set by the thawer
63
64     def __repr__(self):
65         return "<session %s>" % self.id
66
67 class db(object):
68     def __init__(self, backdb=None, cookiename="wrwsess", path="/"):
69         self.live = {}
70         self.cookiename = cookiename
71         self.path = path
72         self.lock = threading.Lock()
73         self.cthread = None
74         self.freezetime = 3600
75         self.backdb = backdb
76
77     def clean(self):
78         now = int(time.time())
79         with self.lock:
80             clist = self.live.keys()
81         for sessid in clist:
82             with self.lock:
83                 try:
84                     entry = self.live[sessid]
85                 except KeyError:
86                     continue
87             with entry[0]:
88                 rm = False
89                 if entry[1] == "retired":
90                     pass
91                 elif entry[1] is None:
92                     pass
93                 else:
94                     sess = entry[1]
95                     if sess.atime + self.freezetime < now:
96                         try:
97                             if sess.dirty():
98                                 self.freeze(sess)
99                         except:
100                             if sess.atime + sess.expire < now:
101                                 rm = True
102                         else:
103                             rm = True
104                 if rm:
105                     entry[1] = "retired"
106                     with self.lock:
107                         del self.live[sessid]
108
109     def cleanloop(self):
110         try:
111             while True:
112                 time.sleep(300)
113                 self.clean()
114                 if len(self.live) == 0:
115                     break
116         finally:
117             with self.lock:
118                 self.cthread = None
119
120     def _fetch(self, sessid):
121         while True:
122             now = int(time.time())
123             with self.lock:
124                 if sessid in self.live:
125                     entry = self.live[sessid]
126                 else:
127                     entry = self.live[sessid] = [threading.RLock(), None]
128             with entry[0]:
129                 if isinstance(entry[1], session):
130                     entry[1].atime = now
131                     return entry[1]
132                 elif entry[1] == "retired":
133                     continue
134                 elif entry[1] is None:
135                     try:
136                         thawed = self.thaw(sessid)
137                         if thawed.atime + thawed.expire < now:
138                             raise KeyError()
139                         thawed.lock = entry[0]
140                         thawed.atime = now
141                         entry[1] = thawed
142                         return thawed
143                     finally:
144                         if entry[1] is None:
145                             entry[1] = "retired"
146                             with self.lock:
147                                 del self.live[sessid]
148                 else:
149                     raise Exception("Illegal session entry: " + repr(entry[1]))
150
151     def checkclean(self):
152         with self.lock:
153             if self.cthread is None:
154                 self.cthread = threading.Thread(target = self.cleanloop)
155                 self.cthread.setDaemon(True)
156                 self.cthread.start()
157
158     def mksession(self, req):
159         return session(threading.RLock())
160
161     def mkcookie(self, req, sess):
162         cookie.add(req, self.cookiename, sess.id,
163                    path=self.path,
164                    expires=cookie.cdate(time.time() + sess.expire))
165
166     def fetch(self, req):
167         now = int(time.time())
168         sessid = cookie.get(req, self.cookiename)
169         new = False
170         try:
171             if sessid is None:
172                 raise KeyError()
173             sess = self._fetch(sessid)
174         except KeyError:
175             sess = self.mksession(req)
176             new = True
177
178         def ckfreeze(req):
179             if sess.dirty():
180                 if new:
181                     self.mkcookie(req, sess)
182                     with self.lock:
183                         self.live[sess.id] = [sess.lock, sess]
184                 try:
185                     self.freeze(sess)
186                 except:
187                     pass
188                 self.checkclean()
189         req.oncommit(ckfreeze)
190         return sess
191
192     def thaw(self, sessid):
193         if self.backdb is None:
194             raise KeyError()
195         data = self.backdb[sessid]
196         try:
197             return pickle.loads(data)
198         except Exception, e:
199             raise KeyError()
200
201     def freeze(self, sess):
202         if self.backdb is None:
203             raise TypeError()
204         with sess.lock:
205             data = pickle.dumps(sess, -1)
206         self.backdb[sess.id] = data
207         sess.frozen()
208
209     def get(self, req):
210         return req.item(self.fetch)
211
212 class dirback(object):
213     def __init__(self, path):
214         self.path = path
215
216     def __getitem__(self, key):
217         try:
218             with open(os.path.join(self.path, key)) as inf:
219                 return inf.read()
220         except IOError:
221             raise KeyError(key)
222
223     def __setitem__(self, key, value):
224         if not os.path.exists(self.path):
225             os.makedirs(self.path)
226         with open(os.path.join(self.path, key), "w") as out:
227             out.write(value)
228
229 default = env.var(db(backdb=dirback(os.path.join("/tmp", "wrwsess-" + str(os.getuid())))))
230
231 def get(req):
232     return default.val.get(req)