summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-06-13 11:40:54 -0700
committerBob Halley <halley@dnspython.org>2020-06-13 11:40:54 -0700
commit92b5b2d4a660a509aba0b818cf1889eb6d76c09e (patch)
tree19fa9ad78c521a2e41237a54637cc40c01006690
parentdfff63e1d4cefc8af15abe37e9b8cce39951ac72 (diff)
downloaddnspython-async.tar.gz
Change parameter order of low_level_address_tuple; add test coverage.async
-rw-r--r--dns/_curio_backend.py4
-rw-r--r--dns/_trio_backend.py4
-rw-r--r--dns/asyncquery.py2
-rw-r--r--dns/inet.py15
-rw-r--r--tests/test_ntoaaton.py15
5 files changed, 31 insertions, 9 deletions
diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py
index 836273b..d5eba68 100644
--- a/dns/_curio_backend.py
+++ b/dns/_curio_backend.py
@@ -77,14 +77,14 @@ class Backend(dns._asyncbackend.Backend):
s = curio.socket.socket(af, socktype, proto)
try:
if source:
- s.bind(_lltuple(af, source))
+ s.bind(_lltuple(source, af))
except Exception:
await s.close()
raise
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
if source:
- source_addr = (_lltuple(af, source))
+ source_addr = _lltuple(source, af)
else:
source_addr = None
async with _maybe_timeout(timeout):
diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py
index 418639c..cfb0e1d 100644
--- a/dns/_trio_backend.py
+++ b/dns/_trio_backend.py
@@ -81,10 +81,10 @@ class Backend(dns._asyncbackend.Backend):
stream = None
try:
if source:
- await s.bind(_lltuple(af, source))
+ await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM:
with _maybe_timeout(timeout):
- await s.connect(_lltuple(af, destination))
+ await s.connect(_lltuple(destination, af))
except Exception:
s.close()
raise
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
index 38141fe..709c246 100644
--- a/dns/asyncquery.py
+++ b/dns/asyncquery.py
@@ -199,7 +199,7 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
- destination = _lltuple(af, (where, port))
+ destination = _lltuple((where, port), af)
await send_udp(s, wire, destination, expiration)
(r, received_time) = await receive_udp(s, destination, expiration,
ignore_unexpected,
diff --git a/dns/inet.py b/dns/inet.py
index 7960e9f..71782ac 100644
--- a/dns/inet.py
+++ b/dns/inet.py
@@ -141,15 +141,22 @@ def is_address(text):
return False
-def low_level_address_tuple(af, high_tuple):
- """Given an address family and a "high-level" address tuple, i.e.
+def low_level_address_tuple(high_tuple, af=None):
+ """Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls.
+
+ If an *af* other than ``None`` is provided, it is assumed the
+ address in the high-level tuple is valid and has that af. If af
+ is ``None``, then af_for_address will be called.
+
"""
address, port = high_tuple
- if af == dns.inet.AF_INET:
+ if af is None:
+ af = af_for_address(address)
+ if af == AF_INET:
return (address, port)
- elif af == dns.inet.AF_INET6:
+ elif af == AF_INET6:
ai_flags = socket.AI_NUMERICHOST
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
return tup
diff --git a/tests/test_ntoaaton.py b/tests/test_ntoaaton.py
index 36107e1..3a72891 100644
--- a/tests/test_ntoaaton.py
+++ b/tests/test_ntoaaton.py
@@ -17,6 +17,7 @@
import unittest
import binascii
+import socket
import dns.exception
import dns.ipv4
@@ -274,5 +275,19 @@ class NtoAAtoNTestCase(unittest.TestCase):
('2001:db8:0:1:1:1:1:q1', False)]:
self.assertEqual(dns.inet.is_address(t), e)
+ def test_low_level_address_tuple(self):
+ t = dns.inet.low_level_address_tuple(('1.2.3.4', 53))
+ self.assertEqual(t, ('1.2.3.4', 53))
+ t = dns.inet.low_level_address_tuple(('2600::1', 53))
+ self.assertEqual(t, ('2600::1', 53, 0, 0))
+ t = dns.inet.low_level_address_tuple(('1.2.3.4', 53), socket.AF_INET)
+ self.assertEqual(t, ('1.2.3.4', 53))
+ t = dns.inet.low_level_address_tuple(('2600::1', 53), socket.AF_INET6)
+ self.assertEqual(t, ('2600::1', 53, 0, 0))
+ def bad():
+ bogus = socket.AF_INET + socket.AF_INET6 + 1
+ t = dns.inet.low_level_address_tuple(('2600::1', 53), bogus)
+ self.assertRaises(NotImplementedError, bad)
+
if __name__ == '__main__':
unittest.main()