summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Chan <alex@alexwlchan.net>2017-01-30 07:13:30 +0000
committerHynek Schlawack <hs@ox.cx>2017-01-30 08:13:30 +0100
commit1c0cb66f81d747b6349f0e132e369f52a0024efe (patch)
tree362766c8947dd4c210ba387002ecc4347d251288
parent7f3914b478e8b4fcd6ed0e68a272649bbb1c627d (diff)
downloadpyopenssl-1c0cb66f81d747b6349f0e132e369f52a0024efe.tar.gz
Convert the rest of TestConnection to be pytest-style (#594)
-rw-r--r--tests/test_ssl.py663
1 files changed, 286 insertions, 377 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index e0a720b..14b2310 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -229,33 +229,13 @@ class _LoopbackMixin(object):
BIOs.
"""
def _loopbackClientFactory(self, socket):
- client = Connection(Context(TLSv1_METHOD), socket)
- client.set_connect_state()
- return client
+ return loopback_client_factory(socket)
def _loopbackServerFactory(self, socket):
- ctx = Context(TLSv1_METHOD)
- ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
- ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
- server = Connection(ctx, socket)
- server.set_accept_state()
- return server
+ return loopback_server_factory(socket)
def _loopback(self, serverFactory=None, clientFactory=None):
- if serverFactory is None:
- serverFactory = self._loopbackServerFactory
- if clientFactory is None:
- clientFactory = self._loopbackClientFactory
-
- (server, client) = socket_pair()
- server = serverFactory(server)
- client = clientFactory(client)
-
- handshake(client, server)
-
- server.setblocking(True)
- client.setblocking(True)
- return server, client
+ return loopback(serverFactory, clientFactory)
def _interactInMemory(self, client_conn, server_conn):
return interact_in_memory(client_conn, server_conn)
@@ -264,6 +244,42 @@ class _LoopbackMixin(object):
return handshake_in_memory(client_conn, server_conn)
+def loopback_client_factory(socket):
+ client = Connection(Context(TLSv1_METHOD), socket)
+ client.set_connect_state()
+ return client
+
+
+def loopback_server_factory(socket):
+ ctx = Context(TLSv1_METHOD)
+ ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+ server = Connection(ctx, socket)
+ server.set_accept_state()
+ return server
+
+
+def loopback(server_factory=None, client_factory=None):
+ """
+ Create a connected socket pair and force two connected SSL sockets
+ to talk to each other via memory BIOs.
+ """
+ if server_factory is None:
+ server_factory = loopback_server_factory
+ if client_factory is None:
+ client_factory = loopback_client_factory
+
+ (server, client) = socket_pair()
+ server = server_factory(server)
+ client = client_factory(client)
+
+ handshake(client, server)
+
+ server.setblocking(True)
+ client.setblocking(True)
+ return server, client
+
+
def interact_in_memory(client_conn, server_conn):
"""
Try to read application bytes from each of the two `Connection` objects.
@@ -1956,9 +1972,9 @@ class TestSession(object):
assert isinstance(new_session, Session)
-class ConnectionTests(TestCase, _LoopbackMixin):
+class TestConnection(object):
"""
- Unit tests for :class:`OpenSSL.SSL.Connection`.
+ Unit tests for `OpenSSL.SSL.Connection`.
"""
# XXX get_peer_certificate -> None
# XXX sock_shutdown
@@ -1976,57 +1992,47 @@ class ConnectionTests(TestCase, _LoopbackMixin):
def test_type(self):
"""
- :py:obj:`Connection` and :py:obj:`ConnectionType` refer to the same
- type object and can be used to create instances of that type.
+ `Connection` and `ConnectionType` refer to the same type object and
+ can be used to create instances of that type.
"""
- self.assertIdentical(Connection, ConnectionType)
+ assert Connection is ConnectionType
ctx = Context(TLSv1_METHOD)
- self.assertConsistentType(Connection, 'Connection', ctx, None)
+ assert is_consistent_type(Connection, 'Connection', ctx, None)
def test_get_context(self):
"""
- :py:obj:`Connection.get_context` returns the :py:obj:`Context` instance
- used to construct the :py:obj:`Connection` instance.
+ `Connection.get_context` returns the `Context` instance used to
+ construct the `Connection` instance.
"""
context = Context(TLSv1_METHOD)
connection = Connection(context, None)
- self.assertIdentical(connection.get_context(), context)
-
- def test_get_context_wrong_args(self):
- """
- :py:obj:`Connection.get_context` raises :py:obj:`TypeError` if called
- with any arguments.
- """
- connection = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, connection.get_context, None)
+ assert connection.get_context() is context
def test_set_context_wrong_args(self):
"""
- :py:obj:`Connection.set_context` raises :py:obj:`TypeError` if called
- with a non-:py:obj:`Context` instance argument or with any number of
- arguments other than 1.
+ `Connection.set_context` raises `TypeError` if called with a
+ non-`Context` instance argument,
"""
ctx = Context(TLSv1_METHOD)
connection = Connection(ctx, None)
- self.assertRaises(TypeError, connection.set_context)
- self.assertRaises(TypeError, connection.set_context, object())
- self.assertRaises(TypeError, connection.set_context, "hello")
- self.assertRaises(TypeError, connection.set_context, 1)
- self.assertRaises(TypeError, connection.set_context, 1, 2)
- self.assertRaises(
- TypeError, connection.set_context, Context(TLSv1_METHOD), 2)
- self.assertIdentical(ctx, connection.get_context())
+ with pytest.raises(TypeError):
+ connection.set_context(object())
+ with pytest.raises(TypeError):
+ connection.set_context("hello")
+ with pytest.raises(TypeError):
+ connection.set_context(1)
+ assert ctx is connection.get_context()
def test_set_context(self):
"""
- :py:obj:`Connection.set_context` specifies a new :py:obj:`Context`
- instance to be used for the connection.
+ `Connection.set_context` specifies a new `Context` instance to be
+ used for the connection.
"""
original = Context(SSLv23_METHOD)
replacement = Context(TLSv1_METHOD)
connection = Connection(original, None)
connection.set_context(replacement)
- self.assertIdentical(replacement, connection.get_context())
+ assert replacement is connection.get_context()
# Lose our references to the contexts, just in case the Connection
# isn't properly managing its own contributions to their reference
# counts.
@@ -2035,88 +2041,52 @@ class ConnectionTests(TestCase, _LoopbackMixin):
def test_set_tlsext_host_name_wrong_args(self):
"""
- If :py:obj:`Connection.set_tlsext_host_name` is called with a non-byte
- string argument or a byte string with an embedded NUL or other than one
- argument, :py:obj:`TypeError` is raised.
+ If `Connection.set_tlsext_host_name` is called with a non-byte string
+ argument or a byte string with an embedded NUL, `TypeError` is raised.
"""
conn = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, conn.set_tlsext_host_name)
- self.assertRaises(TypeError, conn.set_tlsext_host_name, object())
- self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456)
- self.assertRaises(
- TypeError, conn.set_tlsext_host_name, b"with\0null")
+ with pytest.raises(TypeError):
+ conn.set_tlsext_host_name(object())
+ with pytest.raises(TypeError):
+ conn.set_tlsext_host_name(b"with\0null")
if PY3:
# On Python 3.x, don't accidentally implicitly convert from text.
- self.assertRaises(
- TypeError,
- conn.set_tlsext_host_name, b"example.com".decode("ascii"))
-
- def test_get_servername_wrong_args(self):
- """
- :py:obj:`Connection.get_servername` raises :py:obj:`TypeError` if
- called with any arguments.
- """
- connection = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, connection.get_servername, object())
- self.assertRaises(TypeError, connection.get_servername, 1)
- self.assertRaises(TypeError, connection.get_servername, "hello")
+ with pytest.raises(TypeError):
+ conn.set_tlsext_host_name(b"example.com".decode("ascii"))
def test_pending(self):
"""
- :py:obj:`Connection.pending` returns the number of bytes available for
+ `Connection.pending` returns the number of bytes available for
immediate read.
"""
connection = Connection(Context(TLSv1_METHOD), None)
- self.assertEquals(connection.pending(), 0)
-
- def test_pending_wrong_args(self):
- """
- :py:obj:`Connection.pending` raises :py:obj:`TypeError` if called with
- any arguments.
- """
- connection = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, connection.pending, None)
+ assert connection.pending() == 0
def test_peek(self):
"""
- :py:obj:`Connection.recv` peeks into the connection if
- :py:obj:`socket.MSG_PEEK` is passed.
+ `Connection.recv` peeks into the connection if `socket.MSG_PEEK`
+ is passed.
"""
- server, client = self._loopback()
+ server, client = loopback()
server.send(b'xy')
- self.assertEqual(client.recv(2, MSG_PEEK), b'xy')
- self.assertEqual(client.recv(2, MSG_PEEK), b'xy')
- self.assertEqual(client.recv(2), b'xy')
+ assert client.recv(2, MSG_PEEK) == b'xy'
+ assert client.recv(2, MSG_PEEK) == b'xy'
+ assert client.recv(2) == b'xy'
def test_connect_wrong_args(self):
"""
- :py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with
- a non-address argument or with the wrong number of arguments.
+ `Connection.connect` raises `TypeError` if called with
+ a non-address argument.
"""
connection = Connection(Context(TLSv1_METHOD), socket())
- self.assertRaises(TypeError, connection.connect, None)
- self.assertRaises(TypeError, connection.connect)
- self.assertRaises(
- TypeError, connection.connect, ("127.0.0.1", 1), None
- )
-
- def test_connection_undefined_attr(self):
- """
- :py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with
- a non-address argument or with the wrong number of arguments.
- """
-
- def attr_access_test(connection):
- return connection.an_attribute_which_is_not_defined
-
- connection = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(AttributeError, attr_access_test, connection)
+ with pytest.raises(TypeError):
+ connection.connect(None)
def test_connect_refused(self):
"""
- :py:obj:`Connection.connect` raises :py:obj:`socket.error` if the
- underlying socket connect method raises it.
+ `Connection.connect` raises `socket.error` if the underlying socket
+ connect method raises it.
"""
client = socket()
context = Context(TLSv1_METHOD)
@@ -2131,8 +2101,7 @@ class ConnectionTests(TestCase, _LoopbackMixin):
def test_connect(self):
"""
- :py:obj:`Connection.connect` establishes a connection to the specified
- address.
+ `Connection.connect` establishes a connection to the specified address.
"""
port = socket()
port.bind(('', 0))
@@ -2148,8 +2117,8 @@ class ConnectionTests(TestCase, _LoopbackMixin):
)
def test_connect_ex(self):
"""
- If there is a connection error, :py:obj:`Connection.connect_ex`
- returns the errno instead of raising an exception.
+ If there is a connection error, `Connection.connect_ex` returns the
+ errno instead of raising an exception.
"""
port = socket()
port.bind(('', 0))
@@ -2159,22 +2128,13 @@ class ConnectionTests(TestCase, _LoopbackMixin):
clientSSL.setblocking(False)
result = clientSSL.connect_ex(port.getsockname())
expected = (EINPROGRESS, EWOULDBLOCK)
- self.assertTrue(
- result in expected, "%r not in %r" % (result, expected))
-
- def test_accept_wrong_args(self):
- """
- :py:obj:`Connection.accept` raises :py:obj:`TypeError` if called with
- any arguments.
- """
- connection = Connection(Context(TLSv1_METHOD), socket())
- self.assertRaises(TypeError, connection.accept, None)
+ assert result in expected
def test_accept(self):
"""
- :py:obj:`Connection.accept` accepts a pending connection attempt and
- returns a tuple of a new :py:obj:`Connection` (the accepted client) and
- the address the connection originated from.
+ `Connection.accept` accepts a pending connection attempt and returns a
+ tuple of a new `Connection` (the accepted client) and the address the
+ connection originated from.
"""
ctx = Context(TLSv1_METHOD)
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
@@ -2192,58 +2152,53 @@ class ConnectionTests(TestCase, _LoopbackMixin):
serverSSL, address = portSSL.accept()
- self.assertTrue(isinstance(serverSSL, Connection))
- self.assertIdentical(serverSSL.get_context(), ctx)
- self.assertEquals(address, clientSSL.getsockname())
+ assert isinstance(serverSSL, Connection)
+ assert serverSSL.get_context() is ctx
+ assert address == clientSSL.getsockname()
def test_shutdown_wrong_args(self):
"""
- :py:obj:`Connection.shutdown` raises :py:obj:`TypeError` if called with
- the wrong number of arguments or with arguments other than integers.
+ `Connection.set_shutdown` raises `TypeError` if called with arguments
+ other than integers.
"""
connection = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, connection.shutdown, None)
- self.assertRaises(TypeError, connection.get_shutdown, None)
- self.assertRaises(TypeError, connection.set_shutdown)
- self.assertRaises(TypeError, connection.set_shutdown, None)
- self.assertRaises(TypeError, connection.set_shutdown, 0, 1)
+ with pytest.raises(TypeError):
+ connection.set_shutdown(None)
def test_shutdown(self):
"""
- :py:obj:`Connection.shutdown` performs an SSL-level connection
- shutdown.
+ `Connection.shutdown` performs an SSL-level connection shutdown.
"""
- server, client = self._loopback()
- self.assertFalse(server.shutdown())
- self.assertEquals(server.get_shutdown(), SENT_SHUTDOWN)
- self.assertRaises(ZeroReturnError, client.recv, 1024)
- self.assertEquals(client.get_shutdown(), RECEIVED_SHUTDOWN)
+ server, client = loopback()
+ assert not server.shutdown()
+ assert server.get_shutdown() == SENT_SHUTDOWN
+ with pytest.raises(ZeroReturnError):
+ client.recv(1024)
+ assert client.get_shutdown() == RECEIVED_SHUTDOWN
client.shutdown()
- self.assertEquals(
- client.get_shutdown(), SENT_SHUTDOWN | RECEIVED_SHUTDOWN
- )
- self.assertRaises(ZeroReturnError, server.recv, 1024)
- self.assertEquals(
- server.get_shutdown(), SENT_SHUTDOWN | RECEIVED_SHUTDOWN
- )
+ assert client.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN)
+ with pytest.raises(ZeroReturnError):
+ server.recv(1024)
+ assert server.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN)
def test_shutdown_closed(self):
"""
- If the underlying socket is closed, :py:obj:`Connection.shutdown`
- propagates the write error from the low level write call.
+ If the underlying socket is closed, `Connection.shutdown` propagates
+ the write error from the low level write call.
"""
- server, client = self._loopback()
+ server, client = loopback()
server.sock_shutdown(2)
- exc = self.assertRaises(SysCallError, server.shutdown)
- if platform == "win32":
- self.assertEqual(exc.args[0], ESHUTDOWN)
- else:
- self.assertEqual(exc.args[0], EPIPE)
+ with pytest.raises(SysCallError) as exc:
+ server.shutdown()
+ if platform == "win32":
+ assert exc.value.args[0] == ESHUTDOWN
+ else:
+ assert exc.value.args[0] == EPIPE
def test_shutdown_truncated(self):
"""
- If the underlying connection is truncated, :obj:`Connection.shutdown`
- raises an :obj:`Error`.
+ If the underlying connection is truncated, `Connection.shutdown`
+ raises an `Error`.
"""
server_ctx = Context(TLSv1_METHOD)
client_ctx = Context(TLSv1_METHOD)
@@ -2253,39 +2208,41 @@ class ConnectionTests(TestCase, _LoopbackMixin):
load_certificate(FILETYPE_PEM, server_cert_pem))
server = Connection(server_ctx, None)
client = Connection(client_ctx, None)
- self._handshakeInMemory(client, server)
- self.assertEqual(server.shutdown(), False)
- self.assertRaises(WantReadError, server.shutdown)
+ handshake_in_memory(client, server)
+ assert not server.shutdown()
+ with pytest.raises(WantReadError):
+ server.shutdown()
server.bio_shutdown()
- self.assertRaises(Error, server.shutdown)
+ with pytest.raises(Error):
+ server.shutdown()
def test_set_shutdown(self):
"""
- :py:obj:`Connection.set_shutdown` sets the state of the SSL connection
+ `Connection.set_shutdown` sets the state of the SSL connection
shutdown process.
"""
connection = Connection(Context(TLSv1_METHOD), socket())
connection.set_shutdown(RECEIVED_SHUTDOWN)
- self.assertEquals(connection.get_shutdown(), RECEIVED_SHUTDOWN)
+ assert connection.get_shutdown() == RECEIVED_SHUTDOWN
@skip_if_py3
def test_set_shutdown_long(self):
"""
- On Python 2 :py:obj:`Connection.set_shutdown` accepts an argument
- of type :py:obj:`long` as well as :py:obj:`int`.
+ On Python 2 `Connection.set_shutdown` accepts an argument
+ of type `long` as well as `int`.
"""
connection = Connection(Context(TLSv1_METHOD), socket())
connection.set_shutdown(long(RECEIVED_SHUTDOWN))
- self.assertEquals(connection.get_shutdown(), RECEIVED_SHUTDOWN)
+ assert connection.get_shutdown() == RECEIVED_SHUTDOWN
def test_state_string(self):
"""
- :meth:`Connection.state_string` verbosely describes the current
- state of the :class:`Connection`.
+ `Connection.state_string` verbosely describes the current state of
+ the `Connection`.
"""
server, client = socket_pair()
- server = self._loopbackServerFactory(server)
- client = self._loopbackClientFactory(client)
+ server = loopback_server_factory(server)
+ client = loopback_client_factory(client)
assert server.get_state_string() in [
b"before/accept initialization", b"before SSL initialization"
@@ -2294,22 +2251,11 @@ class ConnectionTests(TestCase, _LoopbackMixin):
b"before/connect initialization", b"before SSL initialization"
]
- def test_app_data_wrong_args(self):
- """
- :py:obj:`Connection.set_app_data` raises :py:obj:`TypeError` if called
- with other than one argument. :py:obj:`Connection.get_app_data` raises
- :py:obj:`TypeError` if called with any arguments.
- """
- conn = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, conn.get_app_data, None)
- self.assertRaises(TypeError, conn.set_app_data)
- self.assertRaises(TypeError, conn.set_app_data, None, None)
-
def test_app_data(self):
"""
Any object can be set as app data by passing it to
- :py:obj:`Connection.set_app_data` and later retrieved with
- :py:obj:`Connection.get_app_data`.
+ `Connection.set_app_data` and later retrieved with
+ `Connection.get_app_data`.
"""
conn = Connection(Context(TLSv1_METHOD), None)
assert None is conn.get_app_data()
@@ -2319,26 +2265,16 @@ class ConnectionTests(TestCase, _LoopbackMixin):
def test_makefile(self):
"""
- :py:obj:`Connection.makefile` is not implemented and calling that
- method raises :py:obj:`NotImplementedError`.
+ `Connection.makefile` is not implemented and calling that
+ method raises `NotImplementedError`.
"""
conn = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(NotImplementedError, conn.makefile)
-
- def test_get_peer_cert_chain_wrong_args(self):
- """
- :py:obj:`Connection.get_peer_cert_chain` raises :py:obj:`TypeError` if
- called with any arguments.
- """
- conn = Connection(Context(TLSv1_METHOD), None)
- self.assertRaises(TypeError, conn.get_peer_cert_chain, 1)
- self.assertRaises(TypeError, conn.get_peer_cert_chain, "foo")
- self.assertRaises(TypeError, conn.get_peer_cert_chain, object())
- self.assertRaises(TypeError, conn.get_peer_cert_chain, [])
+ with pytest.raises(NotImplementedError):
+ conn.makefile()
def test_get_peer_cert_chain(self):
"""
- :py:obj:`Connection.get_peer_cert_chain` returns a list of certificates
+ `Connection.get_peer_cert_chain` returns a list of certificates
which the connected server returned for the certification verification.
"""
chain = _create_certificate_chain()
@@ -2358,21 +2294,18 @@ class ConnectionTests(TestCase, _LoopbackMixin):
client = Connection(clientContext, None)
client.set_connect_state()
- self._interactInMemory(client, server)
+ interact_in_memory(client, server)
chain = client.get_peer_cert_chain()
- self.assertEqual(len(chain), 3)
- self.assertEqual(
- "Server Certificate", chain[0].get_subject().CN)
- self.assertEqual(
- "Intermediate Certificate", chain[1].get_subject().CN)
- self.assertEqual(
- "Authority Certificate", chain[2].get_subject().CN)
+ assert len(chain) == 3
+ assert "Server Certificate" == chain[0].get_subject().CN
+ assert "Intermediate Certificate" == chain[1].get_subject().CN
+ assert "Authority Certificate" == chain[2].get_subject().CN
def test_get_peer_cert_chain_none(self):
"""
- :py:obj:`Connection.get_peer_cert_chain` returns :py:obj:`None` if the
- peer sends no certificate chain.
+ `Connection.get_peer_cert_chain` returns `None` if the peer sends
+ no certificate chain.
"""
ctx = Context(TLSv1_METHOD)
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
@@ -2381,71 +2314,57 @@ class ConnectionTests(TestCase, _LoopbackMixin):
server.set_accept_state()
client = Connection(Context(TLSv1_METHOD), None)
client.set_connect_state()
- self._interactInMemory(client, server)
- self.assertIdentical(None, server.get_peer_cert_chain())
-
- def test_get_session_wrong_args(self):
- """
- :py:obj:`Connection.get_session` raises :py:obj:`TypeError` if called
- with any arguments.
- """
- ctx = Context(TLSv1_METHOD)
- server = Connection(ctx, None)
- self.assertRaises(TypeError, server.get_session, 123)
- self.assertRaises(TypeError, server.get_session, "hello")
- self.assertRaises(TypeError, server.get_session, object())
+ interact_in_memory(client, server)
+ assert None is server.get_peer_cert_chain()
def test_get_session_unconnected(self):
"""
- :py:obj:`Connection.get_session` returns :py:obj:`None` when used with
- an object which has not been connected.
+ `Connection.get_session` returns `None` when used with an object
+ which has not been connected.
"""
ctx = Context(TLSv1_METHOD)
server = Connection(ctx, None)
session = server.get_session()
- self.assertIdentical(None, session)
+ assert None is session
def test_server_get_session(self):
"""
- On the server side of a connection, :py:obj:`Connection.get_session`
- returns a :py:class:`Session` instance representing the SSL session for
- that connection.
+ On the server side of a connection, `Connection.get_session` returns a
+ `Session` instance representing the SSL session for that connection.
"""
- server, client = self._loopback()
+ server, client = loopback()
session = server.get_session()
- self.assertIsInstance(session, Session)
+ assert isinstance(session, Session)
def test_client_get_session(self):
"""
- On the client side of a connection, :py:obj:`Connection.get_session`
- returns a :py:class:`Session` instance representing the SSL session for
+ On the client side of a connection, `Connection.get_session`
+ returns a `Session` instance representing the SSL session for
that connection.
"""
- server, client = self._loopback()
+ server, client = loopback()
session = client.get_session()
- self.assertIsInstance(session, Session)
+ assert isinstance(session, Session)
def test_set_session_wrong_args(self):
"""
- If called with an object that is not an instance of
- :py:class:`Session`, or with other than one argument,
- :py:obj:`Connection.set_session` raises :py:obj:`TypeError`.
+ `Connection.set_session` raises `TypeError` if called with an object
+ that is not an instance of `Session`.
"""
ctx = Context(TLSv1_METHOD)
connection = Connection(ctx, None)
- self.assertRaises(TypeError, connection.set_session)
- self.assertRaises(TypeError, connection.set_session, 123)
- self.assertRaises(TypeError, connection.set_session, "hello")
- self.assertRaises(TypeError, connection.set_session, object())
- self.assertRaises(
- TypeError, connection.set_session, Session(), Session())
+ with pytest.raises(TypeError):
+ connection.set_session(123)
+ with pytest.raises(TypeError):
+ connection.set_session("hello")
+ with pytest.raises(TypeError):
+ connection.set_session(object())
def test_client_set_session(self):
"""
- :py:obj:`Connection.set_session`, when used prior to a connection being
- established, accepts a :py:class:`Session` instance and causes an
- attempt to re-use the session it represents when the SSL handshake is
- performed.
+ `Connection.set_session`, when used prior to a connection being
+ established, accepts a `Session` instance and causes an attempt to
+ re-use the session it represents when the SSL handshake is performed.
"""
key = load_privatekey(FILETYPE_PEM, server_key_pem)
cert = load_certificate(FILETYPE_PEM, server_cert_pem)
@@ -2459,17 +2378,17 @@ class ConnectionTests(TestCase, _LoopbackMixin):
server.set_accept_state()
return server
- originalServer, originalClient = self._loopback(
- serverFactory=makeServer)
+ originalServer, originalClient = loopback(
+ server_factory=makeServer)
originalSession = originalClient.get_session()
def makeClient(socket):
- client = self._loopbackClientFactory(socket)
+ client = loopback_client_factory(socket)
client.set_session(originalSession)
return client
- resumedServer, resumedClient = self._loopback(
- serverFactory=makeServer,
- clientFactory=makeClient)
+ resumedServer, resumedClient = loopback(
+ server_factory=makeServer,
+ client_factory=makeClient)
# This is a proxy: in general, we have no access to any unique
# identifier for the session (new enough versions of OpenSSL expose
@@ -2477,15 +2396,13 @@ class ConnectionTests(TestCase, _LoopbackMixin):
# Instead, exploit the fact that the master key is re-used if the
# session is re-used. As long as the master key for the two
# connections is the same, the session was re-used!
- self.assertEqual(
- originalServer.master_key(), resumedServer.master_key())
+ assert originalServer.master_key() == resumedServer.master_key()
def test_set_session_wrong_method(self):
"""
- If :py:obj:`Connection.set_session` is passed a :py:class:`Session`
- instance associated with a context using a different SSL method than
- the :py:obj:`Connection` is using, a :py:class:`OpenSSL.SSL.Error` is
- raised.
+ If `Connection.set_session` is passed a `Session` instance associated
+ with a context using a different SSL method than the `Connection`
+ is using, a `OpenSSL.SSL.Error` is raised.
"""
# Make this work on both OpenSSL 1.0.0, which doesn't support TLSv1.2
# and also on OpenSSL 1.1.0 which doesn't support SSLv3. (SSL_ST_INIT
@@ -2514,8 +2431,8 @@ class ConnectionTests(TestCase, _LoopbackMixin):
client.set_connect_state()
return client
- originalServer, originalClient = self._loopback(
- serverFactory=makeServer, clientFactory=makeOriginalClient)
+ originalServer, originalClient = loopback(
+ server_factory=makeServer, client_factory=makeOriginalClient)
originalSession = originalClient.get_session()
def makeClient(socket):
@@ -2525,14 +2442,13 @@ class ConnectionTests(TestCase, _LoopbackMixin):
client.set_session(originalSession)
return client
- self.assertRaises(
- Error,
- self._loopback, clientFactory=makeClient, serverFactory=makeServer)
+ with pytest.raises(Error):
+ loopback(client_factory=makeClient, server_factory=makeServer)
def test_wantWriteError(self):
"""
- :py:obj:`Connection` methods which generate output raise
- :py:obj:`OpenSSL.SSL.WantWriteError` if writing to the connection's BIO
+ `Connection` methods which generate output raise
+ `OpenSSL.SSL.WantWriteError` if writing to the connection's BIO
fail indicating a should-write state.
"""
client_socket, server_socket = socket_pair()
@@ -2551,57 +2467,57 @@ class ConnectionTests(TestCase, _LoopbackMixin):
break
raise
else:
- self.fail(
+ pytest.fail(
"Failed to fill socket buffer, cannot test BIO want write")
ctx = Context(TLSv1_METHOD)
conn = Connection(ctx, client_socket)
# Client's speak first, so make it an SSL client
conn.set_connect_state()
- self.assertRaises(WantWriteError, conn.do_handshake)
+ with pytest.raises(WantWriteError):
+ conn.do_handshake()
# XXX want_read
def test_get_finished_before_connect(self):
"""
- :py:obj:`Connection.get_finished` returns :py:obj:`None` before TLS
- handshake is completed.
+ `Connection.get_finished` returns `None` before TLS handshake
+ is completed.
"""
ctx = Context(TLSv1_METHOD)
connection = Connection(ctx, None)
- self.assertEqual(connection.get_finished(), None)
+ assert connection.get_finished() is None
def test_get_peer_finished_before_connect(self):
"""
- :py:obj:`Connection.get_peer_finished` returns :py:obj:`None` before
- TLS handshake is completed.
+ `Connection.get_peer_finished` returns `None` before TLS handshake
+ is completed.
"""
ctx = Context(TLSv1_METHOD)
connection = Connection(ctx, None)
- self.assertEqual(connection.get_peer_finished(), None)
+ assert connection.get_peer_finished() is None
def test_get_finished(self):
"""
- :py:obj:`Connection.get_finished` method returns the TLS Finished
- message send from client, or server. Finished messages are send during
+ `Connection.get_finished` method returns the TLS Finished message send
+ from client, or server. Finished messages are send during
TLS handshake.
"""
+ server, client = loopback()
- server, client = self._loopback()
-
- self.assertNotEqual(server.get_finished(), None)
- self.assertTrue(len(server.get_finished()) > 0)
+ assert server.get_finished() is not None
+ assert len(server.get_finished()) > 0
def test_get_peer_finished(self):
"""
- :py:obj:`Connection.get_peer_finished` method returns the TLS Finished
+ `Connection.get_peer_finished` method returns the TLS Finished
message received from client, or server. Finished messages are send
during TLS handshake.
"""
- server, client = self._loopback()
+ server, client = loopback()
- self.assertNotEqual(server.get_peer_finished(), None)
- self.assertTrue(len(server.get_peer_finished()) > 0)
+ assert server.get_peer_finished() is not None
+ assert len(server.get_peer_finished()) > 0
def test_tls_finished_message_symmetry(self):
"""
@@ -2611,109 +2527,148 @@ class ConnectionTests(TestCase, _LoopbackMixin):
The TLS Finished message send by client must be the TLS Finished
message received by server.
"""
- server, client = self._loopback()
+ server, client = loopback()
- self.assertEqual(server.get_finished(), client.get_peer_finished())
- self.assertEqual(client.get_finished(), server.get_peer_finished())
+ assert server.get_finished() == client.get_peer_finished()
+ assert client.get_finished() == server.get_peer_finished()
def test_get_cipher_name_before_connect(self):
"""
- :py:obj:`Connection.get_cipher_name` returns :py:obj:`None` if no
- connection has been established.
+ `Connection.get_cipher_name` returns `None` if no connection
+ has been established.
"""
ctx = Context(TLSv1_METHOD)
conn = Connection(ctx, None)
- self.assertIdentical(conn.get_cipher_name(), None)
+ assert conn.get_cipher_name() is None
def test_get_cipher_name(self):
"""
- :py:obj:`Connection.get_cipher_name` returns a :py:class:`unicode`
- string giving the name of the currently used cipher.
+ `Connection.get_cipher_name` returns a `unicode` string giving the
+ name of the currently used cipher.
"""
- server, client = self._loopback()
+ server, client = loopback()
server_cipher_name, client_cipher_name = \
server.get_cipher_name(), client.get_cipher_name()
- self.assertIsInstance(server_cipher_name, text_type)
- self.assertIsInstance(client_cipher_name, text_type)
+ assert isinstance(server_cipher_name, text_type)
+ assert isinstance(client_cipher_name, text_type)
- self.assertEqual(server_cipher_name, client_cipher_name)
+ assert server_cipher_name == client_cipher_name
def test_get_cipher_version_before_connect(self):
"""
- :py:obj:`Connection.get_cipher_version` returns :py:obj:`None` if no
- connection has been established.
+ `Connection.get_cipher_version` returns `None` if no connection
+ has been established.
"""
ctx = Context(TLSv1_METHOD)
conn = Connection(ctx, None)
- self.assertIdentical(conn.get_cipher_version(), None)
+ assert conn.get_cipher_version() is None
def test_get_cipher_version(self):
"""
- :py:obj:`Connection.get_cipher_version` returns a :py:class:`unicode`
- string giving the protocol name of the currently used cipher.
+ `Connection.get_cipher_version` returns a `unicode` string giving
+ the protocol name of the currently used cipher.
"""
- server, client = self._loopback()
+ server, client = loopback()
server_cipher_version, client_cipher_version = \
server.get_cipher_version(), client.get_cipher_version()
- self.assertIsInstance(server_cipher_version, text_type)
- self.assertIsInstance(client_cipher_version, text_type)
+ assert isinstance(server_cipher_version, text_type)
+ assert isinstance(client_cipher_version, text_type)
- self.assertEqual(server_cipher_version, client_cipher_version)
+ assert server_cipher_version == client_cipher_version
def test_get_cipher_bits_before_connect(self):
"""
- :py:obj:`Connection.get_cipher_bits` returns :py:obj:`None` if no
- connection has been established.
+ `Connection.get_cipher_bits` returns `None` if no connection has
+ been established.
"""
ctx = Context(TLSv1_METHOD)
conn = Connection(ctx, None)
- self.assertIdentical(conn.get_cipher_bits(), None)
+ assert conn.get_cipher_bits() is None
def test_get_cipher_bits(self):
"""
- :py:obj:`Connection.get_cipher_bits` returns the number of secret bits
+ `Connection.get_cipher_bits` returns the number of secret bits
of the currently used cipher.
"""
- server, client = self._loopback()
+ server, client = loopback()
server_cipher_bits, client_cipher_bits = \
server.get_cipher_bits(), client.get_cipher_bits()
- self.assertIsInstance(server_cipher_bits, int)
- self.assertIsInstance(client_cipher_bits, int)
+ assert isinstance(server_cipher_bits, int)
+ assert isinstance(client_cipher_bits, int)
- self.assertEqual(server_cipher_bits, client_cipher_bits)
+ assert server_cipher_bits == client_cipher_bits
def test_get_protocol_version_name(self):
"""
- :py:obj:`Connection.get_protocol_version_name()` returns a string
- giving the protocol version of the current connection.
+ `Connection.get_protocol_version_name()` returns a string giving the
+ protocol version of the current connection.
"""
- server, client = self._loopback()
+ server, client = loopback()
client_protocol_version_name = client.get_protocol_version_name()
server_protocol_version_name = server.get_protocol_version_name()
- self.assertIsInstance(server_protocol_version_name, text_type)
- self.assertIsInstance(client_protocol_version_name, text_type)
+ assert isinstance(server_protocol_version_name, text_type)
+ assert isinstance(client_protocol_version_name, text_type)
- self.assertEqual(
- server_protocol_version_name, client_protocol_version_name
- )
+ assert server_protocol_version_name == client_protocol_version_name
def test_get_protocol_version(self):
"""
- :py:obj:`Connection.get_protocol_version()` returns an integer
+ `Connection.get_protocol_version()` returns an integer
giving the protocol version of the current connection.
"""
- server, client = self._loopback()
+ server, client = loopback()
client_protocol_version = client.get_protocol_version()
server_protocol_version = server.get_protocol_version()
- self.assertIsInstance(server_protocol_version, int)
- self.assertIsInstance(client_protocol_version, int)
+ assert isinstance(server_protocol_version, int)
+ assert isinstance(client_protocol_version, int)
- self.assertEqual(server_protocol_version, client_protocol_version)
+ assert server_protocol_version == client_protocol_version
+
+ def test_wantReadError(self):
+ """
+ `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are
+ no bytes available to be read from the BIO.
+ """
+ ctx = Context(TLSv1_METHOD)
+ conn = Connection(ctx, None)
+ with pytest.raises(WantReadError):
+ conn.bio_read(1024)
+
+ def test_buffer_size(self):
+ """
+ `Connection.bio_read` accepts an integer giving the maximum number
+ of bytes to read and return.
+ """
+ ctx = Context(TLSv1_METHOD)
+ conn = Connection(ctx, None)
+ conn.set_connect_state()
+ try:
+ conn.do_handshake()
+ except WantReadError:
+ pass
+ data = conn.bio_read(2)
+ assert 2 == len(data)
+
+ @skip_if_py3
+ def test_buffer_size_long(self):
+ """
+ On Python 2 `Connection.bio_read` accepts values of type `long` as
+ well as `int`.
+ """
+ ctx = Context(TLSv1_METHOD)
+ conn = Connection(ctx, None)
+ conn.set_connect_state()
+ try:
+ conn.do_handshake()
+ except WantReadError:
+ pass
+ data = conn.bio_read(long(2))
+ assert 2 == len(data)
class ConnectionGetCipherListTests(TestCase):
@@ -3618,52 +3573,6 @@ class MemoryBIOTests(TestCase, _LoopbackMixin):
self._check_client_ca_list(set_replaces_add_ca)
-class TestConnection(object):
- """
- Tests for `Connection.bio_read` and `Connection.bio_write`.
- """
- def test_wantReadError(self):
- """
- `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are
- no bytes available to be read from the BIO.
- """
- ctx = Context(TLSv1_METHOD)
- conn = Connection(ctx, None)
- with pytest.raises(WantReadError):
- conn.bio_read(1024)
-
- def test_buffer_size(self):
- """
- `Connection.bio_read` accepts an integer giving the maximum number
- of bytes to read and return.
- """
- ctx = Context(TLSv1_METHOD)
- conn = Connection(ctx, None)
- conn.set_connect_state()
- try:
- conn.do_handshake()
- except WantReadError:
- pass
- data = conn.bio_read(2)
- assert 2 == len(data)
-
- @skip_if_py3
- def test_buffer_size_long(self):
- """
- On Python 2 `Connection.bio_read` accepts values of type `long` as
- well as `int`.
- """
- ctx = Context(TLSv1_METHOD)
- conn = Connection(ctx, None)
- conn.set_connect_state()
- try:
- conn.do_handshake()
- except WantReadError:
- pass
- data = conn.bio_read(long(2))
- assert 2 == len(data)
-
-
class InfoConstantTests(TestCase):
"""
Tests for assorted constants exposed for use in info callbacks.