else:
return self.bk.compare(a[1:], b[1:])
+class compound(object):
+ def __init__(self, *parts):
+ self.parts = parts
+
+ def encode(self, obs):
+ if len(obs) != len(self.parts):
+ raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
+ buf = bytearray()
+ for ob, part in zip(obs, self.parts):
+ dat = part.encode(ob)
+ if len(dat) < 128:
+ buf.append(0x80 | len(dat))
+ buf.extend(dat)
+ else:
+ buf.extend(struct.pack(">i", len(dat)))
+ buf.extend(dat)
+ return bytes(buf)
+ def decode(self, dat):
+ ret = []
+ off = 0
+ for part in self.parts:
+ if dat[off] & 0x80:
+ ln = dat[off] & 0x7f
+ off += 1
+ else:
+ ln = struct.unpack(">i", dat[off:off + 4])[0]
+ off += 4
+ ret.append(part.decode(dat[off:off + len]))
+ off += len
+ return tuple(ret)
+ def compare(self, al, bl):
+ if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
+ raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
+ for a, b, part in zip(al, bl, self.parts):
+ c = part.compare(a, b)
+ if c != 0:
+ return c
+ return 0
+
def floatcmp(a, b):
if math.isnan(a) and math.isnan(b):
return 0
import threading
-from . import store, lib
+from . import store, lib, index
from .store import storedesc
__all__ = ["simple", "multi"]
self.bk.skip(n)
class base(storedesc):
- def __init__(self, store, indextype, name, datatype, default):
+ def __init__(self, store, indextype, name, datatype):
self.store = store
self.indextype = indextype
self.name = name
self.typ = datatype
- self.default = default
self.idx = None
self.lk = threading.Lock()
- self.mattr = "__idx_%s_new" % name
- self.iattr = "__idx_%s_cur" % name
def index(self):
with self.lk:
self.idx = self.indextype(self.store.db(), self.name, self.typ)
return self.idx
+ def get(self, **kwargs):
+ return cursor(self.index().get(**kwargs), self.store)
+
+class descbase(base):
+ def __init__(self, store, indextype, name, datatype, default):
+ super().__init__(store, indextype, name, datatype)
+ self.default = default
+ self.mattr = "__idx_%s_new" % name
+ self.iattr = "__idx_%s_cur" % name
+
def __get__(self, obj, cls):
if obj is None: return self
return getattr(obj, self.mattr, self.default)
def __delete__(self, obj):
delattr(obj, self.mattr)
- def get(self, **kwargs):
- return cursor(self.index().get(**kwargs), self.store)
-
-class simple(base):
+class simple(descbase):
def __init__(self, store, indextype, name, datatype, default=None):
super().__init__(store, indextype, name, datatype, default)
idx.put(val, id, tx=tx)
tx.postcommit(lambda: setattr(obj, self.iattr, val))
-class multi(base):
+class multi(descbase):
def __init__(self, store, indextype, name, datatype):
super().__init__(store, indextype, name, datatype, ())
for val in vals - ivals:
idx.put(val, id, tx=tx)
tx.postcommit(lambda: setattr(obj, self.iattr, vals))
+
+class compound(base):
+ def __init__(self, indextype, name, *parts):
+ super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
+ self.parts = parts
+ self.iattr = "__idx_%s_cur" % name
+
+ def register(self, id, obj, tx):
+ val = tuple(part.__get__(obj, None) for part in self.parts)
+ self.index().put(val, id, tx=tx)
+ tx.postcommit(lambda: setattr(obj, self.iattr, val))
+
+ def unregister(self, id, obj, tx):
+ self.index().remove(getattr(obj, self.iattr), id, tx=tx)
+ tx.postcommit(lambda: delattr(obj, self.iattr))
+
+ def update(self, id, obj, tx):
+ val = tuple(part.__get__(obj, None) for part in self.parts)
+ ival = getattr(obj, self.iattr)
+ if val != ival:
+ idx = self.index()
+ idx.remove(ival, id, tx=tx)
+ idx.put(val, id, tx=tx)
+ tx.postcommit(lambda: setattr(obj, self.iattr, val))