Call formatter.node as should be proper from formatter.fragement.
[wrw.git] / wrw / sp / util.py
1 import cons
2
3 def findnsnames(el):
4     names = {}
5     nid = [1]
6     def proc(el):
7         if isinstance(el, cons.element):
8             if el.ns not in names:
9                 names[el.ns] = u"n" + unicode(nid[0])
10                 nid[:] = [nid[0] + 1]
11             for ch in el.children:
12                 proc(ch)
13     proc(el)
14     if None in names:
15         names[None] = None
16     else:
17         names[el.ns] = None
18     return names
19
20 class formatter(object):
21     def __init__(self, out, root, nsnames=None, charset="utf-8", doctype=None):
22         self.root = root
23         if nsnames is None:
24             nsnames = findnsnames(root)
25         self.nsnames = nsnames
26         self.out = out
27         self.charset = charset
28         self.doctype = doctype
29
30     def write(self, text):
31         self.out.write(text.encode(self.charset))
32
33     def quotewrite(self, buf):
34         for ch in buf:
35             if ch == u'&':
36                 self.write(u"&")
37             elif ch == u'<':
38                 self.write(u"&lt;")
39             elif ch == u'>':
40                 self.write(u"&gt;")
41             else:
42                 self.write(ch)
43
44     def text(self, el):
45         self.quotewrite(el)
46
47     def rawcode(self, el):
48         self.write(el)
49
50     def attrval(self, buf):
51         qc, qt = (u"'", u"&apos;") if u'"' in buf else (u'"', u"&quot;")
52         self.write(qc)
53         for ch in buf:
54             if ch == u'&':
55                 self.write(u"&amp;")
56             elif ch == u'<':
57                 self.write(u"&lt;")
58             elif ch == u'>':
59                 self.write(u"&gt;")
60             elif ch == qc:
61                 self.write(qt)
62             else:
63                 self.write(ch)
64         self.write(qc)
65
66     def attr(self, k, v):
67         self.write(k)
68         self.write(u'=')
69         self.attrval(v)
70
71     def shorttag(self, el, **extra):
72         self.write(u'<' + self.elname(el))
73         for k, v in el.attrs.iteritems():
74             self.write(u' ')
75             self.attr(k, v)
76         for k, v in extra.iteritems():
77             self.write(u' ')
78             self.attr(k, v)
79         self.write(u" />")
80
81     def elname(self, el):
82         ns = self.nsnames[el.ns]
83         if ns is None:
84             return el.name
85         else:
86             return ns + u':' + el.name
87
88     def starttag(self, el, **extra):
89         self.write(u'<' + self.elname(el))
90         for k, v in el.attrs.iteritems():
91             self.write(u' ')
92             self.attr(k, v)
93         for k, v in extra.iteritems():
94             self.write(u' ')
95             self.attr(k, v)
96         self.write(u'>')
97
98     def endtag(self, el):
99         self.write(u'</' + self.elname(el) + u'>')
100
101     def longtag(self, el):
102         self.starttag(el, **extra)
103         for ch in el.children:
104             self.node(ch)
105         self.endtag(el)
106
107     def element(self, el, **extra):
108         if len(el.children) == 0:
109             self.shorttag(el, **extra)
110         else:
111             self.longtag(el, **extra)
112
113     def node(self, el):
114         if isinstance(el, cons.element):
115             self.element(el)
116         elif isinstance(el, cons.text):
117             self.text(el)
118         elif isinstance(el, cons.raw):
119             self.rawcode(el)
120         else:
121             raise Exception("Unknown object in element tree: " + el)
122
123     def start(self):
124         self.write(u'<?xml version="1.0" encoding="' + self.charset + u'" ?>\n')
125         if self.doctype:
126             self.write(u'<!DOCTYPE %s PUBLIC "%s" "%s">\n' % (self.root.name,
127                                                               self.doctype[0],
128                                                               self.doctype[1]))
129         extra = {}
130         for uri, nm in self.nsnames.iteritems():
131             if uri is None:
132                 continue
133             if nm is None:
134                 extra[u"xmlns"] = uri
135             else:
136                 extra[u"xmlns:" + nm] = uri
137         self.element(self.root, **extra)
138
139     @classmethod
140     def output(cls, out, el, *args, **kw):
141         cls(out=out, root=el, *args, **kw).start()
142
143     @classmethod
144     def fragment(cls, out, el, *args, **kw):
145         cls(out=out, root=el, *args, **kw).node(el)
146
147     def update(self, **ch):
148         ret = type(self).__new__(type(self))
149         ret.__dict__.update(self.__dict__)
150         ret.__dict__.update(ch)
151         return ret
152
153 class iwriter(object):
154     def __init__(self, out):
155         self.out = out
156         self.atbol = True
157         self.col = 0
158
159     def write(self, buf):
160         for c in buf:
161             if c == '\n':
162                 self.col = 0
163             else:
164                 self.col += 1
165             self.out.write(c)
166         self.atbol = False
167
168     def indent(self, indent):
169         if self.atbol:
170             return
171         if self.col != 0:
172             self.write('\n')
173         self.write(indent)
174         self.atbol = True
175
176 class indenter(formatter):
177     def __init__(self, indent=u"  ", *args, **kw):
178         super(indenter, self).__init__(*args, **kw)
179         self.out = iwriter(self.out)
180         self.indent = indent
181         self.curind = u""
182
183     def simple(self, el):
184         for ch in el.children:
185             if not isinstance(ch, cons.text):
186                 return False
187         return True
188
189     def longtag(self, el, **extra):
190         self.starttag(el, **extra)
191         sub = self
192         reind = False
193         if not self.simple(el):
194             sub = self.update(curind=self.curind + self.indent)
195             sub.reindent()
196             reind = True
197         for ch in el.children:
198             sub.node(ch)
199         if reind:
200             self.reindent()
201         self.endtag(el)
202
203     def element(self, el, **extra):
204         super(indenter, self).element(el, **extra)
205         if self.out.col > 80 and self.simple(el):
206             self.reindent()
207
208     def reindent(self):
209         self.out.indent(self.curind.encode(self.charset))
210
211     def start(self):
212         super(indenter, self).start()
213         self.write('\n')