diff options
author | Bob Halley <halley@dnspython.org> | 2021-10-31 14:16:36 -0700 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2021-11-01 09:12:17 -0700 |
commit | 5a16bfb4c227da98b6c19a4ca88da991e6a24b47 (patch) | |
tree | e7a821be8694f492e507c27742fcb3dc2f3d6c49 | |
parent | 3e5be3fc47248b9f4d4cc5c9dd81ad2ba2ee4797 (diff) | |
download | dnspython-5a16bfb4c227da98b6c19a4ca88da991e6a24b47.tar.gz |
rrset-reader PR
-rw-r--r-- | dns/zonefile.py | 240 | ||||
-rw-r--r-- | tests/test_rrset_reader.py | 131 |
2 files changed, 322 insertions, 49 deletions
diff --git a/dns/zonefile.py b/dns/zonefile.py index 92e2f0c..d3b9656 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -42,21 +42,35 @@ class Reader: """Read a DNS zone file into a transaction.""" - def __init__(self, tok, rdclass, txn, allow_include=False): + def __init__(self, tok, rdclass, txn, allow_include=False, + allow_directives=True, force_name=None, + force_ttl=None, force_rdclass=None, force_rdtype=None, + default_ttl=None): self.tok = tok (self.zone_origin, self.relativize, _) = \ txn.manager.origin_information() self.current_origin = self.zone_origin self.last_ttl = 0 self.last_ttl_known = False - self.default_ttl = 0 - self.default_ttl_known = False + if force_ttl is not None: + default_ttl = force_ttl + if default_ttl is None: + self.default_ttl = 0 + self.default_ttl_known = False + else: + self.default_ttl = default_ttl + self.default_ttl_known = True self.last_name = self.current_origin self.zone_rdclass = rdclass self.txn = txn self.saved_state = [] self.current_file = None self.allow_include = allow_include + self.allow_directives = allow_directives + self.force_name = force_name + self.force_ttl = force_ttl + self.force_rdclass = force_rdclass + self.force_rdtype = force_rdtype def _eat_line(self): while 1: @@ -64,63 +78,85 @@ class Reader: if token.is_eol_or_eof(): break + def _get_identifier(self): + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + return token + def _rr_line(self): """Process one line from a DNS zone file.""" + token = None # Name - if self.current_origin is None: - raise UnknownOrigin - token = self.tok.get(want_leading=True) - if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.current_origin) + if self.force_name is not None: + name = self.force_name else: - token = self.tok.get() - if token.is_eol_or_eof(): - # treat leading WS followed by EOL/EOF as if they were EOL/EOF. + if self.current_origin is None: + raise UnknownOrigin + token = self.tok.get(want_leading=True) + if not token.is_whitespace(): + self.last_name = self.tok.as_name(token, self.current_origin) + else: + token = self.tok.get() + if token.is_eol_or_eof(): + # treat leading WS followed by EOL/EOF as if they were EOL/EOF. + return + self.tok.unget(token) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() return - self.tok.unget(token) - name = self.last_name - if not name.is_subdomain(self.zone_origin): - self._eat_line() - return - if self.relativize: - name = name.relativize(self.zone_origin) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError + if self.relativize: + name = name.relativize(self.zone_origin) # TTL - ttl = None - try: - ttl = dns.ttl.from_text(token.value) + if self.force_ttl is not None: + ttl = self.force_ttl self.last_ttl = ttl self.last_ttl_known = True - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.ttl.BadTTL: - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl + else: + token = self._get_identifier() + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = None + except dns.ttl.BadTTL: + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + self.tok.unget(token) # Class - try: - rdclass = dns.rdataclass.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.exception.SyntaxError: - raise - except Exception: - rdclass = self.zone_rdclass - if rdclass != self.zone_rdclass: - raise dns.exception.SyntaxError("RR class is not zone's class") + if self.force_rdclass is not None: + rdclass = self.force_rdclass + else: + token = self._get_identifier() + try: + rdclass = dns.rdataclass.from_text(token.value) + except dns.exception.SyntaxError: + raise + except Exception: + rdclass = self.zone_rdclass + self.tok.unget(token) + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + # Type - try: - rdtype = dns.rdatatype.from_text(token.value) - except Exception: - raise dns.exception.SyntaxError( - "unknown rdatatype '%s'" % token.value) + if self.force_rdtype is not None: + rdtype = self.force_rdtype + # we need to unget the token we got, as there is always one + # outstanding at this point + else: + token = self._get_identifier() + try: + rdtype = dns.rdatatype.from_text(token.value) + except Exception: + raise dns.exception.SyntaxError( + "unknown rdatatype '%s'" % token.value) + try: rd = dns.rdata.from_text(rdclass, rdtype, self.tok, self.current_origin, self.relativize, @@ -341,7 +377,7 @@ class Reader: elif token.is_comment(): self.tok.get_eol() continue - elif token.value[0] == '$': + elif token.value[0] == '$' and self.allow_directives: c = token.value.upper() if c == '$TTL': token = self.tok.get() @@ -399,3 +435,109 @@ class Reader: "%s:%d: %s" % (filename, line_number, detail)) tb = sys.exc_info()[2] raise ex.with_traceback(tb) from None + + +class RRsetsReaderTransaction(dns.transaction.Transaction): + + def __init__(self, manager, replacement, read_only): + assert not read_only + super().__init__(manager, replacement, read_only) + self.rdatasets = {} + + def _get_rdataset(self, name, rdtype, covers): + return self.rdatasets.get((name, rdtype, covers)) + + def _put_rdataset(self, name, rdataset): + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset + + def _delete_name(self, name): + # First remove any changes involving the name + remove = [] + for key in self.rdatasets: + if key[0] == name: + remove.append(key) + if len(remove) > 0: + for key in remove: + del self.rdatasets[key] + + def _delete_rdataset(self, name, rdtype, covers): + try: + del self.rdatasets[(name, rdtype, covers)] + except KeyError: + pass + + def _name_exists(self, name): + for (n, _, _) in self.rdatasets: + if n == name: + return True + return False + + def _changed(self): + return len(self.rdatasets) > 0 + + def _end_transaction(self, commit): + if commit and self._changed(): + rrsets = [] + for (name, _, _), rdataset in self.rdatasets.items(): + rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype, + rdataset.covers) + rrset.update(rdataset) + rrsets.append(rrset) + self.manager.set_rrsets(rrsets) + + def _set_origin(self, origin): + pass + + +class RRSetsReaderManager(dns.transaction.TransactionManager): + def __init__(self, origin=dns.name.root, relativize=False, + rdclass=dns.rdataclass.IN): + self.origin = origin + self.relativize = relativize + self.rdclass = rdclass + self.rrsets = [] + + def writer(self, replacement=False): + assert replacement == True + return RRsetsReaderTransaction(self, True, False) + + def get_class(self): + return self.rdclass + + def origin_information(self): + if self.relativize: + effective = dns.name.empty + else: + effective = self.origin + return (self.origin, self.relativize, effective) + + def set_rrsets(self, rrsets): + self.rrsets = rrsets + + +def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN, + default_rdclass=dns.rdataclass.IN, + rdtype=None, default_ttl=None, idna_codec=None, + origin=dns.name.root, relativize=False): + if isinstance(origin, str): + origin = dns.name.from_text(origin, dns.name.root, idna_codec) + if isinstance(name, str): + name = dns.name.from_text(name, origin, idna_codec) + if isinstance(ttl, str): + ttl = dns.ttl.from_text(ttl) + if isinstance(default_ttl, str): + default_ttl = dns.ttl.from_text(default_ttl) + if rdclass is not None: + rdclass = dns.rdataclass.RdataClass.make(rdclass) + default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + if rdtype is not None: + rdtype = dns.rdatatype.RdataType.make(rdtype) + manager = RRSetsReaderManager(origin, relativize, default_rdclass) + with manager.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, '<input>', idna_codec=idna_codec) + reader = Reader(tok, default_rdclass, txn, allow_directives=False, + force_name=name, force_ttl=ttl, force_rdclass=rdclass, + force_rdtype=rdtype, default_ttl=default_ttl) + reader.read() + return manager.rrsets + diff --git a/tests/test_rrset_reader.py b/tests/test_rrset_reader.py new file mode 100644 index 0000000..8d4255e --- /dev/null +++ b/tests/test_rrset_reader.py @@ -0,0 +1,131 @@ +import pytest + +import dns.rrset +from dns.zonefile import read_rrsets + +expected_mx_1= dns.rrset.from_text('name.', 300, 'in', 'mx', '10 a.', '20 b.') +expected_mx_2 = dns.rrset.from_text('name.', 10, 'in', 'mx', '10 a.', '20 b.') +expected_mx_3 = dns.rrset.from_text('foo.', 10, 'in', 'mx', '10 a.') +expected_mx_4 = dns.rrset.from_text('bar.', 10, 'in', 'mx', '20 b.') +expected_mx_5 = dns.rrset.from_text('foo.example.', 10, 'in', 'mx', + '10 a.example.') +expected_mx_6 = dns.rrset.from_text('bar.example.', 10, 'in', 'mx', '20 b.') +expected_mx_7 = dns.rrset.from_text('foo', 10, 'in', 'mx', '10 a') +expected_mx_8 = dns.rrset.from_text('bar', 10, 'in', 'mx', '20 b.') +expected_ns_1 = dns.rrset.from_text('name.', 300, 'in', 'ns', 'hi.') +expected_ns_2 = dns.rrset.from_text('name.', 300, 'ch', 'ns', 'hi.') + +def equal_rrsets(a, b): + # return True iff. a and b have the same rrsets regardless of order + if len(a) != len(b): + return False + for rrset in a: + if not rrset in b: + return False + return True + +def test_name_ttl_rdclass_forced(): + input='''; +mx 10 a +mx 20 b. +ns hi''' + rrsets = read_rrsets(input, name='name', ttl=300) + assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1]) + assert rrsets[0].ttl == 300 + assert rrsets[1].ttl == 300 + +def test_name_ttl_rdclass_forced_rdata_split(): + input='''; +mx 10 a +ns hi +mx 20 b.''' + rrsets = read_rrsets(input, name='name', ttl=300) + assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1]) + +def test_name_ttl_rdclass_rdtype_forced(): + input='''; +10 a +20 b.''' + rrsets = read_rrsets(input, name='name', ttl=300, rdtype='mx') + assert equal_rrsets(rrsets, [expected_mx_1]) + +def test_name_rdclass_forced(): + input = '''30 mx 10 a +10 mx 20 b. +''' + rrsets = read_rrsets(input, name='name') + assert equal_rrsets(rrsets, [expected_mx_2]) + assert rrsets[0].ttl == 10 + +def test_rdclass_forced(): + input = '''; +foo 20 mx 10 a +bar 30 mx 20 b. +''' + rrsets = read_rrsets(input) + assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4]) + +def test_rdclass_forced_with_origin(): + input = '''; +foo 20 mx 10 a +bar.example. 30 mx 20 b. +''' + rrsets = read_rrsets(input, origin='example') + assert equal_rrsets(rrsets, [expected_mx_5, expected_mx_6]) + + +def test_rdclass_forced_with_origin_relativized(): + input = '''; +foo 20 mx 10 a.example. +bar.example. 30 mx 20 b. +''' + rrsets = read_rrsets(input, origin='example', relativize=True) + assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8]) + +def test_rdclass_matching_default_tolerated(): + input = '''; +foo 20 mx 10 a.example. +bar.example. 30 in mx 20 b. +''' + rrsets = read_rrsets(input, origin='example', relativize=True, + rdclass=None) + assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8]) + +def test_rdclass_not_matching_default_rejected(): + input = '''; +foo 20 mx 10 a.example. +bar.example. 30 ch mx 20 b. +''' + with pytest.raises(dns.exception.SyntaxError): + rrsets = read_rrsets(input, origin='example', relativize=True, + rdclass=None) + +def test_default_rdclass_is_none(): + input = '' + with pytest.raises(TypeError): + rrsets = read_rrsets(input, default_rdclass=None, origin='example', + relativize=True) + +def test_name_rdclass_rdtype_force(): + # No real-world usage should do this, but it can be specified so we test it. + input = '''; +30 10 a +10 20 b. +''' + rrsets = read_rrsets(input, name='name', rdtype='mx') + assert equal_rrsets(rrsets, [expected_mx_1]) + assert rrsets[0].ttl == 10 + +def test_rdclass_rdtype_force(): + # No real-world usage should do this, but it can be specified so we test it. + input = '''; +foo 30 10 a +bar 30 20 b. +''' + rrsets = read_rrsets(input, rdtype='mx') + assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4]) + +# also weird but legal +#input5 = '''foo 30 10 a +#bar 10 20 foo. +#''' |