summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2021-10-31 14:16:36 -0700
committerBob Halley <halley@dnspython.org>2021-11-01 09:12:17 -0700
commit5a16bfb4c227da98b6c19a4ca88da991e6a24b47 (patch)
treee7a821be8694f492e507c27742fcb3dc2f3d6c49
parent3e5be3fc47248b9f4d4cc5c9dd81ad2ba2ee4797 (diff)
downloaddnspython-5a16bfb4c227da98b6c19a4ca88da991e6a24b47.tar.gz
rrset-reader PR
-rw-r--r--dns/zonefile.py240
-rw-r--r--tests/test_rrset_reader.py131
2 files changed, 322 insertions, 49 deletions
diff --git a/dns/zonefile.py b/dns/zonefile.py
index 92e2f0c..d3b9656 100644
--- a/dns/zonefile.py
+++ b/dns/zonefile.py
@@ -42,21 +42,35 @@ class Reader:
"""Read a DNS zone file into a transaction."""
- def __init__(self, tok, rdclass, txn, allow_include=False):
+ def __init__(self, tok, rdclass, txn, allow_include=False,
+ allow_directives=True, force_name=None,
+ force_ttl=None, force_rdclass=None, force_rdtype=None,
+ default_ttl=None):
self.tok = tok
(self.zone_origin, self.relativize, _) = \
txn.manager.origin_information()
self.current_origin = self.zone_origin
self.last_ttl = 0
self.last_ttl_known = False
- self.default_ttl = 0
- self.default_ttl_known = False
+ if force_ttl is not None:
+ default_ttl = force_ttl
+ if default_ttl is None:
+ self.default_ttl = 0
+ self.default_ttl_known = False
+ else:
+ self.default_ttl = default_ttl
+ self.default_ttl_known = True
self.last_name = self.current_origin
self.zone_rdclass = rdclass
self.txn = txn
self.saved_state = []
self.current_file = None
self.allow_include = allow_include
+ self.allow_directives = allow_directives
+ self.force_name = force_name
+ self.force_ttl = force_ttl
+ self.force_rdclass = force_rdclass
+ self.force_rdtype = force_rdtype
def _eat_line(self):
while 1:
@@ -64,63 +78,85 @@ class Reader:
if token.is_eol_or_eof():
break
+ def _get_identifier(self):
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ return token
+
def _rr_line(self):
"""Process one line from a DNS zone file."""
+ token = None
# Name
- if self.current_origin is None:
- raise UnknownOrigin
- token = self.tok.get(want_leading=True)
- if not token.is_whitespace():
- self.last_name = self.tok.as_name(token, self.current_origin)
+ if self.force_name is not None:
+ name = self.force_name
else:
- token = self.tok.get()
- if token.is_eol_or_eof():
- # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+ if self.current_origin is None:
+ raise UnknownOrigin
+ token = self.tok.get(want_leading=True)
+ if not token.is_whitespace():
+ self.last_name = self.tok.as_name(token, self.current_origin)
+ else:
+ token = self.tok.get()
+ if token.is_eol_or_eof():
+ # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+ return
+ self.tok.unget(token)
+ name = self.last_name
+ if not name.is_subdomain(self.zone_origin):
+ self._eat_line()
return
- self.tok.unget(token)
- name = self.last_name
- if not name.is_subdomain(self.zone_origin):
- self._eat_line()
- return
- if self.relativize:
- name = name.relativize(self.zone_origin)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
+ if self.relativize:
+ name = name.relativize(self.zone_origin)
# TTL
- ttl = None
- try:
- ttl = dns.ttl.from_text(token.value)
+ if self.force_ttl is not None:
+ ttl = self.force_ttl
self.last_ttl = ttl
self.last_ttl_known = True
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except dns.ttl.BadTTL:
- if self.default_ttl_known:
- ttl = self.default_ttl
- elif self.last_ttl_known:
- ttl = self.last_ttl
+ else:
+ token = self._get_identifier()
+ ttl = None
+ try:
+ ttl = dns.ttl.from_text(token.value)
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ token = None
+ except dns.ttl.BadTTL:
+ if self.default_ttl_known:
+ ttl = self.default_ttl
+ elif self.last_ttl_known:
+ ttl = self.last_ttl
+ self.tok.unget(token)
# Class
- try:
- rdclass = dns.rdataclass.from_text(token.value)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except dns.exception.SyntaxError:
- raise
- except Exception:
- rdclass = self.zone_rdclass
- if rdclass != self.zone_rdclass:
- raise dns.exception.SyntaxError("RR class is not zone's class")
+ if self.force_rdclass is not None:
+ rdclass = self.force_rdclass
+ else:
+ token = self._get_identifier()
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ except dns.exception.SyntaxError:
+ raise
+ except Exception:
+ rdclass = self.zone_rdclass
+ self.tok.unget(token)
+ if rdclass != self.zone_rdclass:
+ raise dns.exception.SyntaxError("RR class is not zone's class")
+
# Type
- try:
- rdtype = dns.rdatatype.from_text(token.value)
- except Exception:
- raise dns.exception.SyntaxError(
- "unknown rdatatype '%s'" % token.value)
+ if self.force_rdtype is not None:
+ rdtype = self.force_rdtype
+ # we need to unget the token we got, as there is always one
+ # outstanding at this point
+ else:
+ token = self._get_identifier()
+ try:
+ rdtype = dns.rdatatype.from_text(token.value)
+ except Exception:
+ raise dns.exception.SyntaxError(
+ "unknown rdatatype '%s'" % token.value)
+
try:
rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
self.current_origin, self.relativize,
@@ -341,7 +377,7 @@ class Reader:
elif token.is_comment():
self.tok.get_eol()
continue
- elif token.value[0] == '$':
+ elif token.value[0] == '$' and self.allow_directives:
c = token.value.upper()
if c == '$TTL':
token = self.tok.get()
@@ -399,3 +435,109 @@ class Reader:
"%s:%d: %s" % (filename, line_number, detail))
tb = sys.exc_info()[2]
raise ex.with_traceback(tb) from None
+
+
+class RRsetsReaderTransaction(dns.transaction.Transaction):
+
+ def __init__(self, manager, replacement, read_only):
+ assert not read_only
+ super().__init__(manager, replacement, read_only)
+ self.rdatasets = {}
+
+ def _get_rdataset(self, name, rdtype, covers):
+ return self.rdatasets.get((name, rdtype, covers))
+
+ def _put_rdataset(self, name, rdataset):
+ self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+
+ def _delete_name(self, name):
+ # First remove any changes involving the name
+ remove = []
+ for key in self.rdatasets:
+ if key[0] == name:
+ remove.append(key)
+ if len(remove) > 0:
+ for key in remove:
+ del self.rdatasets[key]
+
+ def _delete_rdataset(self, name, rdtype, covers):
+ try:
+ del self.rdatasets[(name, rdtype, covers)]
+ except KeyError:
+ pass
+
+ def _name_exists(self, name):
+ for (n, _, _) in self.rdatasets:
+ if n == name:
+ return True
+ return False
+
+ def _changed(self):
+ return len(self.rdatasets) > 0
+
+ def _end_transaction(self, commit):
+ if commit and self._changed():
+ rrsets = []
+ for (name, _, _), rdataset in self.rdatasets.items():
+ rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype,
+ rdataset.covers)
+ rrset.update(rdataset)
+ rrsets.append(rrset)
+ self.manager.set_rrsets(rrsets)
+
+ def _set_origin(self, origin):
+ pass
+
+
+class RRSetsReaderManager(dns.transaction.TransactionManager):
+ def __init__(self, origin=dns.name.root, relativize=False,
+ rdclass=dns.rdataclass.IN):
+ self.origin = origin
+ self.relativize = relativize
+ self.rdclass = rdclass
+ self.rrsets = []
+
+ def writer(self, replacement=False):
+ assert replacement == True
+ return RRsetsReaderTransaction(self, True, False)
+
+ def get_class(self):
+ return self.rdclass
+
+ def origin_information(self):
+ if self.relativize:
+ effective = dns.name.empty
+ else:
+ effective = self.origin
+ return (self.origin, self.relativize, effective)
+
+ def set_rrsets(self, rrsets):
+ self.rrsets = rrsets
+
+
+def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN,
+ default_rdclass=dns.rdataclass.IN,
+ rdtype=None, default_ttl=None, idna_codec=None,
+ origin=dns.name.root, relativize=False):
+ if isinstance(origin, str):
+ origin = dns.name.from_text(origin, dns.name.root, idna_codec)
+ if isinstance(name, str):
+ name = dns.name.from_text(name, origin, idna_codec)
+ if isinstance(ttl, str):
+ ttl = dns.ttl.from_text(ttl)
+ if isinstance(default_ttl, str):
+ default_ttl = dns.ttl.from_text(default_ttl)
+ if rdclass is not None:
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+ if rdtype is not None:
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ manager = RRSetsReaderManager(origin, relativize, default_rdclass)
+ with manager.writer(True) as txn:
+ tok = dns.tokenizer.Tokenizer(text, '<input>', idna_codec=idna_codec)
+ reader = Reader(tok, default_rdclass, txn, allow_directives=False,
+ force_name=name, force_ttl=ttl, force_rdclass=rdclass,
+ force_rdtype=rdtype, default_ttl=default_ttl)
+ reader.read()
+ return manager.rrsets
+
diff --git a/tests/test_rrset_reader.py b/tests/test_rrset_reader.py
new file mode 100644
index 0000000..8d4255e
--- /dev/null
+++ b/tests/test_rrset_reader.py
@@ -0,0 +1,131 @@
+import pytest
+
+import dns.rrset
+from dns.zonefile import read_rrsets
+
+expected_mx_1= dns.rrset.from_text('name.', 300, 'in', 'mx', '10 a.', '20 b.')
+expected_mx_2 = dns.rrset.from_text('name.', 10, 'in', 'mx', '10 a.', '20 b.')
+expected_mx_3 = dns.rrset.from_text('foo.', 10, 'in', 'mx', '10 a.')
+expected_mx_4 = dns.rrset.from_text('bar.', 10, 'in', 'mx', '20 b.')
+expected_mx_5 = dns.rrset.from_text('foo.example.', 10, 'in', 'mx',
+ '10 a.example.')
+expected_mx_6 = dns.rrset.from_text('bar.example.', 10, 'in', 'mx', '20 b.')
+expected_mx_7 = dns.rrset.from_text('foo', 10, 'in', 'mx', '10 a')
+expected_mx_8 = dns.rrset.from_text('bar', 10, 'in', 'mx', '20 b.')
+expected_ns_1 = dns.rrset.from_text('name.', 300, 'in', 'ns', 'hi.')
+expected_ns_2 = dns.rrset.from_text('name.', 300, 'ch', 'ns', 'hi.')
+
+def equal_rrsets(a, b):
+ # return True iff. a and b have the same rrsets regardless of order
+ if len(a) != len(b):
+ return False
+ for rrset in a:
+ if not rrset in b:
+ return False
+ return True
+
+def test_name_ttl_rdclass_forced():
+ input=''';
+mx 10 a
+mx 20 b.
+ns hi'''
+ rrsets = read_rrsets(input, name='name', ttl=300)
+ assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1])
+ assert rrsets[0].ttl == 300
+ assert rrsets[1].ttl == 300
+
+def test_name_ttl_rdclass_forced_rdata_split():
+ input=''';
+mx 10 a
+ns hi
+mx 20 b.'''
+ rrsets = read_rrsets(input, name='name', ttl=300)
+ assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1])
+
+def test_name_ttl_rdclass_rdtype_forced():
+ input=''';
+10 a
+20 b.'''
+ rrsets = read_rrsets(input, name='name', ttl=300, rdtype='mx')
+ assert equal_rrsets(rrsets, [expected_mx_1])
+
+def test_name_rdclass_forced():
+ input = '''30 mx 10 a
+10 mx 20 b.
+'''
+ rrsets = read_rrsets(input, name='name')
+ assert equal_rrsets(rrsets, [expected_mx_2])
+ assert rrsets[0].ttl == 10
+
+def test_rdclass_forced():
+ input = ''';
+foo 20 mx 10 a
+bar 30 mx 20 b.
+'''
+ rrsets = read_rrsets(input)
+ assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4])
+
+def test_rdclass_forced_with_origin():
+ input = ''';
+foo 20 mx 10 a
+bar.example. 30 mx 20 b.
+'''
+ rrsets = read_rrsets(input, origin='example')
+ assert equal_rrsets(rrsets, [expected_mx_5, expected_mx_6])
+
+
+def test_rdclass_forced_with_origin_relativized():
+ input = ''';
+foo 20 mx 10 a.example.
+bar.example. 30 mx 20 b.
+'''
+ rrsets = read_rrsets(input, origin='example', relativize=True)
+ assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8])
+
+def test_rdclass_matching_default_tolerated():
+ input = ''';
+foo 20 mx 10 a.example.
+bar.example. 30 in mx 20 b.
+'''
+ rrsets = read_rrsets(input, origin='example', relativize=True,
+ rdclass=None)
+ assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8])
+
+def test_rdclass_not_matching_default_rejected():
+ input = ''';
+foo 20 mx 10 a.example.
+bar.example. 30 ch mx 20 b.
+'''
+ with pytest.raises(dns.exception.SyntaxError):
+ rrsets = read_rrsets(input, origin='example', relativize=True,
+ rdclass=None)
+
+def test_default_rdclass_is_none():
+ input = ''
+ with pytest.raises(TypeError):
+ rrsets = read_rrsets(input, default_rdclass=None, origin='example',
+ relativize=True)
+
+def test_name_rdclass_rdtype_force():
+ # No real-world usage should do this, but it can be specified so we test it.
+ input = ''';
+30 10 a
+10 20 b.
+'''
+ rrsets = read_rrsets(input, name='name', rdtype='mx')
+ assert equal_rrsets(rrsets, [expected_mx_1])
+ assert rrsets[0].ttl == 10
+
+def test_rdclass_rdtype_force():
+ # No real-world usage should do this, but it can be specified so we test it.
+ input = ''';
+foo 30 10 a
+bar 30 20 b.
+'''
+ rrsets = read_rrsets(input, rdtype='mx')
+ assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4])
+
+# also weird but legal
+#input5 = '''foo 30 10 a
+#bar 10 20 foo.
+#'''