| 1 | import struct, contextlib, math |
| 2 | from . import db, lib |
| 3 | from .db import bd, txnfun, dloopfun |
| 4 | |
| 5 | __all__ = ["maybe", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"] |
| 6 | |
| 7 | deadlock = bd.DBLockDeadlockError |
| 8 | notfound = bd.DBNotFoundError |
| 9 | |
| 10 | class simpletype(object): |
| 11 | def __init__(self, encode, decode): |
| 12 | self.enc = encode |
| 13 | self.dec = decode |
| 14 | |
| 15 | def encode(self, ob): |
| 16 | return self.enc(ob) |
| 17 | def decode(self, dat): |
| 18 | return self.dec(dat) |
| 19 | def compare(self, a, b): |
| 20 | if a < b: |
| 21 | return -1 |
| 22 | elif a > b: |
| 23 | return 1 |
| 24 | else: |
| 25 | return 0 |
| 26 | |
| 27 | @classmethod |
| 28 | def struct(cls, fmt): |
| 29 | return cls(lambda ob: struct.pack(fmt, ob), |
| 30 | lambda dat: struct.unpack(fmt, dat)[0]) |
| 31 | |
| 32 | class foldtype(simpletype): |
| 33 | def __init__(self, encode, decode, fold): |
| 34 | super().__init__(encode, decode) |
| 35 | self.fold = fold |
| 36 | |
| 37 | def compare(self, a, b): |
| 38 | return super().compare(self.fold(a), self.fold(b)) |
| 39 | |
| 40 | class maybe(object): |
| 41 | def __init__(self, bk): |
| 42 | self.bk = bk |
| 43 | |
| 44 | def encode(self, ob): |
| 45 | if ob is None: return b"" |
| 46 | return b"\0" + self.bk.encode(ob) |
| 47 | def decode(self, dat): |
| 48 | if dat == b"": return None |
| 49 | return self.bk.dec(dat[1:]) |
| 50 | def compare(self, a, b): |
| 51 | if a is b is None: |
| 52 | return 0 |
| 53 | elif a is None: |
| 54 | return -1 |
| 55 | elif b is None: |
| 56 | return 1 |
| 57 | else: |
| 58 | return self.bk.compare(a[1:], b[1:]) |
| 59 | |
| 60 | class compound(object): |
| 61 | def __init__(self, *parts): |
| 62 | self.parts = parts |
| 63 | |
| 64 | small = object() |
| 65 | large = object() |
| 66 | def minim(self, *parts): |
| 67 | return parts + tuple([self.small] * (len(self.parts) - len(parts))) |
| 68 | def maxim(self, *parts): |
| 69 | return parts + tuple([self.large] * (len(self.parts) - len(parts))) |
| 70 | |
| 71 | def encode(self, obs): |
| 72 | if len(obs) != len(self.parts): |
| 73 | raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts)) |
| 74 | buf = bytearray() |
| 75 | for ob, part in zip(obs, self.parts): |
| 76 | if ob is self.small: |
| 77 | buf.append(0x01) |
| 78 | elif ob is self.large: |
| 79 | buf.append(0x02) |
| 80 | else: |
| 81 | dat = part.encode(ob) |
| 82 | if len(dat) < 128: |
| 83 | buf.append(0x80 | len(dat)) |
| 84 | buf.extend(dat) |
| 85 | else: |
| 86 | buf.extend(struct.pack(">BI", 0, len(dat))) |
| 87 | buf.extend(dat) |
| 88 | return bytes(buf) |
| 89 | def decode(self, dat): |
| 90 | ret = [] |
| 91 | off = 0 |
| 92 | for part in self.parts: |
| 93 | fl = dat[off] |
| 94 | off += 1 |
| 95 | if fl & 0x80: |
| 96 | ln = fl & 0x7f |
| 97 | elif fl == 0x01: |
| 98 | ret.append(self.small) |
| 99 | continue |
| 100 | elif fl == 0x02: |
| 101 | ret.append(self.large) |
| 102 | continue |
| 103 | else: |
| 104 | ln = struct.unpack(">I", dat[off:off + 4])[0] |
| 105 | off += 4 |
| 106 | ret.append(part.decode(dat[off:off + ln])) |
| 107 | off += ln |
| 108 | return tuple(ret) |
| 109 | def compare(self, al, bl): |
| 110 | if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)): |
| 111 | raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts)) |
| 112 | for a, b, part in zip(al, bl, self.parts): |
| 113 | if a in (self.small, self.large) or b in (self.small, self.large): |
| 114 | if a is b: |
| 115 | return 0 |
| 116 | if a is self.small: |
| 117 | return -1 |
| 118 | elif b is self.small: |
| 119 | return 1 |
| 120 | elif a is self.large: |
| 121 | return 1 |
| 122 | elif b is self.large: |
| 123 | return -1 |
| 124 | c = part.compare(a, b) |
| 125 | if c != 0: |
| 126 | return c |
| 127 | return 0 |
| 128 | |
| 129 | def floatcmp(a, b): |
| 130 | if math.isnan(a) and math.isnan(b): |
| 131 | return 0 |
| 132 | elif math.isnan(a): |
| 133 | return -1 |
| 134 | elif math.isnan(b): |
| 135 | return 1 |
| 136 | elif a < b: |
| 137 | return -1 |
| 138 | elif a > b: |
| 139 | return 1 |
| 140 | else: |
| 141 | return 0 |
| 142 | |
| 143 | t_int = simpletype.struct(">q") |
| 144 | t_uint = simpletype.struct(">Q") |
| 145 | t_dbid = t_uint |
| 146 | t_float = simpletype.struct(">d") |
| 147 | t_float.compare = floatcmp |
| 148 | t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8"))) |
| 149 | t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")), |
| 150 | (lambda st: st.lower())) |
| 151 | |
| 152 | class index(object): |
| 153 | def __init__(self, db, name, datatype): |
| 154 | self.db = db |
| 155 | self.nm = name |
| 156 | self.typ = datatype |
| 157 | |
| 158 | missing = object() |
| 159 | |
| 160 | class ordered(index, lib.closable): |
| 161 | def __init__(self, db, name, datatype, create=True): |
| 162 | super().__init__(db, name, datatype) |
| 163 | fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT |
| 164 | if create: fl |= bd.DB_CREATE |
| 165 | def initdb(db): |
| 166 | def compare(a, b): |
| 167 | if a == b == "": return 0 |
| 168 | return self.typ.compare(self.typ.decode(a), self.typ.decode(b)) |
| 169 | db.set_flags(bd.DB_DUPSORT) |
| 170 | db.set_bt_compare(compare) |
| 171 | self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb) |
| 172 | self.bk.set_get_returns_none(False) |
| 173 | |
| 174 | def close(self): |
| 175 | self.bk.close() |
| 176 | |
| 177 | class cursor(lib.closable): |
| 178 | def __init__(self, idx, fd, fi, ld, li, reverse): |
| 179 | self.idx = idx |
| 180 | self.typ = idx.typ |
| 181 | self.cur = self.idx.bk.cursor() |
| 182 | self.item = None |
| 183 | self.fd = fd |
| 184 | self.fi = fi |
| 185 | self.ld = ld |
| 186 | self.li = li |
| 187 | self.rev = reverse |
| 188 | |
| 189 | def close(self): |
| 190 | if self.cur is not None: |
| 191 | self.cur.close() |
| 192 | self.cur = None |
| 193 | |
| 194 | def __iter__(self): |
| 195 | return self |
| 196 | |
| 197 | def _decode(self, d): |
| 198 | k, v = d |
| 199 | k = self.typ.decode(k) |
| 200 | v = struct.unpack(">Q", v)[0] |
| 201 | return k, v |
| 202 | |
| 203 | @dloopfun |
| 204 | def first(self): |
| 205 | try: |
| 206 | if self.fd is missing: |
| 207 | self.item = self._decode(self.cur.first()) |
| 208 | else: |
| 209 | k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd))) |
| 210 | if not self.fi: |
| 211 | while self.typ.compare(k, self.fd) == 0: |
| 212 | k, v = self._decode(self.cur.next()) |
| 213 | self.item = k, v |
| 214 | except notfound: |
| 215 | self.item = StopIteration |
| 216 | |
| 217 | @dloopfun |
| 218 | def last(self): |
| 219 | try: |
| 220 | if self.ld is missing: |
| 221 | self.item = self._decode(self.cur.last()) |
| 222 | else: |
| 223 | try: |
| 224 | k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld))) |
| 225 | except notfound: |
| 226 | k, v = self._decode(self.cur.last()) |
| 227 | if self.li: |
| 228 | while self.typ.compare(k, self.ld) == 0: |
| 229 | k, v = self._decode(self.cur.next()) |
| 230 | while self.typ.compare(k, self.ld) > 0: |
| 231 | k, v = self._decode(self.cur.prev()) |
| 232 | else: |
| 233 | while self.typ.compare(k, self.ld) >= 0: |
| 234 | k, v = self._decode(self.cur.prev()) |
| 235 | self.item = k, v |
| 236 | except notfound: |
| 237 | self.item = StopIteration |
| 238 | |
| 239 | @dloopfun |
| 240 | def next(self): |
| 241 | try: |
| 242 | k, v = self.item = self._decode(self.cur.next()) |
| 243 | if (self.ld is not missing and |
| 244 | ((self.li and self.typ.compare(k, self.ld) > 0) or |
| 245 | (not self.li and self.typ.compare(k, self.ld) >= 0))): |
| 246 | self.item = StopIteration |
| 247 | except notfound: |
| 248 | self.item = StopIteration |
| 249 | |
| 250 | @dloopfun |
| 251 | def prev(self): |
| 252 | try: |
| 253 | self.item = self._decode(self.cur.prev()) |
| 254 | if (self.fd is not missing and |
| 255 | ((self.fi and self.typ.compare(k, self.fd) < 0) or |
| 256 | (not self.fi and self.typ.compare(k, self.fd) <= 0))): |
| 257 | self.item = StopIteration |
| 258 | except notfound: |
| 259 | self.item = StopIteration |
| 260 | |
| 261 | def __next__(self): |
| 262 | if self.item is None: |
| 263 | if not self.rev: |
| 264 | self.next() |
| 265 | else: |
| 266 | self.prev() |
| 267 | if self.item is StopIteration: |
| 268 | raise StopIteration() |
| 269 | ret, self.item = self.item, None |
| 270 | return ret |
| 271 | |
| 272 | def skip(self, n=1): |
| 273 | try: |
| 274 | for i in range(n): |
| 275 | next(self) |
| 276 | except StopIteration: |
| 277 | return |
| 278 | |
| 279 | def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False): |
| 280 | if all: |
| 281 | cur = self.cursor(self, missing, True, missing, True, reverse) |
| 282 | elif match is not missing: |
| 283 | cur = self.cursor(self, match, True, match, True, reverse) |
| 284 | elif ge is not missing or gt is not missing or lt is not missing or le is not missing: |
| 285 | if ge is not missing: |
| 286 | fd, fi = ge, True |
| 287 | elif gt is not missing: |
| 288 | fd, fi = gt, False |
| 289 | else: |
| 290 | fd, fi = missing, True |
| 291 | if le is not missing: |
| 292 | ld, li = le, True |
| 293 | elif lt is not missing: |
| 294 | ld, li = lt, False |
| 295 | else: |
| 296 | ld, li = missing, True |
| 297 | cur = self.cursor(self, fd, fi, ld, li, reverse) |
| 298 | else: |
| 299 | raise NameError("invalid get() specification") |
| 300 | done = False |
| 301 | try: |
| 302 | if not reverse: |
| 303 | cur.first() |
| 304 | else: |
| 305 | cur.last() |
| 306 | done = True |
| 307 | return cur |
| 308 | finally: |
| 309 | if not done: |
| 310 | cur.close() |
| 311 | |
| 312 | @txnfun(lambda self: self.db.env.env) |
| 313 | def put(self, key, id, *, tx): |
| 314 | obid = struct.pack(">Q", id) |
| 315 | if not self.db.ob.has_key(obid, txn=tx.tx): |
| 316 | raise ValueError("no such object in database: " + str(id)) |
| 317 | try: |
| 318 | self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA) |
| 319 | except bd.DBKeyExistError: |
| 320 | return False |
| 321 | return True |
| 322 | |
| 323 | @txnfun(lambda self: self.db.env.env) |
| 324 | def remove(self, key, id, *, tx): |
| 325 | obid = struct.pack(">Q", id) |
| 326 | if not self.db.ob.has_key(obid, txn=tx.tx): |
| 327 | raise ValueError("no such object in database: " + str(id)) |
| 328 | cur = self.bk.cursor(txn=tx.tx) |
| 329 | try: |
| 330 | try: |
| 331 | cur.get_both(self.typ.encode(key), obid) |
| 332 | except notfound: |
| 333 | return False |
| 334 | cur.delete() |
| 335 | finally: |
| 336 | cur.close() |
| 337 | return True |