diff options
author | Bob Halley <halley@dnspython.org> | 2020-06-13 11:40:54 -0700 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2020-06-13 11:40:54 -0700 |
commit | 92b5b2d4a660a509aba0b818cf1889eb6d76c09e (patch) | |
tree | 19fa9ad78c521a2e41237a54637cc40c01006690 | |
parent | dfff63e1d4cefc8af15abe37e9b8cce39951ac72 (diff) | |
download | dnspython-async.tar.gz |
Change parameter order of low_level_address_tuple; add test coverage.async
-rw-r--r-- | dns/_curio_backend.py | 4 | ||||
-rw-r--r-- | dns/_trio_backend.py | 4 | ||||
-rw-r--r-- | dns/asyncquery.py | 2 | ||||
-rw-r--r-- | dns/inet.py | 15 | ||||
-rw-r--r-- | tests/test_ntoaaton.py | 15 |
5 files changed, 31 insertions, 9 deletions
diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index 836273b..d5eba68 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -77,14 +77,14 @@ class Backend(dns._asyncbackend.Backend): s = curio.socket.socket(af, socktype, proto) try: if source: - s.bind(_lltuple(af, source)) + s.bind(_lltuple(source, af)) except Exception: await s.close() raise return DatagramSocket(s) elif socktype == socket.SOCK_STREAM: if source: - source_addr = (_lltuple(af, source)) + source_addr = _lltuple(source, af) else: source_addr = None async with _maybe_timeout(timeout): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 418639c..cfb0e1d 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -81,10 +81,10 @@ class Backend(dns._asyncbackend.Backend): stream = None try: if source: - await s.bind(_lltuple(af, source)) + await s.bind(_lltuple(source, af)) if socktype == socket.SOCK_STREAM: with _maybe_timeout(timeout): - await s.connect(_lltuple(af, destination)) + await s.connect(_lltuple(destination, af)) except Exception: s.close() raise diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 38141fe..709c246 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -199,7 +199,7 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) - destination = _lltuple(af, (where, port)) + destination = _lltuple((where, port), af) await send_udp(s, wire, destination, expiration) (r, received_time) = await receive_udp(s, destination, expiration, ignore_unexpected, diff --git a/dns/inet.py b/dns/inet.py index 7960e9f..71782ac 100644 --- a/dns/inet.py +++ b/dns/inet.py @@ -141,15 +141,22 @@ def is_address(text): return False -def low_level_address_tuple(af, high_tuple): - """Given an address family and a "high-level" address tuple, i.e. +def low_level_address_tuple(high_tuple, af=None): + """Given a "high-level" address tuple, i.e. an (address, port) return the appropriate "low-level" address tuple suitable for use in socket calls. + + If an *af* other than ``None`` is provided, it is assumed the + address in the high-level tuple is valid and has that af. If af + is ``None``, then af_for_address will be called. + """ address, port = high_tuple - if af == dns.inet.AF_INET: + if af is None: + af = af_for_address(address) + if af == AF_INET: return (address, port) - elif af == dns.inet.AF_INET6: + elif af == AF_INET6: ai_flags = socket.AI_NUMERICHOST ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) return tup diff --git a/tests/test_ntoaaton.py b/tests/test_ntoaaton.py index 36107e1..3a72891 100644 --- a/tests/test_ntoaaton.py +++ b/tests/test_ntoaaton.py @@ -17,6 +17,7 @@ import unittest import binascii +import socket import dns.exception import dns.ipv4 @@ -274,5 +275,19 @@ class NtoAAtoNTestCase(unittest.TestCase): ('2001:db8:0:1:1:1:1:q1', False)]: self.assertEqual(dns.inet.is_address(t), e) + def test_low_level_address_tuple(self): + t = dns.inet.low_level_address_tuple(('1.2.3.4', 53)) + self.assertEqual(t, ('1.2.3.4', 53)) + t = dns.inet.low_level_address_tuple(('2600::1', 53)) + self.assertEqual(t, ('2600::1', 53, 0, 0)) + t = dns.inet.low_level_address_tuple(('1.2.3.4', 53), socket.AF_INET) + self.assertEqual(t, ('1.2.3.4', 53)) + t = dns.inet.low_level_address_tuple(('2600::1', 53), socket.AF_INET6) + self.assertEqual(t, ('2600::1', 53, 0, 0)) + def bad(): + bogus = socket.AF_INET + socket.AF_INET6 + 1 + t = dns.inet.low_level_address_tuple(('2600::1', 53), bogus) + self.assertRaises(NotImplementedError, bad) + if __name__ == '__main__': unittest.main() |