summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-08-06 20:02:24 -0700
committerBob Halley <halley@dnspython.org>2020-08-07 17:03:26 -0700
commita18c51228fbc399f6582c969993c32d3b8cdc6b0 (patch)
treef087acaf1bd572bfcdc0cb82567284ee680e4083
parent3aa0379a50c75647320edc9db190b4f27fb3c269 (diff)
downloaddnspython-svcb.tar.gz
SVCB and HTTPS checkpointsvcb
-rw-r--r--dns/rdataset.py4
-rw-r--r--dns/rdatatype.py4
-rw-r--r--dns/rdtypes/IN/HTTPS.py6
-rw-r--r--dns/rdtypes/IN/SVCB.py6
-rw-r--r--dns/rdtypes/__init__.py1
-rw-r--r--dns/rdtypes/svcbbase.py518
-rw-r--r--doc/whatsnew.rst2
-rw-r--r--tests/test_svcb.py239
8 files changed, 778 insertions, 2 deletions
diff --git a/dns/rdataset.py b/dns/rdataset.py
index 0e47139..b91d288 100644
--- a/dns/rdataset.py
+++ b/dns/rdataset.py
@@ -268,7 +268,7 @@ class Rdataset(dns.set.Set):
want_shuffle = False
else:
rdclass = self.rdclass
- file.seek(0, 2)
+ file.seek(0, io.SEEK_END)
if len(self) == 0:
name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)
@@ -292,7 +292,7 @@ class Rdataset(dns.set.Set):
file.seek(start - 2)
stuff = struct.pack("!H", end - start)
file.write(stuff)
- file.seek(0, 2)
+ file.seek(0, io.SEEK_END)
return len(self)
def match(self, rdclass, rdtype, covers):
diff --git a/dns/rdatatype.py b/dns/rdatatype.py
index 740752e..a6b5d64 100644
--- a/dns/rdatatype.py
+++ b/dns/rdatatype.py
@@ -78,6 +78,8 @@ class RdataType(dns.enum.IntEnum):
CDNSKEY = 60
OPENPGPKEY = 61
CSYNC = 62
+ SVCB = 64
+ HTTPS = 65
SPF = 99
UNSPEC = 103
EUI48 = 108
@@ -276,6 +278,8 @@ CDS = RdataType.CDS
CDNSKEY = RdataType.CDNSKEY
OPENPGPKEY = RdataType.OPENPGPKEY
CSYNC = RdataType.CSYNC
+SVCB = RdataType.SVCB
+HTTPS = RdataType.HTTPS
SPF = RdataType.SPF
UNSPEC = RdataType.UNSPEC
EUI48 = RdataType.EUI48
diff --git a/dns/rdtypes/IN/HTTPS.py b/dns/rdtypes/IN/HTTPS.py
new file mode 100644
index 0000000..ad67897
--- /dev/null
+++ b/dns/rdtypes/IN/HTTPS.py
@@ -0,0 +1,6 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import dns.rdtypes.svcbbase
+
+class HTTPS(dns.rdtypes.svcbbase.SVCBBase):
+ """HTTPS record"""
diff --git a/dns/rdtypes/IN/SVCB.py b/dns/rdtypes/IN/SVCB.py
new file mode 100644
index 0000000..8effeb8
--- /dev/null
+++ b/dns/rdtypes/IN/SVCB.py
@@ -0,0 +1,6 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import dns.rdtypes.svcbbase
+
+class SVCB(dns.rdtypes.svcbbase.SVCBBase):
+ """SVCB record"""
diff --git a/dns/rdtypes/__init__.py b/dns/rdtypes/__init__.py
index ccc848c..0783aa5 100644
--- a/dns/rdtypes/__init__.py
+++ b/dns/rdtypes/__init__.py
@@ -24,5 +24,6 @@ __all__ = [
'euibase',
'mxbase',
'nsbase',
+ 'svcbbase',
'util'
]
diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py
new file mode 100644
index 0000000..ac0c5cb
--- /dev/null
+++ b/dns/rdtypes/svcbbase.py
@@ -0,0 +1,518 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import base64
+import enum
+import io
+import struct
+
+import dns.enum
+import dns.exception
+import dns.ipv4
+import dns.ipv6
+import dns.name
+import dns.rdata
+import dns.tokenizer
+import dns.wire
+
+# Until there is an RFC, this module is experimental and may be changed in
+# incompatible ways.
+
+
+class UnknownParamKey(dns.exception.DNSException):
+ """Unknown SVCB ParamKey"""
+
+
+class ParamKey(dns.enum.IntEnum):
+ """SVCB ParamKey"""
+
+ MANDATORY = 0
+ ALPN = 1
+ NO_DEFAULT_ALPN = 2
+ PORT = 3
+ IPV4HINT = 4
+ ECHCONFIG = 5
+ IPV6HINT = 6
+
+ @classmethod
+ def _maximum(cls):
+ return 65535
+
+ @classmethod
+ def _short_name(cls):
+ return "SVCBParamKey"
+
+ @classmethod
+ def _prefix(cls):
+ return "KEY"
+
+ @classmethod
+ def _unknown_exception_class(cls):
+ return UnknownParamKey
+
+
+class Emptiness(enum.IntEnum):
+ NEVER = 0
+ ALWAYS = 1
+ ALLOWED = 2
+
+
+def _validate_key(key):
+ force_generic = False
+ if isinstance(key, bytes):
+ # We decode to latin-1 so we get 0-255 as valid and do NOT interpret
+ # UTF-8 sequences
+ key = key.decode('latin-1')
+ if isinstance(key, str):
+ if key.lower().startswith('key'):
+ force_generic = True
+ if key[3:].startswith('0') and len(key) != 4:
+ # key has leading zeros
+ raise ValueError('leading zeros in key')
+ key = key.replace('-', '_')
+ return (ParamKey.make(key), force_generic)
+
+def key_to_text(key):
+ return ParamKey.to_text(key).replace('_', '-').lower()
+
+# Like rdata escapify, but escapes ',' too.
+
+_escaped = b'",\\'
+
+def _escapify(qstring):
+ text = ''
+ for c in qstring:
+ if c in _escaped:
+ text += '\\' + chr(c)
+ elif c >= 0x20 and c < 0x7F:
+ text += chr(c)
+ else:
+ text += '\\%03d' % c
+ return text
+
+def _unescape(value, list_mode=False):
+ if value == '':
+ return value
+ items = []
+ unescaped = b''
+ l = len(value)
+ i = 0
+ while i < l:
+ c = value[i]
+ i += 1
+ if c == ',' and list_mode:
+ if len(unescaped) == 0:
+ raise ValueError('list item cannot be empty')
+ items.append(unescaped)
+ unescaped = b''
+ continue
+ if c == '\\':
+ if i >= l: # pragma: no cover (can't happen via tokenizer get())
+ raise dns.exception.UnexpectedEnd
+ c = value[i]
+ i += 1
+ if c.isdigit():
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c2 = value[i]
+ i += 1
+ if i >= l:
+ raise dns.exception.UnexpectedEnd
+ c3 = value[i]
+ i += 1
+ if not (c2.isdigit() and c3.isdigit()):
+ raise dns.exception.SyntaxError
+ c = chr(int(c) * 100 + int(c2) * 10 + int(c3))
+ unescaped += c.encode()
+ if len(unescaped) > 0:
+ items.append(unescaped)
+ elif list_mode:
+ raise ValueError('trailing comma')
+ if list_mode:
+ return items
+ else:
+ return items[0]
+
+
+class Param:
+ """Abstract base class for SVCB parameters"""
+
+ @classmethod
+ def emptiness(cls):
+ return Emptiness.NEVER
+
+class GenericParam(Param):
+ """Generic SVCB parameter
+ """
+ def __init__(self, value):
+ self.value = value
+
+ @classmethod
+ def emptiness(cls):
+ return Emptiness.ALLOWED
+
+ @classmethod
+ def from_value(cls, value):
+ if value is None or len(value) == 0:
+ return None
+ else:
+ return cls(_unescape(value))
+
+ def to_text(self):
+ return '"' + _escapify(self.value) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ value = parser.get_bytes(parser.remaining())
+ if len(value) == 0:
+ return None
+ else:
+ return cls(value)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ file.write(self.value)
+
+
+class MandatoryParam(Param):
+ def __init__(self, keys):
+ # check for duplicates
+ self.keys = sorted([_validate_key(key)[0] for key in keys])
+ prior_k = None
+ for k in self.keys:
+ if k == prior_k:
+ raise ValueError(f'duplicate key {k}')
+ prior_k = k
+ if k == ParamKey.MANDATORY:
+ raise ValueError('listed the mandatory key as mandatory')
+
+ @classmethod
+ def from_value(cls, value):
+ keys = [k.encode() for k in value.split(',')]
+ return cls(keys)
+
+ def to_text(self):
+ return '"' + ','.join([key_to_text(key) for key in self.keys]) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ keys = []
+ last_key = -1
+ while parser.remaining() > 0:
+ key = parser.get_uint16()
+ if key < last_key:
+ raise dns.exception.FormError('manadatory keys not ascending')
+ last_key = key
+ keys.append(key)
+ return cls(keys)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for key in self.keys:
+ file.write(struct.pack('!H', key))
+
+class ALPNParam(Param):
+ def __init__(self, ids):
+ for id in ids:
+ if len(id) == 0:
+ raise dns.exception.FormError('empty ALPN')
+ if len(id) > 255:
+ raise ValueError('ALPN id too long')
+ self.ids = ids
+
+ @classmethod
+ def from_value(cls, value):
+ return cls(_unescape(value, True))
+
+ def to_text(self):
+ return '"' + ','.join([_escapify(id) for id in self.ids]) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ ids = []
+ while parser.remaining() > 0:
+ id = parser.get_counted_bytes()
+ ids.append(id)
+ return cls(ids)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for id in self.ids:
+ file.write(struct.pack('!B', len(id)))
+ file.write(id)
+
+class NoDefaultALPNParam(Param):
+ # We don't ever expect to instantiate this class, but we need
+ # a from_value() and a from_wire_parser(), so we just return None
+ # from the class methods when things are OK.
+
+ @classmethod
+ def emptiness(cls):
+ return Emptiness.ALWAYS
+
+ @classmethod
+ def from_value(cls, value):
+ if value is None or value == '':
+ return None
+ else:
+ raise ValueError('no-default-alpn with non-empty value')
+
+ def to_text(self):
+ raise NotImplementedError
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ if parser.remaining() != 0:
+ raise dns.exception.FormError
+ return None
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ raise NotImplementedError
+
+
+class PortParam(Param):
+ def __init__(self, port):
+ self.port = port
+
+ @classmethod
+ def from_value(cls, value):
+ value = int(value)
+ if value < 0 or value > 65535:
+ raise ValueError('port out-of-range')
+ return cls(value)
+
+ def to_text(self):
+ return f'"{self.port}"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ port = parser.get_uint16()
+ return cls(port)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ file.write(struct.pack('!H', self.port))
+
+
+class IPv4HintParam(Param):
+ def __init__(self, addresses):
+ self.addresses = addresses
+
+ @classmethod
+ def from_value(cls, value):
+ addresses = value.split(',')
+ for address in addresses:
+ # check validity
+ dns.ipv4.inet_aton(address)
+ return cls(addresses)
+
+ def to_text(self):
+ return '"' + ','.join(self.addresses) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ addresses = []
+ while parser.remaining() > 0:
+ ip = parser.get_bytes(4)
+ addresses.append(dns.ipv4.inet_ntoa(ip))
+ return cls(addresses)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for address in self.addresses:
+ file.write(dns.ipv4.inet_aton(address))
+
+
+class IPv6HintParam(Param):
+ def __init__(self, addresses):
+ self.addresses = addresses
+
+ @classmethod
+ def from_value(cls, value):
+ addresses = value.split(',')
+ for address in addresses:
+ # check validity
+ dns.ipv6.inet_aton(address)
+ return cls(addresses)
+
+ def to_text(self):
+ return '"' + ','.join(self.addresses) + '"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ addresses = []
+ while parser.remaining() > 0:
+ ip = parser.get_bytes(16)
+ addresses.append(dns.ipv6.inet_ntoa(ip))
+ return cls(addresses)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ for address in self.addresses:
+ file.write(dns.ipv6.inet_aton(address))
+
+
+class ECHConfigParam(Param):
+ def __init__(self, echconfig):
+ self.echconfig = echconfig
+
+ @classmethod
+ def from_value(cls, value):
+ if '\\' in value:
+ raise ValueError('escape in ECHConfig value')
+ value = base64.b64decode(value.encode())
+ return cls(value)
+
+ def to_text(self):
+ b64 = base64.b64encode(self.echconfig).decode('ascii')
+ return f'"{b64}"'
+
+ @classmethod
+ def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
+ value = parser.get_bytes(parser.remaining())
+ return cls(value)
+
+ def to_wire(self, file, origin=None): # pylint: disable=W0613
+ file.write(self.echconfig)
+
+
+_class_for_key = {
+ ParamKey.MANDATORY: MandatoryParam,
+ ParamKey.ALPN: ALPNParam,
+ ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam,
+ ParamKey.PORT: PortParam,
+ ParamKey.IPV4HINT: IPv4HintParam,
+ ParamKey.ECHCONFIG: ECHConfigParam,
+ ParamKey.IPV6HINT: IPv6HintParam,
+}
+
+
+def _validate_and_define(params, key, value):
+ (key, force_generic) = _validate_key(_unescape(key))
+ if key in params:
+ raise SyntaxError(f'duplicate key "{key}"')
+ cls = _class_for_key.get(key, GenericParam)
+ emptiness = cls.emptiness()
+ if value is None:
+ if emptiness == Emptiness.NEVER:
+ raise SyntaxError('value cannot be empty')
+ value = cls.from_value(value)
+ else:
+ if force_generic:
+ value = cls.from_wire_parser(dns.wire.Parser(_unescape(value)))
+ else:
+ value = cls.from_value(value)
+ params[key] = value
+
+
+class SVCBBase(dns.rdata.Rdata):
+
+ """Base class for SVCB-like records"""
+
+ # see: draft-ietf-dnsop-svcb-https-01
+
+ __slots__ = ['priority', 'target', 'params']
+
+ def __init__(self, rdclass, rdtype, priority, target, params):
+ super().__init__(rdclass, rdtype)
+ object.__setattr__(self, 'priority', priority)
+ object.__setattr__(self, 'target', target)
+ object.__setattr__(self, 'params', params)
+ # Make sure any paramater listed as mandatory is present in the
+ # record.
+ mandatory = params.get(ParamKey.MANDATORY)
+ if mandatory:
+ for key in mandatory.keys:
+ # Note we have to say "not in" as we have None as a value
+ # so a get() and a not None test would be wrong.
+ if key not in params:
+ raise ValueError(f'key {key} declared mandatory but not'
+ 'present')
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ target = self.target.choose_relativity(origin, relativize)
+ params = []
+ for key in sorted(self.params.keys()):
+ value = self.params[key]
+ if value is None:
+ params.append(key_to_text(key))
+ else:
+ kv = key_to_text(key) + '=' + value.to_text()
+ params.append(kv)
+ if len(params) > 0:
+ space = ' '
+ else:
+ space = ''
+ return '%d %s%s%s' % (self.priority, target, space, ' '.join(params))
+
+ @classmethod
+ def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
+ relativize_to=None):
+ priority = tok.get_uint16()
+ target = tok.get_name(origin, relativize, relativize_to)
+ params = {}
+ while True:
+ token = tok.get()
+ if token.is_eol_or_eof():
+ tok.unget(token)
+ break
+ if token.ttype != dns.tokenizer.IDENTIFIER:
+ raise SyntaxError('parameter is not an identifier')
+ equals = token.value.find('=')
+ if equals == len(token.value) - 1:
+ # 'key=', so next token should be a quoted string without
+ # any intervening whitespace.
+ key = token.value[:-1]
+ token = tok.get(want_leading=True)
+ if token.ttype != dns.tokenizer.QUOTED_STRING:
+ raise SyntaxError('whitespace after =')
+ value = token.value
+ elif equals > 0:
+ # key=value
+ key = token.value[:equals]
+ value = token.value[equals + 1:]
+ if len(value) == 0:
+ raise SyntaxError('unquoted parameter value cannot '
+ 'be empty')
+ elif equals == 0:
+ # =key
+ raise SyntaxError('parameter cannot start with "="')
+ else:
+ # key
+ key = token.value
+ value = None
+ _validate_and_define(params, key, value)
+ return cls(rdclass, rdtype, priority, target, params)
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ file.write(struct.pack("!H", self.priority))
+ self.target.to_wire(file, compress, origin, canonicalize)
+ for key in sorted(self.params):
+ file.write(struct.pack("!H", key))
+ value = self.params[key]
+ # placeholder for length (or actual length of empty values)
+ file.write(struct.pack("!H", 0))
+ if value is None:
+ continue
+ else:
+ start = file.tell()
+ value.to_wire(file, origin)
+ end = file.tell()
+ assert end - start < 65536
+ file.seek(start - 2)
+ stuff = struct.pack("!H", end - start)
+ file.write(stuff)
+ file.seek(0, io.SEEK_END)
+
+ @classmethod
+ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
+ priority = parser.get_uint16()
+ target = parser.get_name(origin)
+ if priority == 0 and parser.remaining() != 0:
+ raise dns.exception.FormError('parameters in AliasMode')
+ params = {}
+ prior_key = -1
+ while parser.remaining() > 0:
+ key = parser.get_uint16()
+ if key < prior_key:
+ raise dns.exception.FormError('keys not in order')
+ prior_key = key
+ vlen = parser.get_uint16()
+ pcls = _class_for_key.get(key, GenericParam)
+ with parser.restrict_to(vlen):
+ value = pcls.from_wire_parser(parser, origin)
+ params[key] = value
+ return cls(rdclass, rdtype, priority, target, params)
diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst
index bc31f19..823dc86 100644
--- a/doc/whatsnew.rst
+++ b/doc/whatsnew.rst
@@ -22,6 +22,8 @@ What's New in dnspython
* The default EDNS payload size has changed from 1280 to 1232.
+* The SVCB and HTTPS RR types are now supported.
+
2.0.0
-----
diff --git a/tests/test_svcb.py b/tests/test_svcb.py
new file mode 100644
index 0000000..d2f24b2
--- /dev/null
+++ b/tests/test_svcb.py
@@ -0,0 +1,239 @@
+import unittest
+
+import dns.rdata
+
+class SVCBTestCase(unittest.TestCase):
+ def check_valid_inputs(self, inputs):
+ expected = inputs[0]
+ for text in inputs:
+ rr = dns.rdata.from_text('IN', 'SVCB', text)
+ new_text = rr.to_text()
+ self.assertEqual(expected, new_text)
+
+ def check_invalid_inputs(self, inputs):
+ for text in inputs:
+ with self.assertRaises(dns.exception.SyntaxError):
+ dns.rdata.from_text('IN', 'SVCB', text)
+
+ def test_svcb_general_invalid(self):
+ invalid_inputs = (
+ # Duplicate keys
+ "1 . alpn=h2 alpn=h3",
+ "1 . alpn=h2 key1=h3",
+ # Quoted keys
+ "1 . \"alpn=h2\"",
+ # Invalid space
+ "1 . alpn= h2",
+ "1 . alpn =h2",
+ "1 . alpn = h2",
+ "1 . alpn= \"h2\"",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_mandatory(self):
+ valid_inputs = (
+ "1 . mandatory=\"alpn,no-default-alpn\" alpn=\"h2\" no-default-alpn",
+ "1 . mandatory=alpn,no-default-alpn alpn=h2 no-default-alpn",
+ "1 . mandatory=key1,key2 alpn=h2 no-default-alpn",
+ "1 . mandatory=alpn,no-default-alpn key1=\\002h2 key2=\"\"",
+ "1 . mandatory=alpn,no-default-alpn key1=\\002h2 key2",
+ "1 . key0=\\000\\001\\000\\002 alpn=h2 no-default-alpn",
+ "1 . alpn=h2 no-default-alpn mandatory=alpn,no-default-alpn",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ invalid_inputs = (
+ # unknown key
+ "1 . mandatory=foo",
+ # key 0
+ "1 . mandatory=key0",
+ "1 . mandatory=key0,alpn",
+ # missing key
+ "1 . mandatory=alpn",
+ # duplicate
+ "1 . mandatory=alpn,alpn alpn=h2",
+ # invalid escaping
+ "1 . mandatory=\\alpn alpn=h2",
+ # 0 in wire format
+ "1 . key0=\\000\\000",
+ # invalid length in wire format
+ "1 . key0=\\000",
+ # out of order in wire format
+ "1 . key0=\\000\\002\\000\\001 alpn=h2 no-default-alpn",
+ # leading zeros
+ "1 . mandatory=key1,key002 alpn=h2 no-default-alpn",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_alpn(self):
+ valid_inputs_two_items = (
+ "1 . alpn=\"h2,h3\"",
+ "1 . alpn=h2,h3",
+ "1 . alpn=h\\050,h3",
+ "1 . alpn=\"h\\050,h3\"",
+ "1 . alpn=\\h2,h3",
+ "1 . key1=\\002h2\\002h3",
+ )
+ self.check_valid_inputs(valid_inputs_two_items)
+
+ valid_inputs_one_item = (
+ "1 . alpn=\"h2\\,h3\"",
+ "1 . alpn=h2\\,h3",
+ "1 . alpn=h2\\044h3",
+ )
+ self.check_valid_inputs(valid_inputs_one_item)
+
+ invalid_inputs = (
+ "1 . alpn=h2,,h3",
+ "1 . alpn=01234567890abcdef01234567890abcdef01234567890abcdef"
+ "01234567890abcdef01234567890abcdef01234567890abcdef"
+ "01234567890abcdef01234567890abcdef01234567890abcdef"
+ "01234567890abcdef01234567890abcdef01234567890abcdef"
+ "01234567890abcdef01234567890abcdef01234567890abcdef"
+ "01234567890abcdef",
+ "1 . key1=\\000",
+ "1 . key1=\\002x",
+ "1 . alpn=\",h2,h3\"",
+ "1 . alpn=\"h2,h3,\"",
+ "1 . alpn",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_no_default_alpn(self):
+ valid_inputs = (
+ "1 . no-default-alpn",
+ "1 . no-default-alpn=\"\"",
+ "1 . key2",
+ "1 . key2=\"\"",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ invalid_inputs = (
+ "1 . no-default-alpn=foo",
+ "1 . no-default-alpn=",
+ "1 . key2=foo",
+ "1 . key2=",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_port(self):
+ valid_inputs = (
+ "1 . port=\"53\"",
+ "1 . port=53",
+ "1 . key3=\\000\\053",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ invalid_inputs = (
+ "1 . port=",
+ "1 . port=53x",
+ "1 . port=x53",
+ "1 . port=53,54",
+ "1 . port=53\\,54",
+ "1 . key3=\\000",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_echconfig(self):
+ valid_inputs = (
+ "1 . echconfig=\"Zm9vMA==\"",
+ "1 . echconfig=Zm9vMA==",
+ "1 . key5=foo0",
+ "1 . key5=\\102\\111\\111\\048",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ invalid_inputs = (
+ "1 . echconfig=",
+ "1 . echconfig=Zm9vMA",
+ "1 . key5=",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_ipv4hint(self):
+ valid_inputs = (
+ "1 . ipv4hint=\"0.0.0.0,1.1.1.1\"",
+ "1 . ipv4hint=0.0.0.0,1.1.1.1",
+ "1 . key4=\\000\\000\\000\\000\\001\\001\\001\\001",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ invalid_inputs = (
+ "1 . ipv4hint=",
+ "1 . ipv4hint=1234",
+ "1 . ipv4hint=1\\.2.3.4",
+ "1 . ipv4hint=1.2.3.4\\,2.3.4.5",
+ "1 . ipv4hint",
+ "1 . key4=",
+ "1 . key4=123",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_ipv6hint(self):
+ valid_inputs = (
+ "1 . ipv6hint=\"::4,1::\"",
+ "1 . ipv6hint=::4,1::",
+ "1 . key6=\\000\\000\\000\\000\\000\\000\\000\\000"
+ "\\000\\000\\000\\000\\000\\000\\000\\004"
+ "\\000\\001\\000\\000\\000\\000\\000\\000"
+ "\\000\\000\\000\\000\\000\\000\\000\\000",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ invalid_inputs = (
+ "1 . ipv6hint=",
+ "1 . ipv6hint=1234",
+ "1 . ipv6hint=1\\::2",
+ "1 . ipv6hint=::1\\,::2",
+ "1 . ipv6hint",
+ "1 . key6=",
+ "1 . key6=123",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_unknown(self):
+ valid_inputs_one_key = (
+ "1 . key23=\"key45\"",
+ "1 . key23=key45",
+ "1 . key23=key\\052\\053",
+ "1 . key23=\"key\\052\\053\"",
+ "1 . key23=\\107\\101\\121\\052\\053",
+ )
+ self.check_valid_inputs(valid_inputs_one_key)
+
+ valid_inputs_two_keys = (
+ "1 . key24 key48",
+ "1 . key24=\"\" key48",
+ )
+ self.check_valid_inputs(valid_inputs_two_keys)
+
+ invalid_inputs = (
+ "1 . key65536=foo",
+ "1 . key24= key48",
+ )
+ self.check_invalid_inputs(invalid_inputs)
+
+ def test_svcb_wire(self):
+ valid_inputs = (
+ "1 . mandatory=\"alpn,port\" alpn=\"h2\" port=\"257\"",
+ "\\# 24 0001 00 0000000400010003 00010003026832 000300020101",
+ )
+ self.check_valid_inputs(valid_inputs)
+
+ everything = \
+ "100 foo.com. mandatory=\"alpn,port\" alpn=\"h2,h3\" " \
+ " no-default-alpn port=\"12345\" echconfig=\"abcd\" " \
+ " ipv4hint=1.2.3.4,4.3.2.1 ipv6hint=1::2,3::4" \
+ " key12345=\"foo\""
+ rr = dns.rdata.from_text('IN', 'SVCB', everything)
+ rr2 = dns.rdata.from_text('IN', 'SVCB', rr.to_generic().to_text())
+ self.assertEqual(rr, rr2)
+
+ invalid_inputs = (
+ # As above, but the keys are out of order.
+ "\\# 24 0001 00 0000000400010003 000300020101 00010003026832",
+ # As above, but the mandatory keys don't match
+ "\\# 24 0001 00 0000000400010002 000300020101 00010003026832",
+ "\\# 24 0001 00 0000000400010004 000300020101 00010003026832",
+ )
+ self.check_invalid_inputs(invalid_inputs)