Change frozen session format to allow for individual items to fail unpickling individ...
authorFredrik Tolf <fredrik@dolda2000.com>
Fri, 27 Sep 2024 01:52:01 +0000 (03:52 +0200)
committerFredrik Tolf <fredrik@dolda2000.com>
Fri, 27 Sep 2024 01:52:01 +0000 (03:52 +0200)
wrw/session.py

index 8141827..1504d82 100644 (file)
@@ -1,4 +1,4 @@
-import threading, time, pickle, random, os
+import threading, time, pickle, random, os, io
 from . import cookie, env, proto
 
 __all__ = ["db", "get"]
@@ -6,6 +6,24 @@ __all__ = ["db", "get"]
 def gennonce(length):
     return os.urandom(length)
 
+class itempickler(pickle.Pickler):
+    def persistent_id(self, obj):
+        if isinstance(obj, session):
+            return ("session", obj.id)
+        return None
+
+class itemunpickler(pickle.Unpickler):
+    def __init__(self, *args, session, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.session = session
+
+    def persistent_load(self, oid):
+        tag = oid[0]
+        if tag == "session":
+            if oid[1] == self.session.id:
+                return self.session
+        raise pickle.UnpicklingError("unexpected persistent id: " + repr(oid))
+
 class session(object):
     def __init__(self, lock, expire=86400 * 7):
         self.id = proto.enhex(gennonce(16))
@@ -50,15 +68,37 @@ class session(object):
         return key in self.dict
 
     def __getstate__(self):
-        ret = []
-        for k, v in self.__dict__.items():
-            if k == "lock": continue
-            ret.append((k, v))
+        items = []
+        for k, v in self.dict.items():
+            buf = io.BytesIO()
+            itempickler(buf).dump((k, v))
+            items.append(buf.getvalue())
+        ret = {"data": items}
+        for k in ["id", "ctime", "mtime", "atime", "expire"]:
+            ret[k] = getattr(self, k)
         return ret
-    
+
     def __setstate__(self, st):
-        for k, v in st:
-            self.__dict__[k] = v
+        print(st)
+        if isinstance(st, list):
+            # Only for the old session format; remove me in due time.
+            for k, v in st:
+                self.__dict__[k] = v
+        else:
+            self.dict = {}
+            self.dctl = set()
+            self.dirtyp = False
+            for k in ["id", "ctime", "mtime", "atime", "expire"]:
+                setattr(self, k, st[k])
+            for item in st["data"]:
+                try:
+                    k, v = itemunpickler(io.BytesIO(item), session=self).load()
+                except Exception as exc:
+                    print(exc)
+                    continue
+                self.dict[k] = v
+                if hasattr(v, "sessdirty"):
+                    self.dctl.add(v)
         # The proper lock is set by the thawer
 
     def __repr__(self):