diff options
| author | Brian Wellington <bwelling@xbill.org> | 2020-05-20 18:48:31 -0700 |
|---|---|---|
| committer | Brian Wellington <bwelling@xbill.org> | 2020-05-20 18:48:31 -0700 |
| commit | 0fa0d197f9cd978d16c7d2ff73f6319173d7ff45 (patch) | |
| tree | 3a103be98badb589a55d5a67e693044deadd883a /dns/query.py | |
| parent | a32bac9aca0bc5303700878f5b699fe69ce04193 (diff) | |
| download | dnspython-0fa0d197f9cd978d16c7d2ff73f6319173d7ff45.tar.gz | |
Use context managers in the query methods.
Diffstat (limited to 'dns/query.py')
| -rw-r--r-- | dns/query.py | 89 |
1 files changed, 33 insertions, 56 deletions
diff --git a/dns/query.py b/dns/query.py index 080a66d..a21bd65 100644 --- a/dns/query.py +++ b/dns/query.py @@ -463,10 +463,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) - s = socket_factory(af, socket.SOCK_DGRAM, 0) - received_time = None - sent_time = None - try: + with socket_factory(af, socket.SOCK_DGRAM, 0) as s: expiration = _compute_expiration(timeout) s.setblocking(0) if source is not None: @@ -475,16 +472,10 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, (r, received_time) = receive_udp(s, destination, expiration, ignore_unexpected, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) - finally: - if sent_time is None or received_time is None: - response_time = 0 - else: - response_time = received_time - sent_time - s.close() - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + r.time = received_time - sent_time + if not q.is_response(r): + raise BadResponse + return r def _net_read(sock, count, expiration): @@ -637,10 +628,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) - s = socket_factory(af, socket.SOCK_STREAM, 0) - begin_time = None - received_time = None - try: + with socket_factory(af, socket.SOCK_STREAM, 0) as s: expiration = _compute_expiration(timeout) s.setblocking(0) begin_time = time.time() @@ -650,16 +638,21 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) - finally: - if begin_time is None or received_time is None: - response_time = 0 - else: - response_time = received_time - begin_time - s.close() - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + + +def _tls_handshake(s, expiration): + while True: + try: + s.do_handshake() + return + except ssl.SSLWantReadError: + _wait_for_readable(s, expiration) + except ssl.SSLWantWriteError: + _wait_for_writable(s, expiration) def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0, @@ -708,43 +701,27 @@ def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0, wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) - s = socket_factory(af, socket.SOCK_STREAM, 0) - begin_time = None - received_time = None - try: + if ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + ssl_context.check_hostname = False + with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0), + do_handshake_on_connect=False, + server_hostname=server_hostname) as s: expiration = _compute_expiration(timeout) s.setblocking(0) begin_time = time.time() if source is not None: s.bind(source) _connect(s, destination, expiration) - if ssl_context is None: - ssl_context = ssl.create_default_context() - if server_hostname is None: - ssl_context.check_hostname = False - s = ssl_context.wrap_socket(s, do_handshake_on_connect=False, - server_hostname=server_hostname) - while True: - try: - s.do_handshake() - break - except ssl.SSLWantReadError: - _wait_for_readable(s, expiration) - except ssl.SSLWantWriteError: - _wait_for_writable(s, expiration) + _tls_handshake(s, expiration) send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) - finally: - if begin_time is None or received_time is None: - response_time = 0 - else: - response_time = received_time - begin_time - s.close() - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, |
