summaryrefslogtreecommitdiff
path: root/dns
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2021-10-25 17:37:00 -0700
committerGitHub <noreply@github.com>2021-10-25 17:37:00 -0700
commit99166c2c2915c594df5cbbb2fffa6075309ae703 (patch)
tree6fd73c36471bfdc290ba09d7ff84c2bcb4d0a5f1 /dns
parente298b0d231db0444746886252c9a48ce8fce364d (diff)
parenta6ba21c6f07d4fbc481d21a3a660e5ecfef32476 (diff)
downloaddnspython-99166c2c2915c594df5cbbb2fffa6075309ae703.tar.gz
Merge pull request #694 from rthalley/continue_on_error
Continue on error
Diffstat (limited to 'dns')
-rw-r--r--dns/message.py221
1 files changed, 131 insertions, 90 deletions
diff --git a/dns/message.py b/dns/message.py
index 75faee2..8e6f5cc 100644
--- a/dns/message.py
+++ b/dns/message.py
@@ -108,6 +108,12 @@ class MessageSection(dns.enum.IntEnum):
return 3
+class MessageError:
+ def __init__(self, exception, offset):
+ self.exception = exception
+ self.offset = offset
+
+
DEFAULT_EDNS_PAYLOAD = 1232
MAX_CHAIN = 16
@@ -132,6 +138,7 @@ class Message:
self.origin = None
self.tsig_ctx = None
self.index = {}
+ self.errors = []
@property
def question(self):
@@ -873,11 +880,14 @@ class _WireReader:
ignore_trailing: Ignore trailing junk at end of request?
multi: Is this message part of a multi-message sequence?
DNS dynamic updates.
+ continue_on_error: try to extract as much information as possible from
+ the message, accumulating MessageErrors in the *errors* attribute instead of
+ raising them.
"""
def __init__(self, wire, initialize_message, question_only=False,
one_rr_per_rrset=False, ignore_trailing=False,
- keyring=None, multi=False):
+ keyring=None, multi=False, continue_on_error=False):
self.parser = dns.wire.Parser(wire)
self.message = None
self.initialize_message = initialize_message
@@ -886,6 +896,8 @@ class _WireReader:
self.ignore_trailing = ignore_trailing
self.keyring = keyring
self.multi = multi
+ self.continue_on_error = continue_on_error
+ self.errors = []
def _get_question(self, section_number, qcount):
"""Read the next *qcount* records from the wire data and add them to
@@ -902,11 +914,14 @@ class _WireReader:
self.message.find_rrset(section, qname, rdclass, rdtype,
create=True, force_unique=True)
+ def _add_error(self, e):
+ self.errors.append(MessageError(e, self.parser.current))
+
def _get_section(self, section_number, count):
"""Read the next I{count} records from the wire data and add them to
the specified section.
- section: the section of the message to which to add records
+ section_number: the section of the message to which to add records
count: the number of records to read
"""
@@ -929,55 +944,65 @@ class _WireReader:
(rdclass, rdtype, deleting, empty) = \
self.message._parse_rr_header(section_number,
name, rdclass, rdtype)
- if empty:
- if rdlen > 0:
- raise dns.exception.FormError
- rd = None
- covers = dns.rdatatype.NONE
- else:
- with self.parser.restrict_to(rdlen):
- rd = dns.rdata.from_wire_parser(rdclass, rdtype,
- self.parser,
- self.message.origin)
- covers = rd.covers()
- if self.message.xfr and rdtype == dns.rdatatype.SOA:
- force_unique = True
- if rdtype == dns.rdatatype.OPT:
- self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
- elif rdtype == dns.rdatatype.TSIG:
- if self.keyring is None:
- raise UnknownTSIGKey('got signed message without keyring')
- if isinstance(self.keyring, dict):
- key = self.keyring.get(absolute_name)
- if isinstance(key, bytes):
- key = dns.tsig.Key(absolute_name, key, rd.algorithm)
- elif callable(self.keyring):
- key = self.keyring(self.message, absolute_name)
+ try:
+ rdata_start = self.parser.current
+ if empty:
+ if rdlen > 0:
+ raise dns.exception.FormError
+ rd = None
+ covers = dns.rdatatype.NONE
else:
- key = self.keyring
- if key is None:
- raise UnknownTSIGKey("key '%s' unknown" % name)
- self.message.keyring = key
- self.message.tsig_ctx = \
- dns.tsig.validate(self.parser.wire,
- key,
- absolute_name,
- rd,
- int(time.time()),
- self.message.request_mac,
- rr_start,
- self.message.tsig_ctx,
- self.multi)
- self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
- else:
- rrset = self.message.find_rrset(section, name,
- rdclass, rdtype, covers,
- deleting, True,
- force_unique)
- if rd is not None:
- if ttl > 0x7fffffff:
- ttl = 0
- rrset.add(rd, ttl)
+ with self.parser.restrict_to(rdlen):
+ rd = dns.rdata.from_wire_parser(rdclass, rdtype,
+ self.parser,
+ self.message.origin)
+ covers = rd.covers()
+ if self.message.xfr and rdtype == dns.rdatatype.SOA:
+ force_unique = True
+ if rdtype == dns.rdatatype.OPT:
+ self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
+ elif rdtype == dns.rdatatype.TSIG:
+ if self.keyring is None:
+ raise UnknownTSIGKey('got signed message without '
+ 'keyring')
+ if isinstance(self.keyring, dict):
+ key = self.keyring.get(absolute_name)
+ if isinstance(key, bytes):
+ key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+ elif callable(self.keyring):
+ key = self.keyring(self.message, absolute_name)
+ else:
+ key = self.keyring
+ if key is None:
+ raise UnknownTSIGKey("key '%s' unknown" % name)
+ self.message.keyring = key
+ self.message.tsig_ctx = \
+ dns.tsig.validate(self.parser.wire,
+ key,
+ absolute_name,
+ rd,
+ int(time.time()),
+ self.message.request_mac,
+ rr_start,
+ self.message.tsig_ctx,
+ self.multi)
+ self.message.tsig = dns.rrset.from_rdata(absolute_name, 0,
+ rd)
+ else:
+ rrset = self.message.find_rrset(section, name,
+ rdclass, rdtype, covers,
+ deleting, True,
+ force_unique)
+ if rd is not None:
+ if ttl > 0x7fffffff:
+ ttl = 0
+ rrset.add(rd, ttl)
+ except Exception as e:
+ if self.continue_on_error:
+ self._add_error(e)
+ self.parser.seek(rdata_start + rdlen)
+ else:
+ raise
def read(self):
"""Read a wire format DNS message and build a dns.message.Message
@@ -993,69 +1018,82 @@ class _WireReader:
self.initialize_message(self.message)
self.one_rr_per_rrset = \
self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
- self._get_question(MessageSection.QUESTION, qcount)
- if self.question_only:
- return self.message
- self._get_section(MessageSection.ANSWER, ancount)
- self._get_section(MessageSection.AUTHORITY, aucount)
- self._get_section(MessageSection.ADDITIONAL, adcount)
- if not self.ignore_trailing and self.parser.remaining() != 0:
- raise TrailingJunk
- if self.multi and self.message.tsig_ctx and not self.message.had_tsig:
- self.message.tsig_ctx.update(self.parser.wire)
+ try:
+ self._get_question(MessageSection.QUESTION, qcount)
+ if self.question_only:
+ return self.message
+ self._get_section(MessageSection.ANSWER, ancount)
+ self._get_section(MessageSection.AUTHORITY, aucount)
+ self._get_section(MessageSection.ADDITIONAL, adcount)
+ if not self.ignore_trailing and self.parser.remaining() != 0:
+ raise TrailingJunk
+ if self.multi and self.message.tsig_ctx and \
+ not self.message.had_tsig:
+ self.message.tsig_ctx.update(self.parser.wire)
+ except Exception as e:
+ if self.continue_on_error:
+ self._add_error(e)
+ else:
+ raise
return self.message
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
tsig_ctx=None, multi=False,
question_only=False, one_rr_per_rrset=False,
- ignore_trailing=False, raise_on_truncation=False):
- """Convert a DNS wire format message into a message
- object.
+ ignore_trailing=False, raise_on_truncation=False,
+ continue_on_error=False):
+ """Convert a DNS wire format message into a message object.
- *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use
- if the message is signed.
+ *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
+ message is signed.
- *request_mac*, a ``bytes``. If the message is a response to a
- TSIG-signed request, *request_mac* should be set to the MAC of
- that request.
+ *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed
+ request, *request_mac* should be set to the MAC of that request.
- *xfr*, a ``bool``, should be set to ``True`` if this message is part of
- a zone transfer.
+ *xfr*, a ``bool``, should be set to ``True`` if this message is part of a
+ zone transfer.
- *origin*, a ``dns.name.Name`` or ``None``. If the message is part
- of a zone transfer, *origin* should be the origin name of the
- zone. If not ``None``, names will be relativized to the origin.
+ *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone
+ transfer, *origin* should be the origin name of the zone. If not ``None``,
+ names will be relativized to the origin.
*tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
ongoing TSIG context, used when validating zone transfers.
- *multi*, a ``bool``, should be set to ``True`` if this message is
- part of a multiple message sequence.
+ *multi*, a ``bool``, should be set to ``True`` if this message is part of a
+ multiple message sequence.
+
+ *question_only*, a ``bool``. If ``True``, read only up to the end of the
+ question section.
- *question_only*, a ``bool``. If ``True``, read only up to
- the end of the question section.
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its
- own RRset.
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of
+ the message.
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the message.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the
+ TC bit is set.
- *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
- the TC bit is set.
+ *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even
+ if errors occur. Erroneous rdata will be ignored. Errors will be
+ accumulated as a list of MessageError objects in the message's ``errors``
+ attribute. This option is recommended only for DNS analysis tools, or for
+ use in a server as part of an error handling path. The default is
+ ``False``.
Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
long.
- Raises ``dns.message.TrailingJunk`` if there were octets in the message
- past the end of the proper DNS message, and *ignore_trailing* is ``False``.
+ Raises ``dns.message.TrailingJunk`` if there were octets in the message past
+ the end of the proper DNS message, and *ignore_trailing* is ``False``.
- Raises ``dns.message.BadEDNS`` if an OPT record was in the
- wrong section, or occurred more than once.
+ Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or
+ occurred more than once.
- Raises ``dns.message.BadTSIG`` if a TSIG record was not the last
- record of the additional data section.
+ Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of
+ the additional data section.
Raises ``dns.message.Truncated`` if the TC flag is set and
*raise_on_truncation* is ``True``.
@@ -1070,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
message.tsig_ctx = tsig_ctx
reader = _WireReader(wire, initialize_message, question_only,
- one_rr_per_rrset, ignore_trailing, keyring, multi)
+ one_rr_per_rrset, ignore_trailing, keyring, multi,
+ continue_on_error)
try:
m = reader.read()
except dns.exception.FormError:
@@ -1083,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
# have to do this check here too.
if m.flags & dns.flags.TC and raise_on_truncation:
raise Truncated(message=m)
+ if continue_on_error:
+ m.errors = reader.errors
return m