bin: Added coding and decoding of floats.
[coe.git] / coe / bin.py
1 import io, math
2 from . import data
3
4 T_END = 0
5 T_INT = 1
6 T_STR = 2
7 T_BIT = 3
8 T_NIL = 4
9 T_SYM = 5
10 T_CON = 6
11
12 INT_REF = 1
13
14 STR_SYM = 1
15
16 BIT_BFLOAT = 1
17 BIT_DFLOAT = 2
18
19 CON_SEQ = 0
20 CON_SET = 1
21 CON_MAP = 2
22 CON_OBJ = 3
23
24 NIL_FALSE = 1
25 NIL_TRUE = 2
26
27 class encoder(object):
28     def __init__(self, *, backrefs=True):
29         self.backrefs = backrefs
30         self.reftab = {}
31         self.nextref = 0
32         self.nstab = {}
33
34     @staticmethod
35     def enctag(pri, sec):
36         return bytes([(sec << 3) | pri])
37
38     def writetag(self, dst, pri, sec, datum):
39         dst.write(self.enctag(pri, sec))
40         if self.backrefs:
41             ref = self.nextref
42             self.nextref += 1
43             if datum is not None and id(datum) not in self.reftab:
44                 self.reftab[id(datum)] = ref
45             return ref
46         return None
47
48     @staticmethod
49     def encint(x):
50         ret = bytearray()
51         if x >= 0:
52             b = x & 0x7f
53             x >>= 7
54             while (x > 0) or (b & 0x40) != 0:
55                 ret.append(0x80 | b)
56                 b = x & 0x7f
57                 x >>= 7
58             ret.append(b)
59         elif x < 0:
60             b = x & 0x7f
61             x >>= 7
62             while x < -1 or (b & 0x40) == 0:
63                 ret.append(0x80 | b)
64                 b = x & 0x7f
65                 x >>= 7
66             ret.append(b)
67         return ret
68
69     @staticmethod
70     def writestr(dst, text):
71         dst.write(text.encode("utf-8"))
72         dst.write(b'\0')
73
74     @staticmethod
75     def writefloat(dst, x):
76         if x == 0.0:
77             mnt, exp = (0, 0)
78         elif math.isinf(x):
79             if x > 0:
80                 mnt, exp = (0, 2)
81             else:
82                 mnt, exp = (0, 3)
83         elif math.isnan(x):
84             mnt, exp = (0, 4)
85         else:
86             mnt, exp = math.frexp(x)
87             mnt *= 2
88             exp -= 1
89             while mnt != int(mnt):
90                 mnt *= 2
91             mnt = int(mnt)
92         buf = bytearray()
93         buf.extend(encoder.encint(mnt))
94         buf.extend(encoder.encint(exp))
95         dst.write(encoder.encint(len(buf)))
96         dst.write(buf)
97
98     def dumpseq(self, dst, seq):
99         for v in seq:
100             self.dump(dst, v)
101         dst.write(self.enctag(T_END, 0))
102
103     def dumpmap(self, dst, val):
104         for k, v in val.items():
105             self.dump(dst, k)
106             self.dump(dst, v)
107         dst.write(self.enctag(T_END, 0))
108
109     def dump(self, dst, datum):
110         ref = self.reftab.get(id(datum))
111         if ref is not None:
112             dst.write(self.enctag(T_INT, INT_REF))
113             dst.write(self.encint(ref))
114             return
115         if datum == None:
116             self.writetag(dst, T_NIL, 0, None)
117         elif datum is False:
118             self.writetag(dst, T_NIL, NIL_FALSE, None)
119         elif datum is True:
120             self.writetag(dst, T_NIL, NIL_TRUE, None)
121         elif isinstance(datum, int):
122             self.writetag(dst, T_INT, 0, None)
123             dst.write(self.encint(datum))
124         elif isinstance(datum, str):
125             self.writetag(dst, T_STR, 0, datum)
126             self.writestr(dst, datum)
127         elif isinstance(datum, (bytes, bytearray)):
128             self.writetag(dst, T_BIT, 0, datum)
129             dst.write(self.encint(len(datum)))
130             dst.write(datum)
131         elif isinstance(datum, float):
132             self.writetag(dst, T_BIT, BIT_BFLOAT, datum)
133             self.writefloat(dst, datum)
134         elif isinstance(datum, data.symbol):
135             if datum.ns == "":
136                 self.writetag(dst, T_STR, STR_SYM, datum)
137                 self.writestr(dst, datum.name)
138             else:
139                 nsref = self.nstab.get(datum.ns)
140                 if nsref is None:
141                     nsref = self.writetag(dst, T_SYM, 0, datum)
142                     dst.write(b'\0')
143                     self.writestr(dst, datum.ns)
144                     self.writestr(dst, datum.name)
145                     if nsref is not None:
146                         self.nstab[datum.ns] = nsref
147                 else:
148                     self.writetag(dst, T_SYM, 0, datum)
149                     dst.write(b'\x01')
150                     dst.write(self.encint(nsref))
151                     self.writestr(dst, datum.name)
152         elif isinstance(datum, list):
153             self.writetag(dst, T_CON, CON_SEQ, datum)
154             self.dumpseq(dst, datum)
155         elif isinstance(datum, set):
156             self.writetag(dst, T_CON, CON_SET, datum)
157             self.dumpseq(dst, datum)
158         elif isinstance(datum, dict):
159             self.writetag(dst, T_CON, CON_MAP, datum)
160             self.dumpmap(dst, datum)
161         elif isinstance(datum, data.obj):
162             self.writetag(dst, T_CON, CON_OBJ, datum)
163             self.dump(dst, getattr(type(datum), "typename", None))
164             self.dumpmap(dst, datum.__dict__)
165         else:
166             raise ValueError("unsupported object type: " + repr(datum))
167
168 def dump(dst, datum):
169     encoder().dump(dst, datum)
170     return dst
171
172 class fmterror(Exception):
173     pass
174
175 class eoferror(fmterror):
176     def __init__(self):
177         super().__init__("unexpected end-of-data")
178
179 class referror(fmterror):
180     def __init__(self):
181         super().__init__("bad backref")
182
183 class decoder(object):
184     def __init__(self):
185         self.reftab = []
186         self.namedtypes = {}
187
188     @staticmethod
189     def byte(fp):
190         b = fp.read(1)
191         if b == b"":
192             raise eoferror()
193         return b[0]
194
195     @staticmethod
196     def loadint(fp):
197         ret = 0
198         p = 0
199         while True:
200             b = decoder.byte(fp)
201             ret += (b & 0x7f) << p
202             p += 7
203             if (b & 0x80) == 0:
204                 break
205         if (b & 0x40) != 0:
206             ret = ret - (1 << p)
207         return ret
208
209     @staticmethod
210     def loadstr(fp):
211         buf = bytearray()
212         while True:
213             b = decoder.byte(fp)
214             if b == 0:
215                 break
216             buf.append(b)
217         return buf.decode("utf-8")
218
219     def loadsym(self, fp):
220         h = self.byte(fp)
221         if h & 0x1:
222             nsref = self.loadint(fp)
223             if not 0 <= nsref < len(self.reftab):
224                 raise fmterror("illegal namespace ref: " + str(nsref))
225             nssym = self.reftab[nsref]
226             if not isinstance(nssym, data.symbol):
227                 raise fmterror("illegal namespace ref: " + str(nsref))
228             ns = nssym.ns
229         else:
230             ns = self.loadstr(fp)
231         nm = self.loadstr(fp)
232         ret = data.symbol.get(ns, nm)
233         return ret
234
235     def loadlist(self, fp, buf):
236         while True:
237             tag = self.byte(fp)
238             if tag == T_END:
239                 return buf
240             buf.append(self.loadtagged(fp, tag))
241
242     def loadmap(self, fp, buf):
243         while True:
244             tag = self.byte(fp)
245             if tag == T_END:
246                 return buf
247             key = self.loadtagged(fp, tag)
248             tag = self.byte(fp)
249             if tag == T_END:
250                 return buf
251             buf[key] = self.loadtagged(fp, tag)
252
253     def makeobjtype(self, nm):
254         return data.namedtype.make(str(nm), (data.obj, object), {}, typename=nm)
255
256     def loadobj(self, fp, ref=False):
257         if ref:
258             refid = len(self.reftab)
259             self.reftab.append(None)
260         nm = self.load(fp)
261         typ = self.namedtypes.get(nm)
262         if typ is None:
263             typ = self.namedtypes[nm] = self.makeobjtype(nm)
264         ret = typ()
265         if ref:
266             self.reftab[refid] = ret
267         # st = fp.tell()
268         # print(">", nm, hex(st))
269         ret.__dict__.update(self.loadmap(fp, {}))
270         # print("<", nm, hex(fp.tell()), hex(st))
271         return ret
272
273     def addref(self, obj):
274         self.reftab.append(obj)
275         return obj
276
277     def loadtagged(self, fp, tag):
278         pri, sec = (tag & 0x7), (tag & 0xf8) >> 3
279         if pri == T_END:
280             raise fmterror("unexpected end-tag")
281         elif pri == T_INT:
282             if sec == INT_REF:
283                 idx = self.loadint(fp)
284                 if not 0 <= idx < len(self.reftab):
285                     raise referror()
286                 # print(idx, self.reftab[idx], hex(fp.tell()))
287                 return self.reftab[idx]
288             return self.addref(self.loadint(fp))
289         elif pri == T_STR:
290             ret = self.loadstr(fp)
291             if sec == STR_SYM:
292                 return self.addref(data.symbol.get("", ret))
293             return self.addref(ret)
294         elif pri == T_BIT:
295             ln = self.loadint(fp)
296             ret = fp.read(ln)
297             if len(ret) < ln:
298                 raise eoferror()
299             if sec == BIT_BFLOAT:
300                 buf = io.BytesIO(ret)
301                 mnt = self.loadint(buf)
302                 exp = self.loadint(buf)
303                 if mnt == 0:
304                     if exp == 0:
305                         return 0.0
306                     elif exp == 1:
307                         return -0.0
308                     elif exp == 2:
309                         return float("inf")
310                     elif exp == 3:
311                         return -float("inf")
312                     else:
313                         return float("nan")
314                 else:
315                     ret = math.ldexp(mnt, exp - (mnt.bit_length() - 1))
316                 return self.addref(ret)
317             return self.addref(ret)
318         elif pri == T_NIL:
319             if sec == NIL_TRUE:
320                 return self.addref(True)
321             elif sec == NIL_FALSE:
322                 return self.addref(False)
323             return self.addref(None)
324         elif pri == T_SYM:
325             return self.addref(self.loadsym(fp))
326         elif pri == T_CON:
327             if sec == CON_MAP:
328                 return self.loadmap(fp, self.addref({}))
329             elif sec == CON_OBJ:
330                 return self.loadobj(fp, ref=True)
331             else:
332                 return self.loadlist(fp, self.addref([]))
333         else:
334             raise fmterror("unknown primary: " + str(pri))
335
336     def load(self, fp):
337         tag = self.byte(fp)
338         return self.loadtagged(fp, tag)
339
340 def load(fp):
341     return decoder().load(fp)