36e5731e0f20d5b10383fe779c4eba3131332381
[ldd.git] / ldd / rescache.py
1 #    ldd - DNS implementation in Python
2 #    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3 #
4 #    This program is free software; you can redistribute it and/or modify
5 #    it under the terms of the GNU General Public License as published by
6 #    the Free Software Foundation; either version 2 of the License, or
7 #    (at your option) any later version.
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program; if not, write to the Free Software
16 #    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17
18 import threading
19 import time
20
21 import resolver
22 import proto
23 import rec
24
25 class nxdmark:
26     def __init__(self, expire, auth):
27         self.expire = expire
28         self.auth = auth
29
30 class cacheresolver(resolver.resolver):
31     def __init__(self, resolver):
32         self.resolver = resolver
33         self.cache = dict()
34         self.cachelock = threading.Lock()
35
36     def getcached(self, name, rtype = proto.QTANY):
37         self.cachelock.acquire()
38         try:
39             if name not in self.cache:
40                 return []
41             now = int(time.time())
42             if isinstance(self.cache[name], nxdmark):
43                 if self.cache[name].expire < now:
44                     self.cache[name] = []
45                     return []
46                 return self.cache[name]
47             ret = []
48             if rtype == proto.QTANY:
49                 cond = lambda rt: True
50             elif type(rtype) == int:
51                 cond = lambda rt: rtype == rt
52             elif type(rtype) == str:
53                 rtid = rec.rtypebyname(rtype)
54                 cond = lambda rt: rtid == rt
55             else:
56                 rtset = set([((type(rtid) == str) and rec.rtypebyname(rtid)) or rtid for rtid in rtype])
57                 cond = lambda rt: rt in rtset
58             for exp, trd, data, auth in self.cache[name]:
59                 if exp > now and cond(trd):
60                     ret += [(rec.rr((name, trd), exp - now, data), auth)]
61             return ret
62         finally:
63             self.cachelock.release()
64
65     def dolookup(self, name, rtype):
66         try:
67             res = self.resolver.squery(name, rtype)
68         except resolver.servfail, resolver.unreachable:
69             return None
70         if res is None:
71             return None
72         if res.rescode == proto.NXDOMAIN:
73             ttl = 300
74             for rr in res.aulist:
75                 if rr.head.istype("SOA"):
76                     ttl = rr.data["minttl"]
77             nc = nxdmark(int(time.time()) + ttl, res.aulist)
78             self.cachelock.acquire()
79             try:
80                 self.cache[name] = nc
81             finally:
82                 self.cachelock.release()
83             return nc
84         now = int(time.time())
85         self.cachelock.acquire()
86         try:
87             alltypes = set([rr.head.rtype for rr in res.allrrs()])
88             for name in set([rr.head.name for rr in res.allrrs()]):
89                 if name in self.cache:
90                     self.cache[name] = [cl for cl in self.cache[name] if cl[1] not in alltypes]
91             for rr in res.allrrs():
92                 if rr.head.name not in self.cache:
93                     self.cache[rr.head.name] = []
94                 self.cache[rr.head.name] += [(now + rr.ttl, rr.head.rtype, rr.data, [rr for rr in res.aulist if rr.head.istype("NS")])]
95             return res
96         finally:
97             self.cachelock.release()
98
99     def addcached(self, packet, cis):
100         for item, auth in cis:
101             packet.addan(item)
102             for ns in auth:
103                 packet.addau(ns)
104                 nsal = self.getcached(ns.data["nsname"], ["A", "AAAA"])
105                 if type(nsal) == list:
106                     for item, auth in nsal:
107                         packet.addad(item)
108
109     def resolve(self, packet):
110         res = proto.responsefor(packet)
111         for q in packet.qlist:
112             name = q.name
113             rtype = q.rtype
114             while True:
115                 cis = self.getcached(name, rtype)
116                 if isinstance(cis, nxdmark):
117                     if len(packet.qlist) == 1:
118                         res.rescode = proto.NXDOMAIN
119                         res.aulist = cis.auth
120                         return res
121                     continue
122                 if len(cis) == 0:
123                     cics = self.getcached(name, "CNAME")
124                     if isinstance(cics, nxdmark):
125                         break
126                     if len(cics) > 0:
127                         self.addcached(res, cics)
128                         name = cics[0][0].data["priname"]
129                         continue
130                 break
131             if len(cis) == 0:
132                 tres = self.dolookup(name, rtype)
133                 if isinstance(tres, nxdmark) and len(packet.qlist) == 1:
134                     res.rescode = proto.NXDOMAIN
135                     res.aulist = tres.auth
136                     return res
137                 if tres is None and len(packet.qlist) == 1:
138                     res.rescode = proto.SERVFAIL
139                     return res
140                 if tres is not None and tres.rescode == 0:
141                     res.merge(tres)
142             else:
143                 self.addcached(res, cis)
144         return res