| 1 | import struct, contextlib, math |
| 2 | from . import db, lib |
| 3 | from .db import bd, txnfun |
| 4 | |
| 5 | deadlock = bd.DBLockDeadlockError |
| 6 | notfound = bd.DBNotFoundError |
| 7 | |
| 8 | class simpletype(object): |
| 9 | def __init__(self, encode, decode): |
| 10 | self.enc = encode |
| 11 | self.dec = decode |
| 12 | |
| 13 | def encode(self, ob): |
| 14 | return self.enc(ob) |
| 15 | def decode(self, dat): |
| 16 | return self.dec(dat) |
| 17 | def compare(self, a, b): |
| 18 | if a < b: |
| 19 | return -1 |
| 20 | elif a > b: |
| 21 | return 1 |
| 22 | else: |
| 23 | return 0 |
| 24 | |
| 25 | @classmethod |
| 26 | def struct(cls, fmt): |
| 27 | return cls(lambda ob: struct.pack(fmt, ob), |
| 28 | lambda dat: struct.unpack(fmt, dat)[0]) |
| 29 | |
| 30 | class maybe(object): |
| 31 | def __init__(self, bk): |
| 32 | self.bk = bk |
| 33 | |
| 34 | def encode(self, ob): |
| 35 | if ob is None: return b"" |
| 36 | return b"\0" + self.bk.encode(ob) |
| 37 | def decode(self, dat): |
| 38 | if dat == b"": return None |
| 39 | return self.bk.dec(dat[1:]) |
| 40 | def compare(self, a, b): |
| 41 | if a is b is None: |
| 42 | return 0 |
| 43 | elif a is None: |
| 44 | return -1 |
| 45 | elif b is None: |
| 46 | return 1 |
| 47 | else: |
| 48 | return self.bk.compare(a[1:], b[1:]) |
| 49 | |
| 50 | def floatcmp(a, b): |
| 51 | if math.isnan(a) and math.isnan(b): |
| 52 | return 0 |
| 53 | elif math.isnan(a): |
| 54 | return -1 |
| 55 | elif math.isnan(b): |
| 56 | return 1 |
| 57 | elif a < b: |
| 58 | return -1 |
| 59 | elif a > b: |
| 60 | return 1 |
| 61 | else: |
| 62 | return 0 |
| 63 | |
| 64 | t_int = simpletype.struct(">q") |
| 65 | t_uint = simpletype.struct(">Q") |
| 66 | t_float = simpletype.struct(">d") |
| 67 | t_float.compare = floatcmp |
| 68 | t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8"))) |
| 69 | |
| 70 | class index(object): |
| 71 | def __init__(self, db, name, datatype): |
| 72 | self.db = db |
| 73 | self.nm = name |
| 74 | self.typ = datatype |
| 75 | |
| 76 | missing = object() |
| 77 | |
| 78 | class ordered(index, lib.closable): |
| 79 | def __init__(self, db, name, datatype, create=True): |
| 80 | super().__init__(db, name, datatype) |
| 81 | fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT |
| 82 | if create: fl |= bd.DB_CREATE |
| 83 | def initdb(db): |
| 84 | def compare(a, b): |
| 85 | if a == b == "": return 0 |
| 86 | return self.typ.compare(self.typ.decode(a), self.typ.decode(b)) |
| 87 | db.set_flags(bd.DB_DUPSORT) |
| 88 | db.set_bt_compare(compare) |
| 89 | self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb) |
| 90 | self.bk.set_get_returns_none(False) |
| 91 | |
| 92 | def close(self): |
| 93 | self.bk.close() |
| 94 | |
| 95 | class cursor(lib.closable): |
| 96 | def __init__(self, idx, cur, item, stop): |
| 97 | self.idx = idx |
| 98 | self.cur = cur |
| 99 | self.item = item |
| 100 | self.stop = stop |
| 101 | |
| 102 | def close(self): |
| 103 | if self.cur is not None: |
| 104 | self.cur.close() |
| 105 | |
| 106 | def __iter__(self): |
| 107 | return self |
| 108 | |
| 109 | def peek(self): |
| 110 | if self.item is None: |
| 111 | raise StopIteration() |
| 112 | rk, rv = self.item |
| 113 | rk = self.idx.typ.decode(rk) |
| 114 | rv = struct.unpack(">Q", rv)[0] |
| 115 | if self.stop(rk): |
| 116 | self.item = None |
| 117 | raise StopIteration() |
| 118 | return rk, rv |
| 119 | |
| 120 | def __next__(self): |
| 121 | rk, rv = self.peek() |
| 122 | try: |
| 123 | while True: |
| 124 | try: |
| 125 | self.item = self.cur.next() |
| 126 | break |
| 127 | except deadlock: |
| 128 | continue |
| 129 | except notfound: |
| 130 | self.item = None |
| 131 | return rk, rv |
| 132 | |
| 133 | def skip(self, n=1): |
| 134 | try: |
| 135 | for i in range(n): |
| 136 | next(self) |
| 137 | except StopIteration: |
| 138 | return |
| 139 | |
| 140 | def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False): |
| 141 | while True: |
| 142 | try: |
| 143 | cur = self.bk.cursor() |
| 144 | done = False |
| 145 | try: |
| 146 | if match is not missing: |
| 147 | try: |
| 148 | k, v = cur.set(self.typ.encode(match)) |
| 149 | except notfound: |
| 150 | return self.cursor(None, None, None, None) |
| 151 | else: |
| 152 | done = True |
| 153 | return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0)) |
| 154 | elif all: |
| 155 | try: |
| 156 | k, v = cur.first() |
| 157 | except notfound: |
| 158 | return self.cursor(None, None, None, None) |
| 159 | else: |
| 160 | done = True |
| 161 | return self.cursor(self, cur, (k, v), lambda o: False) |
| 162 | elif ge is not missing or gt is not missing or lt is not missing or le is not missing: |
| 163 | skip = False |
| 164 | try: |
| 165 | if ge is not missing: |
| 166 | k, v = cur.set_range(self.typ.encode(ge)) |
| 167 | elif gt is not missing: |
| 168 | k, v = cur.set_range(self.typ.encode(gt)) |
| 169 | skip = True |
| 170 | else: |
| 171 | k, v = cur.first() |
| 172 | except notfound: |
| 173 | return self.cursor(None, None, None, None) |
| 174 | if lt is not missing: |
| 175 | stop = lambda o: self.typ.compare(o, lt) >= 0 |
| 176 | elif le is not missing: |
| 177 | stop = lambda o: self.typ.compare(o, le) > 0 |
| 178 | else: |
| 179 | stop = lambda o: False |
| 180 | ret = self.cursor(self, cur, (k, v), stop) |
| 181 | if skip: |
| 182 | try: |
| 183 | while self.typ.compare(ret.peek()[0], gt) == 0: |
| 184 | next(ret) |
| 185 | except StopIteration: |
| 186 | pass |
| 187 | done = True |
| 188 | return ret |
| 189 | else: |
| 190 | raise NameError("invalid get() specification") |
| 191 | finally: |
| 192 | if not done: |
| 193 | cur.close() |
| 194 | except deadlock: |
| 195 | continue |
| 196 | |
| 197 | @txnfun(lambda self: self.db.env.env) |
| 198 | def put(self, key, id, *, tx): |
| 199 | obid = struct.pack(">Q", id) |
| 200 | if not self.db.ob.has_key(obid, txn=tx.tx): |
| 201 | raise ValueError("no such object in database: " + str(id)) |
| 202 | try: |
| 203 | self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA) |
| 204 | except bd.DBKeyExistError: |
| 205 | return False |
| 206 | return True |
| 207 | |
| 208 | @txnfun(lambda self: self.db.env.env) |
| 209 | def remove(self, key, id, *, tx): |
| 210 | obid = struct.pack(">Q", id) |
| 211 | if not self.db.ob.has_key(obid, txn=tx.tx): |
| 212 | raise ValueError("no such object in database: " + str(id)) |
| 213 | cur = self.bk.cursor(txn=tx.tx) |
| 214 | try: |
| 215 | try: |
| 216 | cur.get_both(self.typ.encode(key), obid) |
| 217 | except notfound: |
| 218 | return False |
| 219 | cur.delete() |
| 220 | finally: |
| 221 | cur.close() |
| 222 | return True |