diff options
| author | Bob Halley <halley@dnspython.org> | 2020-07-08 15:47:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-07-08 15:47:54 -0700 |
| commit | 6070cffcadf72508c19e60cd6d4d24601c1ea7aa (patch) | |
| tree | ae18b61e6f6158c19c2cbef5a2b0366a6fb658c1 | |
| parent | bc329ed22eba8dc4dd6b2605d9ba6bd789c2026c (diff) | |
| parent | 5250399a9aeecab9dbf40a65164faf7290a08f5b (diff) | |
| download | dnspython-6070cffcadf72508c19e60cd6d4d24601c1ea7aa.tar.gz | |
Merge pull request #533 from bwelling/receive_queries
Receive queries
| -rw-r--r-- | dns/_asyncio_backend.py | 6 | ||||
| -rw-r--r-- | dns/_curio_backend.py | 6 | ||||
| -rw-r--r-- | dns/_trio_backend.py | 9 | ||||
| -rw-r--r-- | dns/asyncquery.py | 36 | ||||
| -rw-r--r-- | dns/query.py | 44 | ||||
| -rw-r--r-- | tests/test_async.py | 23 | ||||
| -rw-r--r-- | tests/test_query.py | 12 |
7 files changed, 104 insertions, 32 deletions
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index ba7c2e7..3af34ff 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -75,6 +75,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.transport.get_extra_info('peername') + async def getsockname(self): + return self.transport.get_extra_info('sockname') + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, af, reader, writer): @@ -102,6 +105,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.writer.get_extra_info('peername') + async def getsockname(self): + return self.writer.get_extra_info('sockname') + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index dca966d..300e1b8 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): @@ -65,6 +68,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 0f1378f..92ea879 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, stream, tls=False): @@ -69,6 +72,12 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): else: return self.stream.socket.getpeername() + async def getsockname(self): + if self.tls: + return self.stream.transport_stream.socket.getsockname() + else: + return self.stream.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 4afe7bc..b792648 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -30,8 +30,7 @@ import dns.rcode import dns.rdataclass import dns.rdatatype -from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \ - BadResponse, ssl +from dns.query import _compute_times, _matches_destination, BadResponse, ssl # for brevity @@ -87,7 +86,7 @@ async def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -async def receive_udp(sock, destination, expiration=None, +async def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -96,7 +95,9 @@ async def receive_udp(sock, destination, expiration=None, *sock*, a ``dns.asyncbackend.DatagramSocket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -121,27 +122,22 @@ async def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received + message, the received time, and the address where the message arrived from. """ wire = b'' while 1: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + return (r, received_time, from_address) async def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, @@ -202,12 +198,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, stuple = _source_tuple(af, source, source_port) s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) await send_udp(s, wire, destination, expiration) - (r, received_time) = await receive_udp(s, destination, expiration, - ignore_unexpected, - one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing, - raise_on_truncation) + (r, received_time, _) = await receive_udp(s, destination, expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing, + raise_on_truncation) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse diff --git a/dns/query.py b/dns/query.py index 13c8246..7df565d 100644 --- a/dns/query.py +++ b/dns/query.py @@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] +def _matches_destination(af, from_address, destination, ignore_unexpected): + # Check that from_address is appropriate for a response to a query + # sent to destination. + if not destination: + return True + if _addresses_equal(af, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + return True + elif ignore_unexpected: + return False + raise UnexpectedSource(f'got a response from {from_address} instead of ' + f'{destination}') + + def _destination_and_source(where, port, source, source_port, where_must_be_address=True): # Apply defaults and compute destination and source tuples @@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -def receive_udp(sock, destination, expiration=None, +def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None, *sock*, a ``socket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` + tuple of the received message and the received time. + + If *destination* is ``None``, returns a + ``(dns.message.Message, float, tuple)`` + tuple of the received message, the received time, and the address where + the message arrived from. """ wire = b'' while 1: _wait_for_readable(sock, expiration) (wire, from_address) = sock.recvfrom(65535) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, diff --git a/tests/test_async.py b/tests/test_async.py index 2d25434..db108c8 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -277,6 +277,8 @@ class AsyncTests(unittest.TestCase): socket.SOCK_STREAM, 0, None, (address, 53)) as s: + # for basic coverage + await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.tcp(q, address, sock=s) response = self.async_run(run) @@ -315,6 +317,8 @@ class AsyncTests(unittest.TestCase): None, (address, 853), None, ssl_context, None) as s: + # for basic coverage + await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.tls(q, '8.8.8.8', sock=s) response = self.async_run(run) @@ -343,6 +347,25 @@ class AsyncTests(unittest.TestCase): (_, tcp) = self.async_run(run) self.assertFalse(tcp) + def testUDPReceiveQuery(self): + async def run(): + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, + source=('127.0.0.1', 0)) as listener: + listener_address = await listener.getsockname() + async with await self.backend.make_socket( + socket.AF_INET, socket.SOCK_DGRAM, + source=('127.0.0.1', 0)) as sender: + sender_address = await sender.getsockname() + q = dns.message.make_query('dns.google', dns.rdatatype.A) + await dns.asyncquery.send_udp(sender, q, listener_address) + expiration = time.time() + 2 + (_, _, recv_address) = await dns.asyncquery.receive_udp( + listener, expiration=expiration) + return (sender_address, recv_address) + (sender_address, recv_address) = self.async_run(run) + self.assertEqual(sender_address, recv_address) + def testUDPReceiveTimeout(self): async def arun(): async with await self.backend.make_socket(socket.AF_INET, diff --git a/tests/test_query.py b/tests/test_query.py index f1ec55c..498128d 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -191,6 +191,18 @@ class QueryTests(unittest.TestCase): (_, tcp) = dns.query.udp_with_fallback(q, address) self.assertFalse(tcp) + def testUDPReceiveQuery(self): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener: + listener.bind(('127.0.0.1', 0)) + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender: + sender.bind(('127.0.0.1', 0)) + q = dns.message.make_query('dns.google', dns.rdatatype.A) + dns.query.send_udp(sender, q, listener.getsockname()) + expiration = time.time() + 2 + (q, _, addr) = dns.query.receive_udp(listener, + expiration=expiration) + self.assertEqual(addr, sender.getsockname()) + # for brevity _d_and_s = dns.query._destination_and_source |
