summaryrefslogtreecommitdiff
path: root/dns/query.py
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-05-20 18:48:31 -0700
committerBrian Wellington <bwelling@xbill.org>2020-05-20 18:48:31 -0700
commit0fa0d197f9cd978d16c7d2ff73f6319173d7ff45 (patch)
tree3a103be98badb589a55d5a67e693044deadd883a /dns/query.py
parenta32bac9aca0bc5303700878f5b699fe69ce04193 (diff)
downloaddnspython-0fa0d197f9cd978d16c7d2ff73f6319173d7ff45.tar.gz
Use context managers in the query methods.
Diffstat (limited to 'dns/query.py')
-rw-r--r--dns/query.py89
1 files changed, 33 insertions, 56 deletions
diff --git a/dns/query.py b/dns/query.py
index 080a66d..a21bd65 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -463,10 +463,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket_factory(af, socket.SOCK_DGRAM, 0)
- received_time = None
- sent_time = None
- try:
+ with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
if source is not None:
@@ -475,16 +472,10 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
(r, received_time) = receive_udp(s, destination, expiration,
ignore_unexpected, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
- finally:
- if sent_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - sent_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
+ r.time = received_time - sent_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
def _net_read(sock, count, expiration):
@@ -637,10 +628,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket_factory(af, socket.SOCK_STREAM, 0)
- begin_time = None
- received_time = None
- try:
+ with socket_factory(af, socket.SOCK_STREAM, 0) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
begin_time = time.time()
@@ -650,16 +638,21 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
- finally:
- if begin_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - begin_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
+def _tls_handshake(s, expiration):
+ while True:
+ try:
+ s.do_handshake()
+ return
+ except ssl.SSLWantReadError:
+ _wait_for_readable(s, expiration)
+ except ssl.SSLWantWriteError:
+ _wait_for_writable(s, expiration)
def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
@@ -708,43 +701,27 @@ def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket_factory(af, socket.SOCK_STREAM, 0)
- begin_time = None
- received_time = None
- try:
+ if ssl_context is None:
+ ssl_context = ssl.create_default_context()
+ if server_hostname is None:
+ ssl_context.check_hostname = False
+ with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
+ do_handshake_on_connect=False,
+ server_hostname=server_hostname) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
begin_time = time.time()
if source is not None:
s.bind(source)
_connect(s, destination, expiration)
- if ssl_context is None:
- ssl_context = ssl.create_default_context()
- if server_hostname is None:
- ssl_context.check_hostname = False
- s = ssl_context.wrap_socket(s, do_handshake_on_connect=False,
- server_hostname=server_hostname)
- while True:
- try:
- s.do_handshake()
- break
- except ssl.SSLWantReadError:
- _wait_for_readable(s, expiration)
- except ssl.SSLWantWriteError:
- _wait_for_writable(s, expiration)
+ _tls_handshake(s, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
- finally:
- if begin_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - begin_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,