| 1 | # ldd - DNS implementation in Python |
| 2 | # Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com> |
| 3 | # |
| 4 | # This program is free software; you can redistribute it and/or modify |
| 5 | # it under the terms of the GNU General Public License as published by |
| 6 | # the Free Software Foundation; either version 2 of the License, or |
| 7 | # (at your option) any later version. |
| 8 | # |
| 9 | # This program is distributed in the hope that it will be useful, |
| 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | # GNU General Public License for more details. |
| 13 | # |
| 14 | # You should have received a copy of the GNU General Public License |
| 15 | # along with this program; if not, write to the Free Software |
| 16 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA |
| 17 | |
| 18 | import threading |
| 19 | import time |
| 20 | |
| 21 | import resolver |
| 22 | import proto |
| 23 | import rec |
| 24 | |
| 25 | class nxdmark: |
| 26 | def __init__(self, expire, auth): |
| 27 | self.expire = expire |
| 28 | self.auth = auth |
| 29 | |
| 30 | class cacheresolver(resolver.resolver): |
| 31 | def __init__(self, resolver): |
| 32 | self.resolver = resolver |
| 33 | self.cache = dict() |
| 34 | self.cachelock = threading.Lock() |
| 35 | |
| 36 | def getcached(self, name, rtype = proto.QTANY): |
| 37 | self.cachelock.acquire() |
| 38 | try: |
| 39 | if name not in self.cache: |
| 40 | return [] |
| 41 | now = int(time.time()) |
| 42 | if isinstance(self.cache[name], nxdmark): |
| 43 | if self.cache[name].expire < now: |
| 44 | self.cache[name] = [] |
| 45 | return [] |
| 46 | return self.cache[name] |
| 47 | ret = [] |
| 48 | if rtype == proto.QTANY: |
| 49 | cond = lambda rt: True |
| 50 | elif type(rtype) == int: |
| 51 | cond = lambda rt: rtype == rt |
| 52 | elif type(rtype) == str: |
| 53 | rtid = rec.rtypebyname(rtype) |
| 54 | cond = lambda rt: rtid == rt |
| 55 | else: |
| 56 | rtset = set([((type(rtid) == str) and rec.rtypebyname(rtid)) or rtid for rtid in rtype]) |
| 57 | cond = lambda rt: rt in rtset |
| 58 | for exp, trd, data, auth in self.cache[name]: |
| 59 | if exp > now and cond(trd): |
| 60 | ret += [(rec.rr((name, trd), exp - now, data), auth)] |
| 61 | return ret |
| 62 | finally: |
| 63 | self.cachelock.release() |
| 64 | |
| 65 | def dolookup(self, name, rtype): |
| 66 | try: |
| 67 | res = self.resolver.squery(name, rtype) |
| 68 | except resolver.servfail, resolver.unreachable: |
| 69 | return None |
| 70 | if res is None: |
| 71 | return None |
| 72 | if res.rescode == proto.NXDOMAIN: |
| 73 | ttl = 300 |
| 74 | for rr in res.aulist: |
| 75 | if rr.head.istype("SOA"): |
| 76 | ttl = rr.data["minttl"] |
| 77 | nc = nxdmark(int(time.time()) + ttl, res.aulist) |
| 78 | self.cachelock.acquire() |
| 79 | try: |
| 80 | self.cache[name] = nc |
| 81 | finally: |
| 82 | self.cachelock.release() |
| 83 | return nc |
| 84 | now = int(time.time()) |
| 85 | self.cachelock.acquire() |
| 86 | try: |
| 87 | alltypes = set([rr.head.rtype for rr in res.allrrs()]) |
| 88 | for name in set([rr.head.name for rr in res.allrrs()]): |
| 89 | if name in self.cache: |
| 90 | self.cache[name] = [cl for cl in self.cache[name] if cl[1] not in alltypes] |
| 91 | for rr in res.allrrs(): |
| 92 | if rr.head.name not in self.cache: |
| 93 | self.cache[rr.head.name] = [] |
| 94 | self.cache[rr.head.name] += [(now + rr.ttl, rr.head.rtype, rr.data, [rr for rr in res.aulist if rr.head.istype("NS")])] |
| 95 | return res |
| 96 | finally: |
| 97 | self.cachelock.release() |
| 98 | |
| 99 | def addcached(self, packet, cis): |
| 100 | for item, auth in cis: |
| 101 | packet.addan(item) |
| 102 | for ns in auth: |
| 103 | packet.addau(ns) |
| 104 | nsal = self.getcached(ns.data["nsname"], ["A", "AAAA"]) |
| 105 | if type(nsal) == list: |
| 106 | for item, auth in nsal: |
| 107 | packet.addad(item) |
| 108 | |
| 109 | def resolve(self, packet): |
| 110 | res = proto.responsefor(packet) |
| 111 | for q in packet.qlist: |
| 112 | name = q.name |
| 113 | rtype = q.rtype |
| 114 | while True: |
| 115 | cis = self.getcached(name, rtype) |
| 116 | if isinstance(cis, nxdmark): |
| 117 | if len(packet.qlist) == 1: |
| 118 | res.rescode = proto.NXDOMAIN |
| 119 | res.aulist = cis.auth |
| 120 | return res |
| 121 | continue |
| 122 | if len(cis) == 0: |
| 123 | cics = self.getcached(name, "CNAME") |
| 124 | if isinstance(cics, nxdmark): |
| 125 | break |
| 126 | if len(cics) > 0: |
| 127 | self.addcached(res, cics) |
| 128 | name = cics[0][0].data["priname"] |
| 129 | continue |
| 130 | break |
| 131 | if len(cis) == 0: |
| 132 | tres = self.dolookup(name, rtype) |
| 133 | if isinstance(tres, nxdmark) and len(packet.qlist) == 1: |
| 134 | res.rescode = proto.NXDOMAIN |
| 135 | res.aulist = tres.auth |
| 136 | return res |
| 137 | if tres is None and len(packet.qlist) == 1: |
| 138 | res.rescode = proto.SERVFAIL |
| 139 | return res |
| 140 | if tres is not None and tres.rescode == 0: |
| 141 | res.merge(tres) |
| 142 | else: |
| 143 | self.addcached(res, cis) |
| 144 | return res |