Fixed index opening thread-safety by including it in local transaction.
authorFredrik Tolf <fredrik@dolda2000.com>
Mon, 3 Aug 2015 00:39:14 +0000 (02:39 +0200)
committerFredrik Tolf <fredrik@dolda2000.com>
Mon, 3 Aug 2015 00:39:14 +0000 (02:39 +0200)
didex/db.py
didex/index.py
didex/values.py

index 33eb0d8..539fc85 100644 (file)
@@ -125,21 +125,18 @@ class database(object):
         self.env = env
         self.mode = mode
         self.fnm = name
-        fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
+        fl = bd.DB_THREAD
         if create:
             fl |= bd.DB_CREATE
         self.cf = self._opendb("cf", bd.DB_HASH, fl)
         self.ob = self._opendb("ob", bd.DB_HASH, fl)
 
-    def _opendb(self, dnm, typ, fl, init=None):
+    @txnfun(lambda self: self.env.env)
+    def _opendb(self, dnm, typ, fl, init=None, *, tx):
         ret = bd.DB(self.env.env)
         if init: init(ret)
-        while True:
-            try:
-                ret.open(self.fnm, dnm, typ, fl, self.mode)
-            except deadlock:
-                continue
-            return ret
+        ret.open(self.fnm, dnm, typ, fl, self.mode, txn=tx.tx)
+        return ret
 
     @txnfun(lambda self: self.env.env)
     def _nextseq(self, *, tx):
index c2c55b6..5b5a5bc 100644 (file)
@@ -158,9 +158,9 @@ class index(object):
 missing = object()
 
 class ordered(index, lib.closable):
-    def __init__(self, db, name, datatype, create=True):
+    def __init__(self, db, name, datatype, create=True, *, tx=None):
         super().__init__(db, name, datatype)
-        fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
+        fl = bd.DB_THREAD
         if create: fl |= bd.DB_CREATE
         def initdb(db):
             def compare(a, b):
@@ -168,7 +168,7 @@ class ordered(index, lib.closable):
                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
             db.set_flags(bd.DB_DUPSORT)
             db.set_bt_compare(compare)
-        self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
+        self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx)
         self.bk.set_get_returns_none(False)
 
     def close(self):
index ecb78dd..a4473f0 100644 (file)
@@ -31,14 +31,14 @@ class base(storedesc):
         self.idx = None
         self.lk = threading.Lock()
 
-    def index(self):
+    def index(self, tx):
         with self.lk:
             if self.idx is None:
-                self.idx = self.indextype(self.store.db(), self.name, self.typ)
+                self.idx = self.indextype(self.store.db(), self.name, self.typ, tx=tx)
             return self.idx
 
     def get(self, **kwargs):
-        return cursor(self.index().get(**kwargs), self.store)
+        return cursor(self.index(None).get(**kwargs), self.store)
 
     def get1(self, *, check=False, default=KeyError, **kwargs):
         with self.get(**kwargs) as cursor:
@@ -84,18 +84,18 @@ class simple(descbase):
 
     def register(self, id, obj, tx):
         val = self.__get__(obj, None)
-        self.index().put(val, id, tx=tx)
+        self.index(tx).put(val, id, tx=tx)
         tx.postcommit(lambda: setattr(obj, self.iattr, val))
 
     def unregister(self, id, obj, tx):
-        self.index().remove(getattr(obj, self.iattr), id, tx=tx)
+        self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx)
         tx.postcommit(lambda: delattr(obj, self.iattr))
 
     def update(self, id, obj, tx):
         val = self.__get__(obj, None)
         ival = getattr(obj, self.iattr)
         if val != ival:
-            idx = self.index()
+            idx = self.index(tx)
             idx.remove(ival, id, tx=tx)
             idx.put(val, id, tx=tx)
             tx.postcommit(lambda: setattr(obj, self.iattr, val))
@@ -106,13 +106,13 @@ class multi(descbase):
 
     def register(self, id, obj, tx):
         vals = frozenset(self.__get__(obj, None))
-        idx = self.index()
+        idx = self.index(tx)
         for val in vals:
             idx.put(val, id, tx=tx)
         tx.postcommit(lambda: setattr(obj, self.iattr, vals))
 
     def unregister(self, id, obj, tx):
-        idx = self.index()
+        idx = self.index(tx)
         for val in getattr(obj, self.iattr):
             idx.remove(val, id, tx=tx)
         tx.postcommit(lambda: delattr(obj, self.iattr))
@@ -121,7 +121,7 @@ class multi(descbase):
         vals = frozenset(self.__get__(obj, None))
         ivals = getattr(obj, self.iattr)
         if vals != ivals:
-            idx = self.index()
+            idx = self.index(tx)
             for val in ivals - vals:
                 idx.remove(val, id, tx=tx)
             for val in vals - ivals:
@@ -147,18 +147,18 @@ class compound(base):
 
     def register(self, id, obj, tx):
         val = tuple(part.__get__(obj, None) for part in self.parts)
-        self.index().put(val, id, tx=tx)
+        self.index(tx).put(val, id, tx=tx)
         tx.postcommit(lambda: setattr(obj, self.iattr, val))
 
     def unregister(self, id, obj, tx):
-        self.index().remove(getattr(obj, self.iattr), id, tx=tx)
+        self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx)
         tx.postcommit(lambda: delattr(obj, self.iattr))
 
     def update(self, id, obj, tx):
         val = tuple(part.__get__(obj, None) for part in self.parts)
         ival = getattr(obj, self.iattr)
         if val != ival:
-            idx = self.index()
+            idx = self.index(tx)
             idx.remove(ival, id, tx=tx)
             idx.put(val, id, tx=tx)
             tx.postcommit(lambda: setattr(obj, self.iattr, val))