Hopefully fixed the ugliness in index-value duplication.
authorFredrik Tolf <fredrik@dolda2000.com>
Tue, 20 Mar 2018 20:33:22 +0000 (21:33 +0100)
committerFredrik Tolf <fredrik@dolda2000.com>
Tue, 20 Mar 2018 20:33:22 +0000 (21:33 +0100)
didex/store.py
didex/values.py

index a9bf0cd..11fe17a 100644 (file)
@@ -1,4 +1,4 @@
-import threading, pickle, inspect, atexit
+import threading, pickle, inspect, atexit, weakref
 from . import db, index, cache
 from .db import txnfun
 
@@ -47,8 +47,30 @@ def storedescs(obj):
         t.__didex_attr = ret
     return ret
 
+class icache(object):
+    def __init__(self):
+        self.d = weakref.WeakKeyDictionary()
+
+    def __getitem__(self, key):
+        obj, idx = key
+        return self.d[obj][idx]
+    def __setitem__(self, key, val):
+        obj, idx = key
+        if obj in self.d:
+            self.d[obj][idx] = val
+        else:
+            self.d[obj] = {idx: val}
+    def __delitem__(self, key):
+        obj, idx = key
+        del self.d[obj][idx]
+    def get(self, key, default=None):
+        obj, idx = key
+        if obj not in self.d:
+            return default
+        return self.d[obj].get(idx, default)
+
 class datastore(object):
-    def __init__(self, name, *, env=None, path=".", ncache=None):
+    def __init__(self, name, *, env=None, path=".", ncache=None, codec=None):
         self.name = name
         self.lk = threading.Lock()
         if env:
@@ -58,8 +80,11 @@ class datastore(object):
         self._db = None
         if ncache is None:
             ncache = cache.cache()
+        if codec is not None:
+            self._encode, self._decode = codec
         self.cache = ncache
         self.cache.load = self._load
+        self.icache = icache()
 
     def db(self):
         with self.lk:
@@ -67,15 +92,24 @@ class datastore(object):
                 self._db = self.env().db(self.name)
             return self._db
 
-    def _load(self, id):
+    def _decode(self, data):
         try:
-            return pickle.loads(self.db().get(id))
+            return pickle.loads(data)
         except:
             raise KeyError(id, "could not unpickle data")
 
     def _encode(self, obj):
         return pickle.dumps(obj)
 
+    @txnfun(lambda self: self.db().env.env)
+    def _load(self, id, *, tx):
+        loaded = self._decode(self.db().get(id, tx=tx))
+        if hasattr(loaded, "__didex_loaded__"):
+            loaded.__didex_loaded__(self, id)
+        for nm, attr in storedescs(loaded):
+            attr.loaded(id, loaded, tx)
+        return loaded
+
     def get(self, id, *, load=True):
         return self.cache.get(id, load=load)
 
@@ -110,13 +144,17 @@ class autotype(type):
     def __call__(self, *args, **kwargs):
         new = super().__call__(*args, **kwargs)
         new.id = self.store.register(new)
-        self.store.update(new.id, vfy=new) # This doesn't feel too nice.
+        # XXX? ID is not saved now, but relied upon to be __didex_loaded__ later.
         return new
 
 class autostore(object, metaclass=autotype):
     def __init__(self):
         self.id = None
 
+    def __didex_loaded__(self, store, id):
+        assert self.id is None or self.id == id
+        self.id = id
+
     def save(self):
         self.store.update(self.id, vfy=self)
 
index bff26e7..ddb8285 100644 (file)
@@ -65,8 +65,7 @@ class descbase(base):
     def __init__(self, store, indextype, name, datatype, default):
         super().__init__(store, indextype, name, datatype)
         self.default = default
-        self.mattr = "__idx_%s_new" % name
-        self.iattr = "__idx_%s_cur" % name
+        self.mattr = "__ival_%s" % name
 
     def __get__(self, obj, cls):
         if obj is None: return self
@@ -85,20 +84,24 @@ class simple(descbase):
     def register(self, id, obj, tx):
         val = self.__get__(obj, None)
         self.index(tx).put(val, id, tx=tx)
-        tx.postcommit(lambda: setattr(obj, self.iattr, val))
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
     def unregister(self, id, obj, tx):
-        self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx)
-        tx.postcommit(lambda: delattr(obj, self.iattr))
+        self.index(tx).remove(self.store.icache[obj, self], id, tx=tx)
+        tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
 
     def update(self, id, obj, tx):
         val = self.__get__(obj, None)
-        ival = getattr(obj, self.iattr)
+        ival = self.store.icache[obj, self]
         if val != ival:
             idx = self.index(tx)
             idx.remove(ival, id, tx=tx)
             idx.put(val, id, tx=tx)
-            tx.postcommit(lambda: setattr(obj, self.iattr, val))
+            tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+    def loaded(self, id, obj, tx):
+        val = self.__get__(obj, None)
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
 class multi(descbase):
     def __init__(self, store, indextype, name, datatype):
@@ -109,30 +112,33 @@ class multi(descbase):
         idx = self.index(tx)
         for val in vals:
             idx.put(val, id, tx=tx)
-        tx.postcommit(lambda: setattr(obj, self.iattr, vals))
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), vals))
 
     def unregister(self, id, obj, tx):
         idx = self.index(tx)
-        for val in getattr(obj, self.iattr):
+        for val in self.store.icache[obj, self]:
             idx.remove(val, id, tx=tx)
-        tx.postcommit(lambda: delattr(obj, self.iattr))
+        tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
 
     def update(self, id, obj, tx):
         vals = frozenset(self.__get__(obj, None))
-        ivals = getattr(obj, self.iattr)
+        ivals = self.store.icache[obj, self]
         if vals != ivals:
             idx = self.index(tx)
             for val in ivals - vals:
                 idx.remove(val, id, tx=tx)
             for val in vals - ivals:
                 idx.put(val, id, tx=tx)
-            tx.postcommit(lambda: setattr(obj, self.iattr, vals))
+            tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+    def loaded(self, id, obj, tx):
+        vals = frozenset(self.__get__(obj, None))
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
 class compound(base):
     def __init__(self, indextype, name, *parts):
         super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
         self.parts = parts
-        self.iattr = "__idx_%s_cur" % name
 
     def minim(self, *parts):
         return self.typ.minim(*parts)
@@ -148,20 +154,24 @@ class compound(base):
     def register(self, id, obj, tx):
         val = tuple(part.__get__(obj, None) for part in self.parts)
         self.index(tx).put(val, id, tx=tx)
-        tx.postcommit(lambda: setattr(obj, self.iattr, val))
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
     def unregister(self, id, obj, tx):
-        self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx)
-        tx.postcommit(lambda: delattr(obj, self.iattr))
+        self.index(tx).remove(self.store.icache[obj, self], id, tx=tx)
+        tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
 
     def update(self, id, obj, tx):
         val = tuple(part.__get__(obj, None) for part in self.parts)
-        ival = getattr(obj, self.iattr)
+        ival = self.store.icache[obj, self]
         if val != ival:
             idx = self.index(tx)
             idx.remove(ival, id, tx=tx)
             idx.put(val, id, tx=tx)
-            tx.postcommit(lambda: setattr(obj, self.iattr, val))
+            tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+    def loaded(self, id, obj, tx):
+        val = tuple(part.__get__(obj, None) for part in self.parts)
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
 class idlink(object):
     def __init__(self, name, atype):