Added index type for case-folded strings.
[didex.git] / didex / index.py
... / ...
CommitLineData
1import struct, contextlib, math
2from . import db, lib
3from .db import bd, txnfun, dloopfun
4
5__all__ = ["maybe", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
6
7deadlock = bd.DBLockDeadlockError
8notfound = bd.DBNotFoundError
9
10class simpletype(object):
11 def __init__(self, encode, decode):
12 self.enc = encode
13 self.dec = decode
14
15 def encode(self, ob):
16 return self.enc(ob)
17 def decode(self, dat):
18 return self.dec(dat)
19 def compare(self, a, b):
20 if a < b:
21 return -1
22 elif a > b:
23 return 1
24 else:
25 return 0
26
27 @classmethod
28 def struct(cls, fmt):
29 return cls(lambda ob: struct.pack(fmt, ob),
30 lambda dat: struct.unpack(fmt, dat)[0])
31
32class foldtype(simpletype):
33 def __init__(self, encode, decode, fold):
34 super().__init__(encode, decode)
35 self.fold = fold
36
37 def compare(self, a, b):
38 return super().compare(self.fold(a), self.fold(b))
39
40class maybe(object):
41 def __init__(self, bk):
42 self.bk = bk
43
44 def encode(self, ob):
45 if ob is None: return b""
46 return b"\0" + self.bk.encode(ob)
47 def decode(self, dat):
48 if dat == b"": return None
49 return self.bk.dec(dat[1:])
50 def compare(self, a, b):
51 if a is b is None:
52 return 0
53 elif a is None:
54 return -1
55 elif b is None:
56 return 1
57 else:
58 return self.bk.compare(a[1:], b[1:])
59
60class compound(object):
61 def __init__(self, *parts):
62 self.parts = parts
63
64 small = object()
65 large = object()
66 def minim(self, *parts):
67 return parts + tuple([self.small] * (len(self.parts) - len(parts)))
68 def maxim(self, *parts):
69 return parts + tuple([self.large] * (len(self.parts) - len(parts)))
70
71 def encode(self, obs):
72 if len(obs) != len(self.parts):
73 raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
74 buf = bytearray()
75 for ob, part in zip(obs, self.parts):
76 if ob is self.small:
77 buf.append(0x01)
78 elif ob is self.large:
79 buf.append(0x02)
80 else:
81 dat = part.encode(ob)
82 if len(dat) < 128:
83 buf.append(0x80 | len(dat))
84 buf.extend(dat)
85 else:
86 buf.extend(struct.pack(">BI", 0, len(dat)))
87 buf.extend(dat)
88 return bytes(buf)
89 def decode(self, dat):
90 ret = []
91 off = 0
92 for part in self.parts:
93 fl = dat[off]
94 off += 1
95 if fl & 0x80:
96 ln = fl & 0x7f
97 elif fl == 0x01:
98 ret.append(self.small)
99 continue
100 elif fl == 0x02:
101 ret.append(self.large)
102 continue
103 else:
104 ln = struct.unpack(">I", dat[off:off + 4])[0]
105 off += 4
106 ret.append(part.decode(dat[off:off + ln]))
107 off += ln
108 return tuple(ret)
109 def compare(self, al, bl):
110 if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
111 raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
112 for a, b, part in zip(al, bl, self.parts):
113 if a in (self.small, self.large) or b in (self.small, self.large):
114 if a is b:
115 return 0
116 if a is self.small:
117 return -1
118 elif b is self.small:
119 return 1
120 elif a is self.large:
121 return 1
122 elif b is self.large:
123 return -1
124 c = part.compare(a, b)
125 if c != 0:
126 return c
127 return 0
128
129def floatcmp(a, b):
130 if math.isnan(a) and math.isnan(b):
131 return 0
132 elif math.isnan(a):
133 return -1
134 elif math.isnan(b):
135 return 1
136 elif a < b:
137 return -1
138 elif a > b:
139 return 1
140 else:
141 return 0
142
143t_int = simpletype.struct(">q")
144t_uint = simpletype.struct(">Q")
145t_dbid = t_uint
146t_float = simpletype.struct(">d")
147t_float.compare = floatcmp
148t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
149t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
150 (lambda st: st.lower()))
151
152class index(object):
153 def __init__(self, db, name, datatype):
154 self.db = db
155 self.nm = name
156 self.typ = datatype
157
158missing = object()
159
160class ordered(index, lib.closable):
161 def __init__(self, db, name, datatype, create=True):
162 super().__init__(db, name, datatype)
163 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
164 if create: fl |= bd.DB_CREATE
165 def initdb(db):
166 def compare(a, b):
167 if a == b == "": return 0
168 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
169 db.set_flags(bd.DB_DUPSORT)
170 db.set_bt_compare(compare)
171 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
172 self.bk.set_get_returns_none(False)
173
174 def close(self):
175 self.bk.close()
176
177 class cursor(lib.closable):
178 def __init__(self, idx, fd, fi, ld, li, reverse):
179 self.idx = idx
180 self.typ = idx.typ
181 self.cur = self.idx.bk.cursor()
182 self.item = None
183 self.fd = fd
184 self.fi = fi
185 self.ld = ld
186 self.li = li
187 self.rev = reverse
188
189 def close(self):
190 if self.cur is not None:
191 self.cur.close()
192 self.cur = None
193
194 def __iter__(self):
195 return self
196
197 def _decode(self, d):
198 k, v = d
199 k = self.typ.decode(k)
200 v = struct.unpack(">Q", v)[0]
201 return k, v
202
203 @dloopfun
204 def first(self):
205 try:
206 if self.fd is missing:
207 self.item = self._decode(self.cur.first())
208 else:
209 k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
210 if not self.fi:
211 while self.typ.compare(k, self.fd) == 0:
212 k, v = self._decode(self.cur.next())
213 self.item = k, v
214 except notfound:
215 self.item = StopIteration
216
217 @dloopfun
218 def last(self):
219 try:
220 if self.ld is missing:
221 self.item = self._decode(self.cur.last())
222 else:
223 try:
224 k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
225 except notfound:
226 k, v = self._decode(self.cur.last())
227 if self.li:
228 while self.typ.compare(k, self.ld) == 0:
229 k, v = self._decode(self.cur.next())
230 while self.typ.compare(k, self.ld) > 0:
231 k, v = self._decode(self.cur.prev())
232 else:
233 while self.typ.compare(k, self.ld) >= 0:
234 k, v = self._decode(self.cur.prev())
235 self.item = k, v
236 except notfound:
237 self.item = StopIteration
238
239 @dloopfun
240 def next(self):
241 try:
242 k, v = self.item = self._decode(self.cur.next())
243 if (self.ld is not missing and
244 ((self.li and self.typ.compare(k, self.ld) > 0) or
245 (not self.li and self.typ.compare(k, self.ld) >= 0))):
246 self.item = StopIteration
247 except notfound:
248 self.item = StopIteration
249
250 @dloopfun
251 def prev(self):
252 try:
253 self.item = self._decode(self.cur.prev())
254 if (self.fd is not missing and
255 ((self.fi and self.typ.compare(k, self.fd) < 0) or
256 (not self.fi and self.typ.compare(k, self.fd) <= 0))):
257 self.item = StopIteration
258 except notfound:
259 self.item = StopIteration
260
261 def __next__(self):
262 if self.item is None:
263 if not self.rev:
264 self.next()
265 else:
266 self.prev()
267 if self.item is StopIteration:
268 raise StopIteration()
269 ret, self.item = self.item, None
270 return ret
271
272 def skip(self, n=1):
273 try:
274 for i in range(n):
275 next(self)
276 except StopIteration:
277 return
278
279 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
280 if all:
281 cur = self.cursor(self, missing, True, missing, True, reverse)
282 elif match is not missing:
283 cur = self.cursor(self, match, True, match, True, reverse)
284 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
285 if ge is not missing:
286 fd, fi = ge, True
287 elif gt is not missing:
288 fd, fi = gt, False
289 else:
290 fd, fi = missing, True
291 if le is not missing:
292 ld, li = le, True
293 elif lt is not missing:
294 ld, li = lt, False
295 else:
296 ld, li = missing, True
297 cur = self.cursor(self, fd, fi, ld, li, reverse)
298 else:
299 raise NameError("invalid get() specification")
300 done = False
301 try:
302 if not reverse:
303 cur.first()
304 else:
305 cur.last()
306 done = True
307 return cur
308 finally:
309 if not done:
310 cur.close()
311
312 @txnfun(lambda self: self.db.env.env)
313 def put(self, key, id, *, tx):
314 obid = struct.pack(">Q", id)
315 if not self.db.ob.has_key(obid, txn=tx.tx):
316 raise ValueError("no such object in database: " + str(id))
317 try:
318 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
319 except bd.DBKeyExistError:
320 return False
321 return True
322
323 @txnfun(lambda self: self.db.env.env)
324 def remove(self, key, id, *, tx):
325 obid = struct.pack(">Q", id)
326 if not self.db.ob.has_key(obid, txn=tx.tx):
327 raise ValueError("no such object in database: " + str(id))
328 cur = self.bk.cursor(txn=tx.tx)
329 try:
330 try:
331 cur.get_both(self.typ.encode(key), obid)
332 except notfound:
333 return False
334 cur.delete()
335 finally:
336 cur.close()
337 return True