-import struct, contextlib
+import struct, contextlib, math
from . import db, lib
-from .db import bd
+from .db import bd, txnfun, dloopfun
+
+__all__ = ["maybe", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
deadlock = bd.DBLockDeadlockError
notfound = bd.DBNotFoundError
return cls(lambda ob: struct.pack(fmt, ob),
lambda dat: struct.unpack(fmt, dat)[0])
+class foldtype(simpletype):
+ def __init__(self, encode, decode, fold):
+ super().__init__(encode, decode)
+ self.fold = fold
+
+ def compare(self, a, b):
+ return super().compare(self.fold(a), self.fold(b))
+
class maybe(object):
def __init__(self, bk):
self.bk = bk
else:
return self.bk.compare(a[1:], b[1:])
-t_int = simpletype.struct(">Q")
+class compound(object):
+ def __init__(self, *parts):
+ self.parts = parts
+
+ small = object()
+ large = object()
+ def minim(self, *parts):
+ return parts + tuple([self.small] * (len(self.parts) - len(parts)))
+ def maxim(self, *parts):
+ return parts + tuple([self.large] * (len(self.parts) - len(parts)))
+
+ def encode(self, obs):
+ if len(obs) != len(self.parts):
+ raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
+ buf = bytearray()
+ for ob, part in zip(obs, self.parts):
+ if ob is self.small:
+ buf.append(0x01)
+ elif ob is self.large:
+ buf.append(0x02)
+ else:
+ dat = part.encode(ob)
+ if len(dat) < 128:
+ buf.append(0x80 | len(dat))
+ buf.extend(dat)
+ else:
+ buf.extend(struct.pack(">BI", 0, len(dat)))
+ buf.extend(dat)
+ return bytes(buf)
+ def decode(self, dat):
+ ret = []
+ off = 0
+ for part in self.parts:
+ fl = dat[off]
+ off += 1
+ if fl & 0x80:
+ ln = fl & 0x7f
+ elif fl == 0x01:
+ ret.append(self.small)
+ continue
+ elif fl == 0x02:
+ ret.append(self.large)
+ continue
+ else:
+ ln = struct.unpack(">I", dat[off:off + 4])[0]
+ off += 4
+ ret.append(part.decode(dat[off:off + ln]))
+ off += ln
+ return tuple(ret)
+ def compare(self, al, bl):
+ if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
+ raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
+ for a, b, part in zip(al, bl, self.parts):
+ if a in (self.small, self.large) or b in (self.small, self.large):
+ if a is b:
+ return 0
+ if a is self.small:
+ return -1
+ elif b is self.small:
+ return 1
+ elif a is self.large:
+ return 1
+ elif b is self.large:
+ return -1
+ c = part.compare(a, b)
+ if c != 0:
+ return c
+ return 0
+
+def floatcmp(a, b):
+ if math.isnan(a) and math.isnan(b):
+ return 0
+ elif math.isnan(a):
+ return -1
+ elif math.isnan(b):
+ return 1
+ elif a < b:
+ return -1
+ elif a > b:
+ return 1
+ else:
+ return 0
+
+t_int = simpletype.struct(">q")
+t_uint = simpletype.struct(">Q")
+t_dbid = t_uint
+t_float = simpletype.struct(">d")
+t_float.compare = floatcmp
+t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
+t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
+ (lambda st: st.lower()))
class index(object):
def __init__(self, db, name, datatype):
missing = object()
class ordered(index, lib.closable):
- def __init__(self, db, name, datatype, duplicates, create=True):
+ def __init__(self, db, name, datatype, create=True):
super().__init__(db, name, datatype)
- self.dup = duplicates
fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
if create: fl |= bd.DB_CREATE
def initdb(db):
self.bk.close()
class cursor(lib.closable):
- def __init__(self, idx, cur, item, stop):
+ def __init__(self, idx, fd, fi, ld, li, reverse):
self.idx = idx
- self.cur = cur
- self.item = item
- self.stop = stop
+ self.typ = idx.typ
+ self.cur = self.idx.bk.cursor()
+ self.item = None
+ self.fd = fd
+ self.fi = fi
+ self.ld = ld
+ self.li = li
+ self.rev = reverse
def close(self):
if self.cur is not None:
self.cur.close()
+ self.cur = None
def __iter__(self):
return self
- def peek(self):
- if self.item is None:
- raise StopIteration()
- rk, rv = self.item
- rk = self.idx.typ.decode(rk)
- rv = struct.unpack(">Q", rv)[0]
- if self.stop(rk):
- self.item = None
- raise StopIteration()
- return rk, rv
+ def _decode(self, d):
+ k, v = d
+ k = self.typ.decode(k)
+ v = struct.unpack(">Q", v)[0]
+ return k, v
- def __next__(self):
- rk, rv = self.peek()
+ @dloopfun
+ def first(self):
try:
- while True:
+ if self.fd is missing:
+ self.item = self._decode(self.cur.first())
+ else:
+ k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
+ if not self.fi:
+ while self.typ.compare(k, self.fd) == 0:
+ k, v = self._decode(self.cur.next())
+ self.item = k, v
+ except notfound:
+ self.item = StopIteration
+
+ @dloopfun
+ def last(self):
+ try:
+ if self.ld is missing:
+ self.item = self._decode(self.cur.last())
+ else:
try:
- self.item = self.cur.next()
- break
- except deadlock:
- continue
+ k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+ except notfound:
+ k, v = self._decode(self.cur.last())
+ if self.li:
+ while self.typ.compare(k, self.ld) == 0:
+ k, v = self._decode(self.cur.next())
+ while self.typ.compare(k, self.ld) > 0:
+ k, v = self._decode(self.cur.prev())
+ else:
+ while self.typ.compare(k, self.ld) >= 0:
+ k, v = self._decode(self.cur.prev())
+ self.item = k, v
except notfound:
- self.item = None
- return rk, rv
+ self.item = StopIteration
+
+ @dloopfun
+ def next(self):
+ try:
+ k, v = self.item = self._decode(self.cur.next())
+ if (self.ld is not missing and
+ ((self.li and self.typ.compare(k, self.ld) > 0) or
+ (not self.li and self.typ.compare(k, self.ld) >= 0))):
+ self.item = StopIteration
+ except notfound:
+ self.item = StopIteration
+
+ @dloopfun
+ def prev(self):
+ try:
+ self.item = self._decode(self.cur.prev())
+ if (self.fd is not missing and
+ ((self.fi and self.typ.compare(k, self.fd) < 0) or
+ (not self.fi and self.typ.compare(k, self.fd) <= 0))):
+ self.item = StopIteration
+ except notfound:
+ self.item = StopIteration
+
+ def __next__(self):
+ if self.item is None:
+ if not self.rev:
+ self.next()
+ else:
+ self.prev()
+ if self.item is StopIteration:
+ raise StopIteration()
+ ret, self.item = self.item, None
+ return ret
def skip(self, n=1):
try:
except StopIteration:
return
- def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
- while True:
- try:
- cur = self.bk.cursor()
- done = False
- try:
- if match is not missing:
- try:
- k, v = cur.set(self.typ.encode(match))
- except notfound:
- return self.cursor(None, None, None, None)
- else:
- done = True
- return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
- elif all:
- try:
- k, v = cur.first()
- except notfound:
- return self.cursor(None, None, None, None)
- else:
- done = True
- return self.cursor(self, cur, (k, v), lambda o: False)
- elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
- skip = False
- try:
- if ge is not missing:
- k, v = cur.set_range(self.typ.encode(ge))
- elif gt is not missing:
- k, v = cur.set_range(self.typ.encode(gt))
- skip = True
- else:
- k, v = cur.first()
- except notfound:
- return self.cursor(None, None, None, None)
- if lt is not missing:
- stop = lambda o: self.typ.compare(o, lt) >= 0
- elif le is not missing:
- stop = lambda o: self.typ.compare(o, le) > 0
- else:
- stop = lambda o: False
- ret = self.cursor(self, cur, (k, v), stop)
- if skip:
- try:
- while self.typ.compare(ret.peek()[0], gt) == 0:
- next(ret)
- except StopIteration:
- pass
- done = True
- return ret
- else:
- raise NameError("invalid get() specification")
- finally:
- if not done:
- cur.close()
- except deadlock:
- continue
+ def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
+ if all:
+ cur = self.cursor(self, missing, True, missing, True, reverse)
+ elif match is not missing:
+ cur = self.cursor(self, match, True, match, True, reverse)
+ elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
+ if ge is not missing:
+ fd, fi = ge, True
+ elif gt is not missing:
+ fd, fi = gt, False
+ else:
+ fd, fi = missing, True
+ if le is not missing:
+ ld, li = le, True
+ elif lt is not missing:
+ ld, li = lt, False
+ else:
+ ld, li = missing, True
+ cur = self.cursor(self, fd, fi, ld, li, reverse)
+ else:
+ raise NameError("invalid get() specification")
+ done = False
+ try:
+ if not reverse:
+ cur.first()
+ else:
+ cur.last()
+ done = True
+ return cur
+ finally:
+ if not done:
+ cur.close()
- def put(self, key, id):
- while True:
- try:
- with db.txn(self.db.env.env) as tx:
- obid = struct.pack(">Q", id)
- if not self.db.ob.has_key(obid, txn=tx.tx):
- raise ValueError("no such object in database: " + str(id))
- try:
- self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
- except bd.DBKeyExistError:
- return False
- tx.commit()
- return True
- except deadlock:
- continue
+ @txnfun(lambda self: self.db.env.env)
+ def put(self, key, id, *, tx):
+ obid = struct.pack(">Q", id)
+ if not self.db.ob.has_key(obid, txn=tx.tx):
+ raise ValueError("no such object in database: " + str(id))
+ try:
+ self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
+ except bd.DBKeyExistError:
+ return False
+ return True
- def remove(self, key, id):
- while True:
+ @txnfun(lambda self: self.db.env.env)
+ def remove(self, key, id, *, tx):
+ obid = struct.pack(">Q", id)
+ if not self.db.ob.has_key(obid, txn=tx.tx):
+ raise ValueError("no such object in database: " + str(id))
+ cur = self.bk.cursor(txn=tx.tx)
+ try:
try:
- with db.txn(self.db.env.env) as tx:
- obid = struct.pack(">Q", id)
- if not self.db.ob.has_key(obid, txn=tx.tx):
- raise ValueError("no such object in database: " + str(id))
- cur = self.bk.cursor(txn=tx.tx)
- try:
- try:
- cur.get_both(self.typ.encode(key), obid)
- except notfound:
- return False
- cur.delete()
- finally:
- cur.close()
- tx.commit()
- return True
- except deadlock:
- continue
+ cur.get_both(self.typ.encode(key), obid)
+ except notfound:
+ return False
+ cur.delete()
+ finally:
+ cur.close()
+ return True