diff options
| author | Bob Halley <halley@dnspython.org> | 2021-10-25 17:37:00 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-10-25 17:37:00 -0700 |
| commit | 99166c2c2915c594df5cbbb2fffa6075309ae703 (patch) | |
| tree | 6fd73c36471bfdc290ba09d7ff84c2bcb4d0a5f1 /dns | |
| parent | e298b0d231db0444746886252c9a48ce8fce364d (diff) | |
| parent | a6ba21c6f07d4fbc481d21a3a660e5ecfef32476 (diff) | |
| download | dnspython-99166c2c2915c594df5cbbb2fffa6075309ae703.tar.gz | |
Merge pull request #694 from rthalley/continue_on_error
Continue on error
Diffstat (limited to 'dns')
| -rw-r--r-- | dns/message.py | 221 |
1 files changed, 131 insertions, 90 deletions
diff --git a/dns/message.py b/dns/message.py index 75faee2..8e6f5cc 100644 --- a/dns/message.py +++ b/dns/message.py @@ -108,6 +108,12 @@ class MessageSection(dns.enum.IntEnum): return 3 +class MessageError: + def __init__(self, exception, offset): + self.exception = exception + self.offset = offset + + DEFAULT_EDNS_PAYLOAD = 1232 MAX_CHAIN = 16 @@ -132,6 +138,7 @@ class Message: self.origin = None self.tsig_ctx = None self.index = {} + self.errors = [] @property def question(self): @@ -873,11 +880,14 @@ class _WireReader: ignore_trailing: Ignore trailing junk at end of request? multi: Is this message part of a multi-message sequence? DNS dynamic updates. + continue_on_error: try to extract as much information as possible from + the message, accumulating MessageErrors in the *errors* attribute instead of + raising them. """ def __init__(self, wire, initialize_message, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, - keyring=None, multi=False): + keyring=None, multi=False, continue_on_error=False): self.parser = dns.wire.Parser(wire) self.message = None self.initialize_message = initialize_message @@ -886,6 +896,8 @@ class _WireReader: self.ignore_trailing = ignore_trailing self.keyring = keyring self.multi = multi + self.continue_on_error = continue_on_error + self.errors = [] def _get_question(self, section_number, qcount): """Read the next *qcount* records from the wire data and add them to @@ -902,11 +914,14 @@ class _WireReader: self.message.find_rrset(section, qname, rdclass, rdtype, create=True, force_unique=True) + def _add_error(self, e): + self.errors.append(MessageError(e, self.parser.current)) + def _get_section(self, section_number, count): """Read the next I{count} records from the wire data and add them to the specified section. - section: the section of the message to which to add records + section_number: the section of the message to which to add records count: the number of records to read """ @@ -929,55 +944,65 @@ class _WireReader: (rdclass, rdtype, deleting, empty) = \ self.message._parse_rr_header(section_number, name, rdclass, rdtype) - if empty: - if rdlen > 0: - raise dns.exception.FormError - rd = None - covers = dns.rdatatype.NONE - else: - with self.parser.restrict_to(rdlen): - rd = dns.rdata.from_wire_parser(rdclass, rdtype, - self.parser, - self.message.origin) - covers = rd.covers() - if self.message.xfr and rdtype == dns.rdatatype.SOA: - force_unique = True - if rdtype == dns.rdatatype.OPT: - self.message.opt = dns.rrset.from_rdata(name, ttl, rd) - elif rdtype == dns.rdatatype.TSIG: - if self.keyring is None: - raise UnknownTSIGKey('got signed message without keyring') - if isinstance(self.keyring, dict): - key = self.keyring.get(absolute_name) - if isinstance(key, bytes): - key = dns.tsig.Key(absolute_name, key, rd.algorithm) - elif callable(self.keyring): - key = self.keyring(self.message, absolute_name) + try: + rdata_start = self.parser.current + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE else: - key = self.keyring - if key is None: - raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyring = key - self.message.tsig_ctx = \ - dns.tsig.validate(self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi) - self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) - else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) - if rd is not None: - if ttl > 0x7fffffff: - ttl = 0 - rrset.add(rd, ttl) + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser(rdclass, rdtype, + self.parser, + self.message.origin) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True + if rdtype == dns.rdatatype.OPT: + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) + elif rdtype == dns.rdatatype.TSIG: + if self.keyring is None: + raise UnknownTSIGKey('got signed message without ' + 'keyring') + if isinstance(self.keyring, dict): + key = self.keyring.get(absolute_name) + if isinstance(key, bytes): + key = dns.tsig.Key(absolute_name, key, rd.algorithm) + elif callable(self.keyring): + key = self.keyring(self.message, absolute_name) + else: + key = self.keyring + if key is None: + raise UnknownTSIGKey("key '%s' unknown" % name) + self.message.keyring = key + self.message.tsig_ctx = \ + dns.tsig.validate(self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, + rd) + else: + rrset = self.message.find_rrset(section, name, + rdclass, rdtype, covers, + deleting, True, + force_unique) + if rd is not None: + if ttl > 0x7fffffff: + ttl = 0 + rrset.add(rd, ttl) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + self.parser.seek(rdata_start + rdlen) + else: + raise def read(self): """Read a wire format DNS message and build a dns.message.Message @@ -993,69 +1018,82 @@ class _WireReader: self.initialize_message(self.message) self.one_rr_per_rrset = \ self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) - self._get_question(MessageSection.QUESTION, qcount) - if self.question_only: - return self.message - self._get_section(MessageSection.ANSWER, ancount) - self._get_section(MessageSection.AUTHORITY, aucount) - self._get_section(MessageSection.ADDITIONAL, adcount) - if not self.ignore_trailing and self.parser.remaining() != 0: - raise TrailingJunk - if self.multi and self.message.tsig_ctx and not self.message.had_tsig: - self.message.tsig_ctx.update(self.parser.wire) + try: + self._get_question(MessageSection.QUESTION, qcount) + if self.question_only: + return self.message + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) + if not self.ignore_trailing and self.parser.remaining() != 0: + raise TrailingJunk + if self.multi and self.message.tsig_ctx and \ + not self.message.had_tsig: + self.message.tsig_ctx.update(self.parser.wire) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + else: + raise return self.message def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False): - """Convert a DNS wire format message into a message - object. + ignore_trailing=False, raise_on_truncation=False, + continue_on_error=False): + """Convert a DNS wire format message into a message object. - *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use - if the message is signed. + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the + message is signed. - *request_mac*, a ``bytes``. If the message is a response to a - TSIG-signed request, *request_mac* should be set to the MAC of - that request. + *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed + request, *request_mac* should be set to the MAC of that request. - *xfr*, a ``bool``, should be set to ``True`` if this message is part of - a zone transfer. + *xfr*, a ``bool``, should be set to ``True`` if this message is part of a + zone transfer. - *origin*, a ``dns.name.Name`` or ``None``. If the message is part - of a zone transfer, *origin* should be the origin name of the - zone. If not ``None``, names will be relativized to the origin. + *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone + transfer, *origin* should be the origin name of the zone. If not ``None``, + names will be relativized to the origin. *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG context, used when validating zone transfers. - *multi*, a ``bool``, should be set to ``True`` if this message is - part of a multiple message sequence. + *multi*, a ``bool``, should be set to ``True`` if this message is part of a + multiple message sequence. + + *question_only*, a ``bool``. If ``True``, read only up to the end of the + question section. - *question_only*, a ``bool``. If ``True``, read only up to - the end of the question section. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its - own RRset. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of + the message. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the message. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the + TC bit is set. - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even + if errors occur. Erroneous rdata will be ignored. Errors will be + accumulated as a list of MessageError objects in the message's ``errors`` + attribute. This option is recommended only for DNS analysis tools, or for + use in a server as part of an error handling path. The default is + ``False``. Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. - Raises ``dns.message.TrailingJunk`` if there were octets in the message - past the end of the proper DNS message, and *ignore_trailing* is ``False``. + Raises ``dns.message.TrailingJunk`` if there were octets in the message past + the end of the proper DNS message, and *ignore_trailing* is ``False``. - Raises ``dns.message.BadEDNS`` if an OPT record was in the - wrong section, or occurred more than once. + Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or + occurred more than once. - Raises ``dns.message.BadTSIG`` if a TSIG record was not the last - record of the additional data section. + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of + the additional data section. Raises ``dns.message.Truncated`` if the TC flag is set and *raise_on_truncation* is ``True``. @@ -1070,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, message.tsig_ctx = tsig_ctx reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing, keyring, multi) + one_rr_per_rrset, ignore_trailing, keyring, multi, + continue_on_error) try: m = reader.read() except dns.exception.FormError: @@ -1083,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, # have to do this check here too. if m.flags & dns.flags.TC and raise_on_truncation: raise Truncated(message=m) + if continue_on_error: + m.errors = reader.errors return m |
