diff options
author | Victor Stinner <victor.stinner@gmail.com> | 2015-01-29 02:52:57 +0100 |
---|---|---|
committer | Victor Stinner <victor.stinner@gmail.com> | 2015-01-29 02:52:57 +0100 |
commit | 4bbfa97905d6bb79b33d90979c9ae35ecf7bd8f5 (patch) | |
tree | de38f996d29d8ce285847f25dbb9f777fcb5c634 | |
parent | 1ac041be4f0aa79bc7533d17c16f531928180ecd (diff) | |
download | trollius-4bbfa97905d6bb79b33d90979c9ae35ecf7bd8f5.tar.gz |
Fix _SelectorSslTransport.close()
Don't call protocol.connection_lost() if protocol.connection_made() was not
called yet: if the SSL handshake failed or is still in progress.
The close() method can be called if the creation of the connection is
cancelled, by a timeout for example.
-rw-r--r-- | asyncio/selector_events.py | 7 | ||||
-rw-r--r-- | tests/test_selector_events.py | 15 |
2 files changed, 20 insertions, 2 deletions
diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index d046eb2..3195f62 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -479,6 +479,7 @@ class _SelectorTransport(transports._FlowControlMixin, self._sock = sock self._sock_fd = sock.fileno() self._protocol = protocol + self._protocol_connected = True self._server = server self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. @@ -555,7 +556,8 @@ class _SelectorTransport(transports._FlowControlMixin, def _call_connection_lost(self, exc): try: - self._protocol.connection_lost(exc) + if self._protocol_connected: + self._protocol.connection_lost(exc) finally: self._sock.close() self._sock = None @@ -718,6 +720,8 @@ class _SelectorSslTransport(_SelectorTransport): sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) super().__init__(loop, sslsock, protocol, extra, server) + # the protocol connection is only made after the SSL handshake + self._protocol_connected = False self._server_hostname = server_hostname self._waiter = waiter @@ -797,6 +801,7 @@ class _SelectorSslTransport(_SelectorTransport): self._read_wants_write = False self._write_wants_read = False self._loop.add_reader(self._sock_fd, self._read_ready) + self._protocol_connected = True self._loop.call_soon(self._protocol.connection_made, self) # only wake up the waiter when connection_made() has been called self._loop.call_soon(self._wakeup_waiter) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 5152616..f64e40d 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1427,7 +1427,7 @@ class SelectorSslTransportTests(test_utils.TestCase): self.assertFalse(tr.can_write_eof()) self.assertRaises(NotImplementedError, tr.write_eof) - def test_close(self): + def check_close(self): tr = self._make_one() tr.close() @@ -1439,6 +1439,19 @@ class SelectorSslTransportTests(test_utils.TestCase): self.assertEqual(tr._conn_lost, 1) self.assertEqual(1, self.loop.remove_reader_count[1]) + test_utils.run_briefly(self.loop) + + def test_close(self): + self.check_close() + self.assertTrue(self.protocol.connection_made.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_close_not_connected(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.check_close() + self.assertFalse(self.protocol.connection_made.called) + self.assertFalse(self.protocol.connection_lost.called) + @unittest.skipIf(ssl is None, 'No SSL support') def test_server_hostname(self): self.ssl_transport(server_hostname='localhost') |