diff options
author | Jeff Forcier <jeff@bitprophet.org> | 2017-09-11 12:10:01 -0700 |
---|---|---|
committer | Jeff Forcier <jeff@bitprophet.org> | 2017-09-11 12:10:01 -0700 |
commit | 27e7ba97c5713832bd4ce23df83ba9909d20c384 (patch) | |
tree | ad39a9b56bcf87fd0b69ba0979b5297ec93d8cef | |
parent | 24eb8cca184c9474577636afe7a99dbd9da6e814 (diff) | |
download | paramiko-27e7ba97c5713832bd4ce23df83ba9909d20c384.tar.gz |
Overhaul all appropriate lock use to use contextmanagers
Squeaky clean! Remotely possible a few threading bugs
got fixed here too, but seems unlikely (reasonably sure extraneous
lock releases doesn't hurt anything)
-rw-r--r-- | paramiko/auth_handler.py | 35 | ||||
-rw-r--r-- | paramiko/buffered_pipe.py | 36 | ||||
-rw-r--r-- | paramiko/channel.py | 60 | ||||
-rw-r--r-- | paramiko/packet.py | 5 | ||||
-rw-r--r-- | paramiko/sftp_client.py | 10 | ||||
-rw-r--r-- | paramiko/transport.py | 110 | ||||
-rw-r--r-- | paramiko/util.py | 5 | ||||
-rw-r--r-- | sites/www/changelog.rst | 2 | ||||
-rw-r--r-- | tests/loop.py | 20 |
9 files changed, 58 insertions, 225 deletions
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index d8275ee0..804ca414 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -95,84 +95,63 @@ class AuthHandler (object): return self.username def auth_none(self, username, event): - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'none' self.username = username self._request_auth() - finally: - self.transport.lock.release() def auth_publickey(self, username, key, event): - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'publickey' self.username = username self.private_key = key self._request_auth() - finally: - self.transport.lock.release() def auth_pkcs11(self, username, pkcs11session, event): - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'publickey' self.username = username self.pkcs11session = pkcs11session self._request_auth() - finally: - self.transport.lock.release() def auth_password(self, username, password, event): - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'password' self.username = username self.password = password self._request_auth() - finally: - self.transport.lock.release() def auth_interactive(self, username, handler, event, submethods=''): """ response_list = handler(title, instructions, prompt_list) """ - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'keyboard-interactive' self.username = username self.interactive_handler = handler self.submethods = submethods self._request_auth() - finally: - self.transport.lock.release() def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds, event): - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'gssapi-with-mic' self.username = username self.gss_host = gss_host self.gss_deleg_creds = gss_deleg_creds self._request_auth() - finally: - self.transport.lock.release() def auth_gssapi_keyex(self, username, event): - self.transport.lock.acquire() - try: + with self.transport.lock: self.auth_event = event self.auth_method = 'gssapi-keyex' self.username = username self._request_auth() - finally: - self.transport.lock.release() def abort(self): if self.auth_event is not None: diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py index d9f5149d..83299108 100644 --- a/paramiko/buffered_pipe.py +++ b/paramiko/buffered_pipe.py @@ -70,8 +70,7 @@ class BufferedPipe (object): :param threading.Event event: the event to set/clear """ - self._lock.acquire() - try: + with self._lock: self._event = event # Make sure the event starts in `set` state if we appear to already # be closed; otherwise, if we start in `clear` state & are closed, @@ -82,8 +81,6 @@ class BufferedPipe (object): event.set() else: event.clear() - finally: - self._lock.release() def feed(self, data): """ @@ -92,14 +89,11 @@ class BufferedPipe (object): :param data: the data to add, as a ``str`` or ``bytes`` """ - self._lock.acquire() - try: + with self._lock: if self._event is not None: self._event.set() self._buffer_frombytes(b(data)) self._cv.notifyAll() - finally: - self._lock.release() def read_ready(self): """ @@ -111,13 +105,10 @@ class BufferedPipe (object): ``True`` if a `read` call would immediately return at least one byte; ``False`` otherwise. """ - self._lock.acquire() - try: + with self._lock: if len(self._buffer) == 0: return False return True - finally: - self._lock.release() def read(self, nbytes, timeout=None): """ @@ -141,8 +132,7 @@ class BufferedPipe (object): before that timeout """ out = bytes() - self._lock.acquire() - try: + with self._lock: if len(self._buffer) == 0: if self._closed: return out @@ -168,9 +158,6 @@ class BufferedPipe (object): else: out = self._buffer_tobytes(nbytes) del self._buffer[:nbytes] - finally: - self._lock.release() - return out def empty(self): @@ -181,29 +168,23 @@ class BufferedPipe (object): any data that was in the buffer prior to clearing it out, as a `str` """ - self._lock.acquire() - try: + with self._lock: out = self._buffer_tobytes() del self._buffer[:] if (self._event is not None) and not self._closed: self._event.clear() return out - finally: - self._lock.release() def close(self): """ Close this pipe object. Future calls to `read` after the buffer has been emptied will return immediately with an empty string. """ - self._lock.acquire() - try: + with self._lock: self._closed = True self._cv.notifyAll() if self._event is not None: self._event.set() - finally: - self._lock.release() def __len__(self): """ @@ -211,8 +192,5 @@ class BufferedPipe (object): :return: number (`int`) of bytes buffered """ - self._lock.acquire() - try: + with self._lock: return len(self._buffer) - finally: - self._lock.release() diff --git a/paramiko/channel.py b/paramiko/channel.py index c6016a0e..ceb120ea 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -554,15 +554,12 @@ class Channel (ClosingContextManager): .. versionadded:: 1.1 """ data = bytes() - self.lock.acquire() - try: + with self.lock: old = self.combine_stderr self.combine_stderr = combine if combine and not old: # copy old stderr buffer into primary buffer data = self.in_stderr_buffer.empty() - finally: - self.lock.release() if len(data) > 0: self._feed(data) return old @@ -635,8 +632,7 @@ class Channel (ClosingContextManager): is flushed). Channels are automatically closed when their `.Transport` is closed or when they are garbage collected. """ - self.lock.acquire() - try: + with self.lock: # only close the pipe when the user explicitly closes the channel. # otherwise they will get unpleasant surprises. (and do it before # checking self.closed, since the remote host may have already @@ -648,8 +644,6 @@ class Channel (ClosingContextManager): if not self.active or self.closed: return msgs = self._close_internal() - finally: - self.lock.release() for m in msgs: if m is not None: self.transport._send_user_message(m) @@ -756,13 +750,10 @@ class Channel (ClosingContextManager): ``True`` if a `send` call on this channel would immediately succeed or fail """ - self.lock.acquire() - try: + with self.lock: if self.closed or self.eof_sent: return True return self.out_window_size > 0 - finally: - self.lock.release() def send(self, s): """ @@ -897,8 +888,7 @@ class Channel (ClosingContextManager): .. warning:: This method causes channel reads to be slightly less efficient. """ - self.lock.acquire() - try: + with self.lock: if self._pipe is not None: return self._pipe.fileno() # create the pipe and feed in any existing data @@ -907,8 +897,6 @@ class Channel (ClosingContextManager): self.in_buffer.set_event(p1) self.in_stderr_buffer.set_event(p2) return self._pipe.fileno() - finally: - self.lock.release() def shutdown(self, how): """ @@ -925,11 +913,8 @@ class Channel (ClosingContextManager): # feign "read" shutdown self.eof_received = 1 if (how == 1) or (how == 2): - self.lock.acquire() - try: + with self.lock: m = self._send_eof() - finally: - self.lock.release() if m is not None: self.transport._send_user_message(m) @@ -994,11 +979,8 @@ class Channel (ClosingContextManager): return def _request_failed(self, m): - self.lock.acquire() - try: + with self.lock: msgs = self._close_internal() - finally: - self.lock.release() for m in msgs: if m is not None: self.transport._send_user_message(m) @@ -1027,14 +1009,11 @@ class Channel (ClosingContextManager): def _window_adjust(self, m): nbytes = m.get_int() - self.lock.acquire() - try: + with self.lock: if self.ultra_debug: self._log(DEBUG, 'window up %d' % nbytes) self.out_window_size += nbytes self.out_buffer_cv.notifyAll() - finally: - self.lock.release() def _handle_request(self, m): key = m.get_text() @@ -1134,25 +1113,19 @@ class Channel (ClosingContextManager): self.transport._send_user_message(m) def _handle_eof(self, m): - self.lock.acquire() - try: + with self.lock: if not self.eof_received: self.eof_received = True self.in_buffer.close() self.in_stderr_buffer.close() if self._pipe is not None: self._pipe.set_forever() - finally: - self.lock.release() self._log(DEBUG, 'EOF received (%s)', self._name) def _handle_close(self, m): - self.lock.acquire() - try: + with self.lock: msgs = self._close_internal() self.transport._unlink_channel(self.chanid) - finally: - self.lock.release() for m in msgs: if m is not None: self.transport._send_user_message(m) @@ -1161,8 +1134,7 @@ class Channel (ClosingContextManager): def _send(self, s, m): size = len(s) - self.lock.acquire() - try: + with self.lock: if self.closed: # this doesn't seem useful, but it is the documented behavior # of Socket @@ -1172,8 +1144,6 @@ class Channel (ClosingContextManager): # eof or similar return 0 m.add_string(s[:size]) - finally: - self.lock.release() # Note: We release self.lock before calling _send_user_message. # Otherwise, we can deadlock during re-keying. self.transport._send_user_message(m) @@ -1237,16 +1207,12 @@ class Channel (ClosingContextManager): # still signal the close! if self.closed: return - self.lock.acquire() - try: + with self.lock: self._set_closed() self.transport._unlink_channel(self.chanid) - finally: - self.lock.release() def _check_add_window(self, n): - self.lock.acquire() - try: + with self.lock: if self.closed or self.eof_received or not self.active: return 0 if self.ultra_debug: @@ -1259,8 +1225,6 @@ class Channel (ClosingContextManager): out = self.in_window_sofar self.in_window_sofar = 0 return out - finally: - self.lock.release() def _wait_for_send_window(self, size): """ diff --git a/paramiko/packet.py b/paramiko/packet.py index 95a26c6e..391559c8 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -370,8 +370,7 @@ class Packetizer (object): else: cmd_name = '$%x' % cmd orig_len = len(data) - self.__write_lock.acquire() - try: + with self.__write_lock: if self.__compress_engine_out is not None: data = self.__compress_engine_out(data) packet = self._build_packet(data) @@ -409,8 +408,6 @@ class Packetizer (object): self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() - finally: - self.__write_lock.release() def read_message(self): """ diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 51c02365..f7b093ab 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -775,8 +775,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): def _async_request(self, fileobj, t, *arg): # this method may be called from other threads (prefetch) - self._lock.acquire() - try: + with self._lock: msg = Message() msg.add_int(self.request_number) for item in arg: @@ -793,8 +792,6 @@ class SFTPClient(BaseSFTP, ClosingContextManager): num = self.request_number self._expecting[num] = fileobj self.request_number += 1 - finally: - self._lock.release() self._send_packet(t, msg) return num @@ -806,8 +803,7 @@ class SFTPClient(BaseSFTP, ClosingContextManager): raise SSHException('Server connection dropped: %s' % str(e)) msg = Message(data) num = msg.get_int() - self._lock.acquire() - try: + with self._lock: if num not in self._expecting: # might be response for a file that was closed before # responses came back @@ -818,8 +814,6 @@ class SFTPClient(BaseSFTP, ClosingContextManager): continue fileobj = self._expecting[num] del self._expecting[num] - finally: - self._lock.release() if num == waitfor: # synchronous if t == CMD_STATUS: diff --git a/paramiko/transport.py b/paramiko/transport.py index d782415b..500b1dc2 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -808,8 +808,7 @@ class Transport(threading.Thread, ClosingContextManager): if not self.active: raise SSHException('SSH session not active') timeout = 3600 if timeout is None else timeout - self.lock.acquire() - try: + with self.lock: window_size = self._sanitize_window_size(window_size) max_packet_size = self._sanitize_packet_size(max_packet_size) chanid = self._next_channel() @@ -833,8 +832,6 @@ class Transport(threading.Thread, ClosingContextManager): self.channels_seen[chanid] = True chan._set_transport(self) chan._set_window(window_size, max_packet_size) - finally: - self.lock.release() self._send_user_message(m) start_ts = time.time() while True: @@ -1040,8 +1037,7 @@ class Transport(threading.Thread, ClosingContextManager): seconds to wait for a channel, or ``None`` to wait forever :return: a new `.Channel` opened by the client """ - self.lock.acquire() - try: + with self.lock: if len(self.server_accepts) > 0: chan = self.server_accepts.pop(0) else: @@ -1051,8 +1047,6 @@ class Transport(threading.Thread, ClosingContextManager): else: # timeout chan = None - finally: - self.lock.release() return chan def connect( @@ -1159,13 +1153,10 @@ class Transport(threading.Thread, ClosingContextManager): .. versionadded:: 1.1 """ - self.lock.acquire() - try: + with self.lock: e = self.saved_exception self.saved_exception = None return e - finally: - self.lock.release() def set_subsystem_handler(self, name, handler, *larg, **kwarg): """ @@ -1181,11 +1172,8 @@ class Transport(threading.Thread, ClosingContextManager): :param handler: subclass of `.SubsystemHandler` that handles this subsystem. """ - try: - self.lock.acquire() + with self.lock: self.subsystem_table[name] = (handler, larg, kwarg) - finally: - self.lock.release() def is_authenticated(self): """ @@ -1783,12 +1771,9 @@ class Transport(threading.Thread, ClosingContextManager): self._x11_handler = handler def _queue_incoming_channel(self, channel): - self.lock.acquire() - try: + with self.lock: self.server_accepts.append(channel) self.server_accept_cv.notify() - finally: - self.lock.release() def _sanitize_window_size(self, window_size): if window_size is None: @@ -1915,11 +1900,8 @@ class Transport(threading.Thread, ClosingContextManager): self.auth_handler.abort() for event in self.channel_events.values(): event.set() - try: - self.lock.acquire() + with self.lock: self.server_accept_cv.notify() - finally: - self.lock.release() self.sock.close() except: # Don't raise spurious 'NoneType has no attribute X' errors when we @@ -1945,11 +1927,8 @@ class Transport(threading.Thread, ClosingContextManager): def _negotiate_keys(self, m): # throws SSHException on anything unusual - self.clear_to_send_lock.acquire() - try: + with self.clear_to_send_lock: self.clear_to_send.clear() - finally: - self.clear_to_send_lock.release() if self.local_kex_init is None: # remote side wants to renegotiate self._send_kex_init() @@ -2004,11 +1983,8 @@ class Transport(threading.Thread, ClosingContextManager): announce to the other side that we'd like to negotiate keys, and what kind of key negotiation we support. """ - self.clear_to_send_lock.acquire() - try: + with self.clear_to_send_lock: self.clear_to_send.clear() - finally: - self.clear_to_send_lock.release() self.in_kex = True if self.server_mode: mp_required_prefix = 'diffie-hellman-group-exchange-sha' @@ -2324,11 +2300,8 @@ class Transport(threading.Thread, ClosingContextManager): # it's now okay to send data again (if this was a re-key) if not self.packetizer.need_rekey(): self.in_kex = False - self.clear_to_send_lock.acquire() - try: + with self.clear_to_send_lock: self.clear_to_send.set() - finally: - self.clear_to_send_lock.release() return def _parse_disconnect(self, m): @@ -2393,16 +2366,13 @@ class Transport(threading.Thread, ClosingContextManager): if chan is None: self._log(WARNING, 'Success for unrequested channel! [??]') return - self.lock.acquire() - try: + with self.lock: chan._set_remote_channel( server_chanid, server_window_size, server_max_packet_size) self._log(DEBUG, 'Secsh channel %d opened.' % chanid) if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] - finally: - self.lock.release() return def _parse_channel_open_failure(self, m): @@ -2416,16 +2386,13 @@ class Transport(threading.Thread, ClosingContextManager): 'Secsh channel %d open FAILED: %s: %s' % ( chanid, reason_str, reason_text) ) - self.lock.acquire() - try: + with self.lock: self.saved_exception = ChannelException(reason, reason_text) if chanid in self.channel_events: self._channels.delete(chanid) if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] - finally: - self.lock.release() return def _parse_channel_open(self, m): @@ -2439,11 +2406,8 @@ class Transport(threading.Thread, ClosingContextManager): self._forward_agent_handler is not None ): self._log(DEBUG, 'Incoming forward agent connection') - self.lock.acquire() - try: + with self.lock: my_chanid = self._next_channel() - finally: - self.lock.release() elif (kind == 'x11') and (self._x11_handler is not None): origin_addr = m.get_text() origin_port = m.get_int() @@ -2452,11 +2416,8 @@ class Transport(threading.Thread, ClosingContextManager): 'Incoming x11 connection from %s:%d' % ( origin_addr, origin_port) ) - self.lock.acquire() - try: + with self.lock: my_chanid = self._next_channel() - finally: - self.lock.release() elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): server_addr = m.get_text() server_port = m.get_int() @@ -2467,11 +2428,8 @@ class Transport(threading.Thread, ClosingContextManager): 'Incoming tcp forwarded connection from %s:%d' % ( origin_addr, origin_port) ) - self.lock.acquire() - try: + with self.lock: my_chanid = self._next_channel() - finally: - self.lock.release() elif not self.server_mode: self._log( DEBUG, @@ -2479,11 +2437,8 @@ class Transport(threading.Thread, ClosingContextManager): reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: - self.lock.acquire() - try: + with self.lock: my_chanid = self._next_channel() - finally: - self.lock.release() if kind == 'direct-tcpip': # handle direct-tcpip requests coming from the client dest_addr = m.get_text() @@ -2514,8 +2469,7 @@ class Transport(threading.Thread, ClosingContextManager): return chan = Channel(my_chanid) - self.lock.acquire() - try: + with self.lock: self._channels.put(my_chanid, chan) self.channels_seen[my_chanid] = True chan._set_transport(self) @@ -2523,8 +2477,6 @@ class Transport(threading.Thread, ClosingContextManager): self.default_window_size, self.default_max_packet_size) chan._set_remote_channel( chanid, initial_window_size, max_packet_size) - finally: - self.lock.release() m = Message() m.add_byte(cMSG_CHANNEL_OPEN_SUCCESS) m.add_int(chanid) @@ -2554,13 +2506,10 @@ class Transport(threading.Thread, ClosingContextManager): self._log(DEBUG, 'Debug msg: {0}'.format(util.safe_string(msg))) def _get_subsystem_handler(self, name): - try: - self.lock.acquire() + with self.lock: if name not in self.subsystem_table: return None, [], {} return self.subsystem_table[name] - finally: - self.lock.release() _handler_table = { MSG_NEWKEYS: _parse_newkeys, @@ -2672,39 +2621,24 @@ class ChannelMap (object): self._lock = threading.Lock() def put(self, chanid, chan): - self._lock.acquire() - try: + with self._lock: self._map[chanid] = chan - finally: - self._lock.release() def get(self, chanid): - self._lock.acquire() - try: + with self._lock: return self._map.get(chanid, None) - finally: - self._lock.release() def delete(self, chanid): - self._lock.acquire() - try: + with self._lock: try: del self._map[chanid] except KeyError: pass - finally: - self._lock.release() def values(self): - self._lock.acquire() - try: + with self._lock: return list(self._map.values()) - finally: - self._lock.release() def __len__(self): - self._lock.acquire() - try: + with self._lock: return len(self._map) - finally: - self._lock.release() diff --git a/paramiko/util.py b/paramiko/util.py index de099c0c..54ed081f 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -230,12 +230,9 @@ def get_thread_id(): try: return _g_thread_ids[tid] except KeyError: - _g_thread_lock.acquire() - try: + with _g_thread_lock: _g_thread_counter += 1 ret = _g_thread_ids[tid] = _g_thread_counter - finally: - _g_thread_lock.release() return ret diff --git a/sites/www/changelog.rst b/sites/www/changelog.rst index 4f7303b4..d71b5c7b 100644 --- a/sites/www/changelog.rst +++ b/sites/www/changelog.rst @@ -2,6 +2,8 @@ Changelog ========= +* :support:`-` Update all uses of `threading.Lock` (and similar classes) to use + context manager syntax instead of ``try``/``finally``. * :support:`-` Refactor some common (mostly host-)key related logic into a new module, ``authentication.py``. Includes use of a new `SSHException <paramiko.ssh_exception.SSHException>` subclass, `UnknownKeyType diff --git a/tests/loop.py b/tests/loop.py index e805ad96..499d3c11 100644 --- a/tests/loop.py +++ b/tests/loop.py @@ -42,11 +42,8 @@ class LoopSocket (object): def close(self): self.__unlink() self._closed = True - try: - self.__lock.acquire() + with self.__lock: self.__in_buffer = bytes() - finally: - self.__lock.release() def send(self, data): data = asbytes(data) @@ -57,8 +54,7 @@ class LoopSocket (object): return len(data) def recv(self, n): - self.__lock.acquire() - try: + with self.__lock: if self.__mate is None: # EOF return bytes() @@ -69,8 +65,6 @@ class LoopSocket (object): out = self.__in_buffer[:n] self.__in_buffer = self.__in_buffer[n:] return out - finally: - self.__lock.release() def settimeout(self, n): self.__timeout = n @@ -80,22 +74,16 @@ class LoopSocket (object): self.__mate.__mate = self def __feed(self, data): - self.__lock.acquire() - try: + with self.__lock: self.__in_buffer += data self.__cv.notifyAll() - finally: - self.__lock.release() def __unlink(self): m = None - self.__lock.acquire() - try: + with self.__lock: if self.__mate is not None: m = self.__mate self.__mate = None - finally: - self.__lock.release() if m is not None: m.__unlink() |