summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Paul Calderone <exarkun@twistedmatrix.com>2014-02-02 10:59:14 -0500
committerJean-Paul Calderone <exarkun@twistedmatrix.com>2014-02-02 10:59:14 -0500
commitf2bbc9cc0166a7f0db9a750b38e1a924dadb2108 (patch)
treeb4cd950d5db8f9e23939ad0eff14fcd1794a398f
parent0e26e2ccb9c2bb5564ba114499044385a83a5883 (diff)
downloadpyopenssl-f2bbc9cc0166a7f0db9a750b38e1a924dadb2108.tar.gz
Change the info callback test to at least assert that the connection argument is a Connection instance.
Fix the implementation to make the test pass.
-rw-r--r--OpenSSL/SSL.py2
-rw-r--r--OpenSSL/test/test_ssl.py20
2 files changed, 13 insertions, 9 deletions
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 81ec2e2..f21ad9d 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -699,7 +699,7 @@ class Context(object):
"""
@wraps(callback)
def wrapper(ssl, where, return_code):
- callback(self, where, return_code)
+ callback(Connection._reverse_mapping[ssl], where, return_code)
self._info_callback = _ffi.callback(
"void (*)(const SSL *, int, int)", wrapper)
_lib.SSL_CTX_set_info_callback(self._context, self._info_callback)
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index e3518c5..712009b 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -700,15 +700,19 @@ class ContextTests(TestCase, _LoopbackMixin):
serverSSL = Connection(context, server)
serverSSL.set_accept_state()
- while not called:
- for ssl in clientSSL, serverSSL:
- try:
- ssl.do_handshake()
- except WantReadError:
- pass
+ handshake(clientSSL, serverSSL)
- # Kind of lame. Just make sure it got called somehow.
- self.assertTrue(called)
+ # The callback must always be called with the connection it is a
+ # callback for as the first argument. It would probably be better to
+ # split this into separate tests for client and server side info
+ # callbacks. It would also be good to at least assert *something*
+ # about `where` and `ret`.
+ notConnections = [
+ conn for (conn, where, ret) in called
+ if not isinstance(conn, Connection)]
+ self.assertEqual(
+ [], notConnections,
+ "Some info callback arguments were not Connection instaces.")
def _load_verify_locations_test(self, *args):