summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-07-08 15:47:54 -0700
committerGitHub <noreply@github.com>2020-07-08 15:47:54 -0700
commit6070cffcadf72508c19e60cd6d4d24601c1ea7aa (patch)
treeae18b61e6f6158c19c2cbef5a2b0366a6fb658c1
parentbc329ed22eba8dc4dd6b2605d9ba6bd789c2026c (diff)
parent5250399a9aeecab9dbf40a65164faf7290a08f5b (diff)
downloaddnspython-6070cffcadf72508c19e60cd6d4d24601c1ea7aa.tar.gz
Merge pull request #533 from bwelling/receive_queries
Receive queries
-rw-r--r--dns/_asyncio_backend.py6
-rw-r--r--dns/_curio_backend.py6
-rw-r--r--dns/_trio_backend.py9
-rw-r--r--dns/asyncquery.py36
-rw-r--r--dns/query.py44
-rw-r--r--tests/test_async.py23
-rw-r--r--tests/test_query.py12
7 files changed, 104 insertions, 32 deletions
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py
index ba7c2e7..3af34ff 100644
--- a/dns/_asyncio_backend.py
+++ b/dns/_asyncio_backend.py
@@ -75,6 +75,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getpeername(self):
return self.transport.get_extra_info('peername')
+ async def getsockname(self):
+ return self.transport.get_extra_info('sockname')
+
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, af, reader, writer):
@@ -102,6 +105,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
async def getpeername(self):
return self.writer.get_extra_info('peername')
+ async def getsockname(self):
+ return self.writer.get_extra_info('sockname')
+
class Backend(dns._asyncbackend.Backend):
def name(self):
diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py
index dca966d..300e1b8 100644
--- a/dns/_curio_backend.py
+++ b/dns/_curio_backend.py
@@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getpeername(self):
return self.socket.getpeername()
+ async def getsockname(self):
+ return self.socket.getsockname()
+
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
@@ -65,6 +68,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
async def getpeername(self):
return self.socket.getpeername()
+ async def getsockname(self):
+ return self.socket.getsockname()
+
class Backend(dns._asyncbackend.Backend):
def name(self):
diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py
index 0f1378f..92ea879 100644
--- a/dns/_trio_backend.py
+++ b/dns/_trio_backend.py
@@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getpeername(self):
return self.socket.getpeername()
+ async def getsockname(self):
+ return self.socket.getsockname()
+
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, stream, tls=False):
@@ -69,6 +72,12 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
else:
return self.stream.socket.getpeername()
+ async def getsockname(self):
+ if self.tls:
+ return self.stream.transport_stream.socket.getsockname()
+ else:
+ return self.stream.socket.getsockname()
+
class Backend(dns._asyncbackend.Backend):
def name(self):
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
index 4afe7bc..b792648 100644
--- a/dns/asyncquery.py
+++ b/dns/asyncquery.py
@@ -30,8 +30,7 @@ import dns.rcode
import dns.rdataclass
import dns.rdatatype
-from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \
- BadResponse, ssl
+from dns.query import _compute_times, _matches_destination, BadResponse, ssl
# for brevity
@@ -87,7 +86,7 @@ async def send_udp(sock, what, destination, expiration=None):
return (n, sent_time)
-async def receive_udp(sock, destination, expiration=None,
+async def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
@@ -96,7 +95,9 @@ async def receive_udp(sock, destination, expiration=None,
*sock*, a ``dns.asyncbackend.DatagramSocket``.
*destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
+ of the socket, specifying where the message is expected to arrive from.
+ When receiving a response, this would be where the associated query was
+ sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
@@ -121,27 +122,22 @@ async def receive_udp(sock, destination, expiration=None,
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
- Returns a ``(dns.message.Message, float)`` tuple of the received message
- and the received time.
+ Returns a ``(dns.message.Message, float, tuple)`` tuple of the received
+ message, the received time, and the address where the message arrived from.
"""
wire = b''
while 1:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
- if _addresses_equal(sock.family, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
+ if _matches_destination(sock.family, from_address, destination,
+ ignore_unexpected):
break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
- return (r, received_time)
+ return (r, received_time, from_address)
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
@@ -202,12 +198,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
stuple = _source_tuple(af, source, source_port)
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
await send_udp(s, wire, destination, expiration)
- (r, received_time) = await receive_udp(s, destination, expiration,
- ignore_unexpected,
- one_rr_per_rrset,
- q.keyring, q.mac,
- ignore_trailing,
- raise_on_truncation)
+ (r, received_time, _) = await receive_udp(s, destination, expiration,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ q.keyring, q.mac,
+ ignore_trailing,
+ raise_on_truncation)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
diff --git a/dns/query.py b/dns/query.py
index 13c8246..7df565d 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2):
return n1 == n2 and a1[1:] == a2[1:]
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+ # Check that from_address is appropriate for a response to a query
+ # sent to destination.
+ if not destination:
+ return True
+ if _addresses_equal(af, from_address, destination) or \
+ (dns.inet.is_multicast(destination[0]) and
+ from_address[1:] == destination[1:]):
+ return True
+ elif ignore_unexpected:
+ return False
+ raise UnexpectedSource(f'got a response from {from_address} instead of '
+ f'{destination}')
+
+
def _destination_and_source(where, port, source, source_port,
where_must_be_address=True):
# Apply defaults and compute destination and source tuples
@@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None):
return (n, sent_time)
-def receive_udp(sock, destination, expiration=None,
+def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
@@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None,
*sock*, a ``socket``.
*destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
+ of the socket, specifying where the message is expected to arrive from.
+ When receiving a response, this would be where the associated query was
+ sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
@@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None,
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
- Returns a ``(dns.message.Message, float)`` tuple of the received message
- and the received time.
+ If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
+ tuple of the received message and the received time.
+
+ If *destination* is ``None``, returns a
+ ``(dns.message.Message, float, tuple)``
+ tuple of the received message, the received time, and the address where
+ the message arrived from.
"""
wire = b''
while 1:
_wait_for_readable(sock, expiration)
(wire, from_address) = sock.recvfrom(65535)
- if _addresses_equal(sock.family, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
+ if _matches_destination(sock.family, from_address, destination,
+ ignore_unexpected):
break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
- return (r, received_time)
+ if destination:
+ return (r, received_time)
+ else:
+ return (r, received_time, from_address)
def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
diff --git a/tests/test_async.py b/tests/test_async.py
index 2d25434..db108c8 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -277,6 +277,8 @@ class AsyncTests(unittest.TestCase):
socket.SOCK_STREAM, 0,
None,
(address, 53)) as s:
+ # for basic coverage
+ await s.getsockname()
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.tcp(q, address, sock=s)
response = self.async_run(run)
@@ -315,6 +317,8 @@ class AsyncTests(unittest.TestCase):
None,
(address, 853), None,
ssl_context, None) as s:
+ # for basic coverage
+ await s.getsockname()
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.tls(q, '8.8.8.8', sock=s)
response = self.async_run(run)
@@ -343,6 +347,25 @@ class AsyncTests(unittest.TestCase):
(_, tcp) = self.async_run(run)
self.assertFalse(tcp)
+ def testUDPReceiveQuery(self):
+ async def run():
+ async with await self.backend.make_socket(
+ socket.AF_INET, socket.SOCK_DGRAM,
+ source=('127.0.0.1', 0)) as listener:
+ listener_address = await listener.getsockname()
+ async with await self.backend.make_socket(
+ socket.AF_INET, socket.SOCK_DGRAM,
+ source=('127.0.0.1', 0)) as sender:
+ sender_address = await sender.getsockname()
+ q = dns.message.make_query('dns.google', dns.rdatatype.A)
+ await dns.asyncquery.send_udp(sender, q, listener_address)
+ expiration = time.time() + 2
+ (_, _, recv_address) = await dns.asyncquery.receive_udp(
+ listener, expiration=expiration)
+ return (sender_address, recv_address)
+ (sender_address, recv_address) = self.async_run(run)
+ self.assertEqual(sender_address, recv_address)
+
def testUDPReceiveTimeout(self):
async def arun():
async with await self.backend.make_socket(socket.AF_INET,
diff --git a/tests/test_query.py b/tests/test_query.py
index f1ec55c..498128d 100644
--- a/tests/test_query.py
+++ b/tests/test_query.py
@@ -191,6 +191,18 @@ class QueryTests(unittest.TestCase):
(_, tcp) = dns.query.udp_with_fallback(q, address)
self.assertFalse(tcp)
+ def testUDPReceiveQuery(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
+ listener.bind(('127.0.0.1', 0))
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
+ sender.bind(('127.0.0.1', 0))
+ q = dns.message.make_query('dns.google', dns.rdatatype.A)
+ dns.query.send_udp(sender, q, listener.getsockname())
+ expiration = time.time() + 2
+ (q, _, addr) = dns.query.receive_udp(listener,
+ expiration=expiration)
+ self.assertEqual(addr, sender.getsockname())
+
# for brevity
_d_and_s = dns.query._destination_and_source