]> git.dolda2000.com Git - wrw.git/blob - wrw/sp/util.py
6b4ab8d988a4c027ef54f25188c04a0042eb7d29
[wrw.git] / wrw / sp / util.py
1 import itertools, io
2 from .. import dispatch
3 from . import cons
4
5 def findnsnames(el):
6     names = {}
7     nid = [1]
8     def proc(el):
9         if isinstance(el, cons.element):
10             if el.ns not in names:
11                 names[el.ns] = "n" + str(nid[0])
12                 nid[:] = [nid[0] + 1]
13             for ch in el.children:
14                 proc(ch)
15     proc(el)
16     if None in names:
17         names[None] = None
18     else:
19         names[el.ns] = None
20     return names
21
22 def flatiter(root, short=True):
23     yield ">", root
24     stack = [(root, 0)]
25     while len(stack) > 0:
26         el, i = stack[-1]
27         if i >= len(el.children):
28             yield "<", el
29             stack.pop()
30         else:
31             ch = el.children[i]
32             stack[-1] = el, i + 1
33             if isinstance(ch, cons.element):
34                 if short and len(ch.children) == 0:
35                     yield "/", ch
36                 else:
37                     yield ">", ch
38                     stack.append((ch, 0))
39             elif isinstance(ch, cons.text):
40                 yield "", ch
41             elif isinstance(ch, cons.raw):
42                 yield "!", ch
43             elif isinstance(ch, cons.comment):
44                 yield "-", ch
45             else:
46                 raise Exception("Unknown object in element tree: " + el)
47
48 class formatter(object):
49     def __init__(self, src, nsnames=None, charset="utf-8"):
50         self.src = src
51         self.nsnames = nsnames or {}
52         self.nextns = 1
53         self.first = False
54         self.buf = bytearray()
55         self.charset = charset
56
57     def write(self, text):
58         self.buf.extend(text.encode(self.charset))
59
60     def quotewrite(self, buf):
61         buf = buf.replace('&', "&amp;")
62         buf = buf.replace('<', "&lt;")
63         buf = buf.replace('>', "&gt;")
64         self.write(buf)
65
66     def __iter__(self):
67         return self
68
69     def elname(self, el):
70         ns = self.nsnames[el.ns]
71         if ns is None:
72             return el.name
73         else:
74             return ns + ":" + el.name
75
76     def attrval(self, v):
77         qc, qt = ("'", "&apos;") if '"' in v else ('"', "&quot;")
78         self.write(qc)
79         v = v.replace('&', "&amp;")
80         v = v.replace('<', "&lt;")
81         v = v.replace('>', "&gt;")
82         v = v.replace(qc, qt)
83         self.write(v)
84         self.write(qc)
85
86     def attr(self, k, v):
87         self.write(k)
88         self.write("=")
89         self.attrval(v)
90
91     def attrs(self, attrs):
92         for k, v in attrs:
93             self.write(" ")
94             self.attr(k, v)
95
96     def inittag(self, el):
97         self.write("<" + self.elname(el))
98         attrs = el.attrs.items()
99         if self.first:
100             nsnames = []
101             for ns, name in self.nsnames.items():
102                 if ns is None:
103                     if name is not None:
104                         raise Exception("null namespace must have null name, not" + name)
105                     continue
106                 nsnames.append(("xmlns" if name is None else ("xmlns:" + name), ns))
107             attrs = itertools.chain(attrs, iter(nsnames))
108             self.first = False
109         self.attrs(attrs)
110
111     def starttag(self, el):
112         self.inittag(el)
113         self.write(">")
114
115     def shorttag(self, el):
116         self.inittag(el)
117         self.write(" />")
118
119     def endtag(self, el):
120         self.write("</" + self.elname(el) + ">")
121
122     def text(self, el):
123         self.quotewrite(el)
124
125     def rawcode(self, el):
126         self.write(el)
127
128     def comment(self, el):
129         self.write("<!-- " + str(el) + " -->")
130
131     def start(self, el):
132         self.write('<?xml version="1.0" encoding="' + self.charset + '" ?>\n')
133         if isinstance(el, cons.doctype):
134             self.write('<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (el.rootname,
135                                                               el.pubid,
136                                                               el.dtdid))
137         self.first = True
138
139     def end(self, el):
140         pass
141
142     def handle(self, ev, el):
143         if ev == ">":
144             self.starttag(el)
145         elif ev == "/":
146             self.shorttag(el)
147         elif ev == "<":
148             self.endtag(el)
149         elif ev == "":
150             self.text(el)
151         elif ev == "!":
152             self.rawcode(el)
153         elif ev == "-":
154             self.comment(el)
155         elif ev == "^":
156             self.start(el)
157         elif ev == "$":
158             self.end(el)
159
160     def __next__(self):
161         if self.src is None:
162             raise StopIteration()
163         try:
164             ev, el = next(self.src)
165         except StopIteration:
166             self.src = None
167             ev, el = "$", None
168         self.handle(ev, el)
169         ret = bytes(self.buf)
170         self.buf[:] = b""
171         return ret
172
173     def nsname(self, el):
174         for t in type(self).__mro__:
175             ret = getattr(t, "defns", {}).get(el.ns, None)
176             if ret is not None:
177                 return ret
178         if el.ns is None:
179             return None
180         ret = "n" + str(self.nextns)
181         self.nextns += 1
182         return ret
183
184     def findnsnames(self, root):
185         fnames = {}
186         rnames = {}
187         def proc(el):
188             if isinstance(el, cons.element):
189                 if el.ns not in fnames:
190                     nm = self.nsname(el)
191                     fnames[el.ns] = nm
192                     rnames[nm] = el.ns
193                 for ch in el.children:
194                     proc(ch)
195         proc(root)
196         if None not in rnames:
197             fnames[root.ns] = None
198             rnames[None] = root.ns
199         self.nsnames = fnames
200
201     @classmethod
202     def output(cls, out, root, nsnames=None, doctype=None, **kw):
203         if isinstance(doctype, cons.doctype):
204             pass
205         elif doctype is not None:
206             doctype = cons.doctype(root.name, doctype[0], doctype[1])
207         src = itertools.chain(iter([("^", doctype)]), flatiter(root))
208         self = cls(src=src, nsnames=nsnames, **kw)
209         if nsnames is None:
210             self.findnsnames(root)
211         self.first = True
212         for piece in self:
213             out.write(piece)
214
215     @classmethod
216     def fragment(cls, out, root, nsnames=None, **kw):
217         self = cls(src=flatiter(root), nsnames=nsnames, **kw)
218         if nsnames is None:
219             self.findnsnames(root)
220         for piece in self:
221             out.write(piece)
222
223     @classmethod
224     def format(cls, root, **kw):
225         buf = io.BytesIO()
226         cls.output(buf, root, **kw)
227         return buf.getvalue()
228
229 class indenter(formatter):
230     def __init__(self, indent="  ", *args, **kw):
231         super().__init__(*args, **kw)
232         self.indent = indent
233         self.col = 0
234         self.curind = ""
235         self.atbreak = True
236         self.inline = False
237         self.stack = []
238         self.last = None, None
239         self.lastendbr = True
240
241     def write(self, text):
242         lines = text.split("\n")
243         if len(lines) > 1:
244             for ln in lines[:-1]:
245                 self.buf.extend(ln.encode(self.charset))
246                 self.buf.extend(b"\n")
247             self.col = 0
248         self.buf.extend(lines[-1].encode(self.charset))
249         self.col += len(lines[-1])
250         self.atbreak = False
251
252     def br(self):
253         if not self.atbreak:
254             self.buf.extend(("\n" + self.curind).encode(self.charset))
255             self.col = 0
256             self.atbreak = True
257
258     def inlinep(self, el):
259         for ch in el.children:
260             if isinstance(ch, cons.text):
261                 return True
262         return False
263
264     def push(self, el):
265         self.stack.append((el, self.curind, self.inline))
266
267     def pop(self):
268         el, self.curind, self.inline = self.stack.pop()
269         return el
270
271     def starttag(self, el):
272         if not self.inline:
273             if self.last[0] == "<" and self.last[1].name == el.name and self.lastendbr:
274                 pass
275             else:
276                 self.br()
277         self.push(el)
278         self.inline = self.inline or self.inlinep(el)
279         self.curind += self.indent
280         super().starttag(el)
281
282     def shorttag(self, el):
283         if not self.inline:
284             self.br()
285         super().shorttag(el)
286
287     def endtag(self, el):
288         il = self.inline
289         self.pop()
290         if il or (self.last[0] == ">" and self.last[1] == el):
291             self.lastendbr = False
292         else:
293             self.br()
294             self.lastendbr = True
295         super().endtag(el)
296
297     def start(self, el):
298         super().start(el)
299         self.atbreak = True
300
301     def end(self, el):
302         self.br()
303
304     def handle(self, ev, el):
305         super().handle(ev, el)
306         self.last = ev, el
307
308 class textindenter(indenter):
309     maxcol = 70
310
311     def text(self, el):
312         left = str(el)
313         while True:
314             if len(left) + self.col > self.maxcol:
315                 bp = max(self.maxcol - self.col, 0)
316                 for i in range(bp, -1, -1):
317                     if left[i].isspace():
318                         while i > 0 and left[i - 1].isspace(): i -= 1
319                         break
320                 else:
321                     for i in range(bp + 1, len(left)):
322                         if left[i].isspace():
323                             break
324                     else:
325                         i = None
326                 if i is None:
327                     self.quotewrite(left)
328                     break
329                 else:
330                     self.quotewrite(left[:i])
331                     self.br()
332                     left = left[i + 1:].lstrip()
333             else:
334                 self.quotewrite(left)
335                 break
336
337 class response(dispatch.restart):
338     charset = "utf-8"
339     doctype = None
340     formatter = indenter
341
342     def __init__(self, root):
343         super().__init__()
344         self.root = root
345
346     @property
347     def ctype(self):
348         raise Exception("a subclass of wrw.sp.util.response must override ctype")
349
350     def handle(self, req):
351         ret = self.formatter.format(self.root, doctype=self.doctype, charset=self.charset)
352         req.ohead["Content-Type"] = self.ctype
353         req.ohead["Content-Length"] = len(ret)
354         return [ret]