diff options
author | Alex Chan <alex@alexwlchan.net> | 2016-11-05 13:01:51 +0000 |
---|---|---|
committer | Hynek Schlawack <hs@ox.cx> | 2016-11-05 14:01:51 +0100 |
commit | 1ca9e3a7c6d7ac20ee3b2c547bf60a46e26a401b (patch) | |
tree | e2d535f8f0db5e7eccfcd985ae1ffb4424d684bb | |
parent | 67666e2fcdaf8ccd4c4694bcc2d921a7ef7db02e (diff) | |
download | pyopenssl-1ca9e3a7c6d7ac20ee3b2c547bf60a46e26a401b.tar.gz |
Convert ServerNameCallbackTests to use pytest-style tests (#565)
* Convert ServerNameCallbackTests to use pytest-style tests
As well as pytest-ifying up the tests, remove a few redundant tests
and tidy up docstrings as per feedback in #563.
Addresses #340.
* Remove a stray ':py:obj:' in test docstring
* Remove _LoopbackMixin from TestServerNameCallback
Per @hynek feedback. This test class only depended on one method,
which can be broken out as a separate function anyway -- I'll
gradually disassemble the loopback as I pytest-ify other tests.
* Re-wrap a few comments
-rw-r--r-- | tests/test_ssl.py | 126 |
1 files changed, 59 insertions, 67 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 2b7e276..332f5bc 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -221,7 +221,7 @@ def _create_certificate_chain(): return [(cakey, cacert), (ikey, icert), (skey, scert)] -class _LoopbackMixin: +class _LoopbackMixin(object): """ Helper mixin which defines methods for creating a connected socket pair and for forcing two connected SSL sockets to talk to each other via memory @@ -257,50 +257,7 @@ class _LoopbackMixin: return server, client def _interactInMemory(self, client_conn, server_conn): - """ - Try to read application bytes from each of the two :py:obj:`Connection` - objects. Copy bytes back and forth between their send/receive buffers - for as long as there is anything to copy. When there is nothing more - to copy, return :py:obj:`None`. If one of them actually manages to - deliver some application bytes, return a two-tuple of the connection - from which the bytes were read and the bytes themselves. - """ - wrote = True - while wrote: - # Loop until neither side has anything to say - wrote = False - - # Copy stuff from each side's send buffer to the other side's - # receive buffer. - for (read, write) in [(client_conn, server_conn), - (server_conn, client_conn)]: - - # Give the side a chance to generate some more bytes, or - # succeed. - try: - data = read.recv(2 ** 16) - except WantReadError: - # It didn't succeed, so we'll hope it generated some - # output. - pass - else: - # It did succeed, so we'll stop now and let the caller deal - # with it. - return (read, data) - - while True: - # Keep copying as long as there's more stuff there. - try: - dirty = read.bio_read(4096) - except WantReadError: - # Okay, nothing more waiting to be sent. Stop - # processing this send buffer. - break - else: - # Keep track of the fact that someone generated some - # output. - wrote = True - write.bio_write(dirty) + return interact_in_memory(client_conn, server_conn) def _handshakeInMemory(self, client_conn, server_conn): """ @@ -319,6 +276,51 @@ class _LoopbackMixin: self._interactInMemory(client_conn, server_conn) +def interact_in_memory(client_conn, server_conn): + """ + Try to read application bytes from each of the two `Connection` objects. + Copy bytes back and forth between their send/receive buffers for as long + as there is anything to copy. When there is nothing more to copy, + return `None`. If one of them actually manages to deliver some application + bytes, return a two-tuple of the connection from which the bytes were read + and the bytes themselves. + """ + wrote = True + while wrote: + # Loop until neither side has anything to say + wrote = False + + # Copy stuff from each side's send buffer to the other side's + # receive buffer. + for (read, write) in [(client_conn, server_conn), + (server_conn, client_conn)]: + + # Give the side a chance to generate some more bytes, or succeed. + try: + data = read.recv(2 ** 16) + except WantReadError: + # It didn't succeed, so we'll hope it generated some output. + pass + else: + # It did succeed, so we'll stop now and let the caller deal + # with it. + return (read, data) + + while True: + # Keep copying as long as there's more stuff there. + try: + dirty = read.bio_read(4096) + except WantReadError: + # Okay, nothing more waiting to be sent. Stop + # processing this send buffer. + break + else: + # Keep track of the fact that someone generated some + # output. + wrote = True + write.bio_write(dirty) + + class VersionTests(TestCase): """ Tests for version information exposed by @@ -1579,24 +1581,14 @@ class ContextTests(TestCase, _LoopbackMixin): self.assertIsInstance(store, X509Store) -class ServerNameCallbackTests(TestCase, _LoopbackMixin): +class TestServerNameCallback(object): """ - Tests for :py:obj:`Context.set_tlsext_servername_callback` and its - interaction with :py:obj:`Connection`. + Tests for `Context.set_tlsext_servername_callback` and its + interaction with `Connection`. """ - def test_wrong_args(self): - """ - :py:obj:`Context.set_tlsext_servername_callback` raises - :py:obj:`TypeError` if called with other than one argument. - """ - context = Context(TLSv1_METHOD) - self.assertRaises(TypeError, context.set_tlsext_servername_callback) - self.assertRaises( - TypeError, context.set_tlsext_servername_callback, 1, 2) - def test_old_callback_forgotten(self): """ - If :py:obj:`Context.set_tlsext_servername_callback` is used to specify + If `Context.set_tlsext_servername_callback` is used to specify a new callback, the one it replaces is dereferenced. """ def callback(connection): @@ -1629,8 +1621,8 @@ class ServerNameCallbackTests(TestCase, _LoopbackMixin): def test_no_servername(self): """ When a client specifies no server name, the callback passed to - :py:obj:`Context.set_tlsext_servername_callback` is invoked and the - result of :py:obj:`Connection.get_servername` is :py:obj:`None`. + `Context.set_tlsext_servername_callback` is invoked and the + result of `Connection.get_servername` is `None`. """ args = [] @@ -1656,15 +1648,15 @@ class ServerNameCallbackTests(TestCase, _LoopbackMixin): client = Connection(Context(TLSv1_METHOD), None) client.set_connect_state() - self._interactInMemory(server, client) + interact_in_memory(server, client) - self.assertEqual([(server, None)], args) + assert args == [(server, None)] def test_servername(self): """ When a client specifies a server name in its hello message, the - callback passed to :py:obj:`Contexts.set_tlsext_servername_callback` is - invoked and the result of :py:obj:`Connection.get_servername` is that + callback passed to `Contexts.set_tlsext_servername_callback` is + invoked and the result of `Connection.get_servername` is that server name. """ args = [] @@ -1687,9 +1679,9 @@ class ServerNameCallbackTests(TestCase, _LoopbackMixin): client.set_connect_state() client.set_tlsext_host_name(b"foo1.example.com") - self._interactInMemory(server, client) + interact_in_memory(server, client) - self.assertEqual([(server, b"foo1.example.com")], args) + assert args == [(server, b"foo1.example.com")] class NextProtoNegotiationTests(TestCase, _LoopbackMixin): |