data: Replace usymbol/string equivalence with obj __getattr__ implementation.
[coe.git] / coe / bin.py
1 from . import data
2
3 T_END = 0
4 T_INT = 1
5 T_STR = 2
6 T_BIT = 3
7 T_NIL = 4
8 T_SYM = 5
9 T_CON = 6
10
11 INT_REF = 1
12
13 STR_SYM = 1
14
15 BIT_BFLOAT = 1
16 BIT_DFLOAT = 2
17
18 CON_SEQ = 0
19 CON_SET = 1
20 CON_MAP = 2
21 CON_OBJ = 3
22
23 NIL_FALSE = 1
24 NIL_TRUE = 2
25
26 class encoder(object):
27     def __init__(self, *, backrefs=True):
28         self.backrefs = backrefs
29         self.reftab = {}
30         self.nextref = 0
31         self.nstab = {}
32
33     @staticmethod
34     def enctag(pri, sec):
35         return bytes([(sec << 3) | pri])
36
37     def writetag(self, dst, pri, sec, datum):
38         dst.write(self.enctag(pri, sec))
39         if self.backrefs:
40             ref = self.nextref
41             self.nextref += 1
42             if datum is not None and id(datum) not in self.reftab:
43                 self.reftab[id(datum)] = ref
44             return ref
45         return None
46
47     @staticmethod
48     def encint(x):
49         ret = bytearray()
50         if x >= 0:
51             b = x & 0x7f
52             x >>= 7
53             while (x > 0) or (b & 0x40) != 0:
54                 ret.append(0x80 | b)
55                 b = x & 0x7f
56                 x >>= 7
57             ret.append(b)
58         elif x < 0:
59             b = x & 0x7f
60             x >>= 7
61             while x < -1 or (b & 0x40) == 0:
62                 ret.append(0x80 | b)
63                 b = x & 0x7f
64                 x >>= 7
65             ret.append(b)
66         return ret
67
68     @staticmethod
69     def writestr(dst, text):
70         dst.write(text.encode("utf-8"))
71         dst.write(b'\0')
72
73     def dumpseq(self, dst, seq):
74         for v in seq:
75             self.dump(dst, v)
76         dst.write(self.enctag(T_END, 0))
77
78     def dumpmap(self, dst, val):
79         for k, v in val.items():
80             self.dump(dst, k)
81             self.dump(dst, v)
82         dst.write(self.enctag(T_END, 0))
83
84     def dump(self, dst, datum):
85         ref = self.reftab.get(id(datum))
86         if ref is not None:
87             dst.write(self.enctag(T_INT, INT_REF))
88             dst.write(self.encint(ref))
89             return
90         if datum == None:
91             self.writetag(dst, T_NIL, 0, None)
92         elif datum == False:
93             self.writetag(dst, T_NIL, NIL_FALSE, None)
94         elif datum == True:
95             self.writetag(dst, T_NIL, NIL_TRUE, None)
96         elif isinstance(datum, int):
97             self.writetag(dst, T_INT, 0, None)
98             dst.write(self.encint(datum))
99         elif isinstance(datum, str):
100             self.writetag(dst, T_STR, 0, datum)
101             self.writestr(dst, datum)
102         elif isinstance(datum, (bytes, bytearray)):
103             self.writetag(dst, T_BIT, 0, datum)
104             dst.write(self.encint(len(datum)))
105             dst.write(datum)
106         elif isinstance(datum, data.symbol):
107             if datum.ns == "":
108                 self.writetag(dst, T_STR, STR_SYM, datum)
109                 self.writestr(dst, datum.name)
110             else:
111                 nsref = self.nstab.get(datum.ns)
112                 if nsref is None:
113                     nsref = self.writetag(dst, T_SYM, 0, datum)
114                     dst.write(b'\0')
115                     self.writestr(dst, datum.ns)
116                     self.writestr(dst, datum.name)
117                     if nsref is not None:
118                         self.nstab[datum.ns] = nsref
119                 else:
120                     self.writetag(dst, T_SYM, 0, datum)
121                     dst.write(b'\x01')
122                     dst.write(self.encint(nsref))
123                     self.writestr(dst, datum.name)
124         elif isinstance(datum, list):
125             self.writetag(dst, T_CON, CON_SEQ, datum)
126             self.dumpseq(dst, datum)
127         elif isinstance(datum, set):
128             self.writetag(dst, T_CON, CON_SET, datum)
129             self.dumpseq(dst, datum)
130         elif isinstance(datum, dict):
131             self.writetag(dst, T_CON, CON_MAP, datum)
132             self.dumpmap(dst, datum)
133         elif isinstance(datum, data.obj):
134             self.writetag(dst, T_CON, CON_OBJ, datum)
135             self.dump(dst, getattr(type(datum), "typename", None))
136             self.dumpmap(dst, datum.__dict__)
137         else:
138             raise ValueError("unsupported object type: " + repr(datum))
139
140 def dump(dst, datum):
141     encoder().dump(dst, datum)
142     return dst
143
144 class fmterror(Exception):
145     pass
146
147 class eoferror(fmterror):
148     def __init__(self):
149         super().__init__("unexpected end-of-data")
150
151 class referror(fmterror):
152     def __init__(self):
153         super().__init__("bad backref")
154
155 class decoder(object):
156     def __init__(self):
157         self.reftab = []
158         self.namedtypes = {}
159
160     @staticmethod
161     def byte(fp):
162         b = fp.read(1)
163         if b == b"":
164             raise eoferror()
165         return b[0]
166
167     @staticmethod
168     def loadint(fp):
169         ret = 0
170         p = 0
171         while True:
172             b = decoder.byte(fp)
173             ret += (b & 0x7f) << p
174             p += 7
175             if (b & 0x80) == 0:
176                 break
177         if (b & 0x40) != 0:
178             ret = ret - (1 << p)
179         return ret
180
181     @staticmethod
182     def loadstr(fp):
183         buf = bytearray()
184         while True:
185             b = decoder.byte(fp)
186             if b == 0:
187                 break
188             buf.append(b)
189         return buf.decode("utf-8")
190
191     def loadsym(self, fp):
192         h = self.byte(fp)
193         if h & 0x1:
194             nsref = self.loadint(fp)
195             if not 0 <= nsref < len(self.reftab):
196                 raise fmterror("illegal namespace ref: " + str(nsref))
197             nssym = self.reftab[nsref]
198             if not isinstance(nssym, data.symbol):
199                 raise fmterror("illegal namespace ref: " + str(nsref))
200             ns = nssym.ns
201         else:
202             ns = self.loadstr(fp)
203         nm = self.loadstr(fp)
204         ret = data.symbol.get(ns, nm)
205         return ret
206
207     def loadlist(self, fp, buf):
208         while True:
209             tag = self.byte(fp)
210             if tag == T_END:
211                 return buf
212             buf.append(self.loadtagged(fp, tag))
213
214     def loadmap(self, fp, buf):
215         while True:
216             tag = self.byte(fp)
217             if tag == T_END:
218                 return buf
219             key = self.loadtagged(fp, tag)
220             tag = self.byte(fp)
221             if tag == T_END:
222                 return buf
223             buf[key] = self.loadtagged(fp, tag)
224
225     def makeobjtype(self, nm):
226         return data.namedtype(str(nm), (data.obj, object), {}, typename=nm)
227
228     def loadobj(self, fp, ref=False):
229         if ref:
230             refid = len(self.reftab)
231             self.reftab.append(None)
232         nm = self.load(fp)
233         typ = self.namedtypes.get(nm)
234         if typ is None:
235             typ = self.namedtypes[nm] = self.makeobjtype(nm)
236         ret = typ()
237         if ref:
238             self.reftab[refid] = ret
239         # st = fp.tell()
240         # print(">", nm, hex(st))
241         ret.__dict__.update(self.loadmap(fp, {}))
242         # print("<", nm, hex(fp.tell()), hex(st))
243         return ret
244
245     def addref(self, obj):
246         self.reftab.append(obj)
247         return obj
248
249     def loadtagged(self, fp, tag):
250         pri, sec = (tag & 0x7), (tag & 0xf8) >> 3
251         if pri == T_END:
252             raise fmterror("unexpected end-tag")
253         elif pri == T_INT:
254             if sec == INT_REF:
255                 idx = self.loadint(fp)
256                 if not 0 <= idx < len(self.reftab):
257                     raise referror()
258                 # print(idx, self.reftab[idx], hex(fp.tell()))
259                 return self.reftab[idx]
260             return self.addref(self.loadint(fp))
261         elif pri == T_STR:
262             ret = self.loadstr(fp)
263             if sec == STR_SYM:
264                 return self.addref(data.symbol.get("", ret))
265             return self.addref(ret)
266         elif pri == T_BIT:
267             ln = self.loadint(fp)
268             ret = self.addref(fp.read(ln))
269             if len(ret) < ln:
270                 raise eoferror()
271             return ret
272         elif pri == T_NIL:
273             if sec == NIL_TRUE:
274                 return self.addref(True)
275             elif sec == NIL_FALSE:
276                 return self.addref(False)
277             return self.addref(None)
278         elif pri == T_SYM:
279             return self.addref(self.loadsym(fp))
280         elif pri == T_CON:
281             if sec == CON_MAP:
282                 return self.loadmap(fp, self.addref({}))
283             elif sec == CON_OBJ:
284                 return self.loadobj(fp, ref=True)
285             else:
286                 return self.loadlist(fp, self.addref([]))
287         else:
288             raise fmterror("unknown primary: " + str(pri))
289
290     def load(self, fp):
291         tag = self.byte(fp)
292         return self.loadtagged(fp, tag)
293
294 def load(fp):
295     decoder().load(fp)