from . import db, lib
from .db import bd, txnfun, dloopfun
-__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
+__all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
deadlock = bd.DBLockDeadlockError
notfound = bd.DBNotFoundError
return cls(lambda ob: struct.pack(fmt, ob),
lambda dat: struct.unpack(fmt, dat)[0])
+class foldtype(simpletype):
+ def __init__(self, encode, decode, fold):
+ super().__init__(encode, decode)
+ self.fold = fold
+
+ def compare(self, a, b):
+ return super().compare(self.fold(a), self.fold(b))
+
class maybe(object):
def __init__(self, bk):
self.bk = bk
def __init__(self, *parts):
self.parts = parts
+ small = object()
+ large = object()
+ def minim(self, *parts):
+ return parts + tuple([self.small] * (len(self.parts) - len(parts)))
+ def maxim(self, *parts):
+ return parts + tuple([self.large] * (len(self.parts) - len(parts)))
+
def encode(self, obs):
if len(obs) != len(self.parts):
raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
buf = bytearray()
for ob, part in zip(obs, self.parts):
- dat = part.encode(ob)
- if len(dat) < 128:
- buf.append(0x80 | len(dat))
- buf.extend(dat)
+ if ob is self.small:
+ buf.append(0x01)
+ elif ob is self.large:
+ buf.append(0x02)
else:
- buf.extend(struct.pack(">i", len(dat)))
- buf.extend(dat)
+ dat = part.encode(ob)
+ if len(dat) < 128:
+ buf.append(0x80 | len(dat))
+ buf.extend(dat)
+ else:
+ buf.extend(struct.pack(">BI", 0, len(dat)))
+ buf.extend(dat)
return bytes(buf)
def decode(self, dat):
ret = []
off = 0
for part in self.parts:
- if dat[off] & 0x80:
- ln = dat[off] & 0x7f
- off += 1
+ fl = dat[off]
+ off += 1
+ if fl & 0x80:
+ ln = fl & 0x7f
+ elif fl == 0x01:
+ ret.append(self.small)
+ continue
+ elif fl == 0x02:
+ ret.append(self.large)
+ continue
else:
- ln = struct.unpack(">i", dat[off:off + 4])[0]
+ ln = struct.unpack(">I", dat[off:off + 4])[0]
off += 4
- ret.append(part.decode(dat[off:off + len]))
- off += len
+ ret.append(part.decode(dat[off:off + ln]))
+ off += ln
return tuple(ret)
def compare(self, al, bl):
if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
for a, b, part in zip(al, bl, self.parts):
+ if a in (self.small, self.large) or b in (self.small, self.large):
+ if a is b:
+ return 0
+ if a is self.small:
+ return -1
+ elif b is self.small:
+ return 1
+ elif a is self.large:
+ return 1
+ elif b is self.large:
+ return -1
c = part.compare(a, b)
if c != 0:
return c
else:
return 0
+t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True))
t_int = simpletype.struct(">q")
t_uint = simpletype.struct(">Q")
+t_dbid = t_uint
t_float = simpletype.struct(">d")
t_float.compare = floatcmp
t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
+t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
+ (lambda st: st.lower()))
class index(object):
def __init__(self, db, name, datatype):
missing = object()
class ordered(index, lib.closable):
- def __init__(self, db, name, datatype, create=True):
+ def __init__(self, db, name, datatype, create=True, *, tx=None):
super().__init__(db, name, datatype)
- fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
+ fl = bd.DB_THREAD
if create: fl |= bd.DB_CREATE
def initdb(db):
def compare(a, b):
return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
db.set_flags(bd.DB_DUPSORT)
db.set_bt_compare(compare)
- self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
+ self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx)
self.bk.set_get_returns_none(False)
def close(self):