From 79be1cdc3f094156fe6db4c7131ee1051415cd92 Mon Sep 17 00:00:00 2001 From: Fredrik Tolf Date: Fri, 27 Sep 2024 03:52:01 +0200 Subject: [PATCH] Change frozen session format to allow for individual items to fail unpickling individually. --- wrw/session.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/wrw/session.py b/wrw/session.py index 8141827..1504d82 100644 --- a/wrw/session.py +++ b/wrw/session.py @@ -1,4 +1,4 @@ -import threading, time, pickle, random, os +import threading, time, pickle, random, os, io from . import cookie, env, proto __all__ = ["db", "get"] @@ -6,6 +6,24 @@ __all__ = ["db", "get"] def gennonce(length): return os.urandom(length) +class itempickler(pickle.Pickler): + def persistent_id(self, obj): + if isinstance(obj, session): + return ("session", obj.id) + return None + +class itemunpickler(pickle.Unpickler): + def __init__(self, *args, session, **kwargs): + super().__init__(*args, **kwargs) + self.session = session + + def persistent_load(self, oid): + tag = oid[0] + if tag == "session": + if oid[1] == self.session.id: + return self.session + raise pickle.UnpicklingError("unexpected persistent id: " + repr(oid)) + class session(object): def __init__(self, lock, expire=86400 * 7): self.id = proto.enhex(gennonce(16)) @@ -50,15 +68,37 @@ class session(object): return key in self.dict def __getstate__(self): - ret = [] - for k, v in self.__dict__.items(): - if k == "lock": continue - ret.append((k, v)) + items = [] + for k, v in self.dict.items(): + buf = io.BytesIO() + itempickler(buf).dump((k, v)) + items.append(buf.getvalue()) + ret = {"data": items} + for k in ["id", "ctime", "mtime", "atime", "expire"]: + ret[k] = getattr(self, k) return ret - + def __setstate__(self, st): - for k, v in st: - self.__dict__[k] = v + print(st) + if isinstance(st, list): + # Only for the old session format; remove me in due time. + for k, v in st: + self.__dict__[k] = v + else: + self.dict = {} + self.dctl = set() + self.dirtyp = False + for k in ["id", "ctime", "mtime", "atime", "expire"]: + setattr(self, k, st[k]) + for item in st["data"]: + try: + k, v = itemunpickler(io.BytesIO(item), session=self).load() + except Exception as exc: + print(exc) + continue + self.dict[k] = v + if hasattr(v, "sessdirty"): + self.dctl.add(v) # The proper lock is set by the thawer def __repr__(self): -- 2.11.0