diff options
-rw-r--r-- | tests/test_ssl.py | 58 |
1 files changed, 38 insertions, 20 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py index ed911de..362da5c 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -10,9 +10,10 @@ import sys import uuid from gc import collect, get_referrers -from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN +from errno import ( + EAFNOSUPPORT, ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN) from sys import platform, getfilesystemencoding -from socket import MSG_PEEK, SHUT_RDWR, error, socket +from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket from os import makedirs from os.path import join from weakref import ref @@ -101,6 +102,23 @@ V7H54LmltOT/hEh6QWsJqb6BQgH65bswvV/XkYGja8/T0GzvbaVzAgEC skip_if_py3 = pytest.mark.skipif(PY3, reason="Python 2 only") +def socket_any_family(): + try: + return socket(AF_INET) + except error as e: + if e.errno == EAFNOSUPPORT: + return socket(AF_INET6) + raise + + +def loopback_address(socket): + if socket.family == AF_INET: + return "127.0.0.1" + else: + assert socket.family == AF_INET6 + return "::1" + + def join_bytes_or_unicode(prefix, suffix): """ Join two path components of either ``bytes`` or ``unicode``. @@ -127,12 +145,12 @@ def socket_pair(): Establish and return a pair of network sockets connected to each other. """ # Connect a pair of sockets - port = socket() + port = socket_any_family() port.bind(('', 0)) port.listen(1) - client = socket() + client = socket(port.family) client.setblocking(False) - client.connect_ex(("127.0.0.1", port.getsockname()[1])) + client.connect_ex((loopback_address(port), port.getsockname()[1])) client.setblocking(True) server = port.accept()[0] @@ -1209,7 +1227,7 @@ class TestContext(object): VERIFY_PEER, lambda conn, cert, errno, depth, preverify_ok: preverify_ok) - client = socket() + client = socket_any_family() client.connect(("encrypted.google.com", 443)) clientSSL = Connection(context, client) clientSSL.set_connect_state() @@ -2237,7 +2255,7 @@ class TestConnection(object): `Connection.connect` raises `TypeError` if called with a non-address argument. """ - connection = Connection(Context(TLSv1_METHOD), socket()) + connection = Connection(Context(TLSv1_METHOD), socket_any_family()) with pytest.raises(TypeError): connection.connect(None) @@ -2246,13 +2264,13 @@ class TestConnection(object): `Connection.connect` raises `socket.error` if the underlying socket connect method raises it. """ - client = socket() + client = socket_any_family() context = Context(TLSv1_METHOD) clientSSL = Connection(context, client) # pytest.raises here doesn't work because of a bug in py.test on Python # 2.6: https://github.com/pytest-dev/pytest/issues/988 try: - clientSSL.connect(("127.0.0.1", 1)) + clientSSL.connect((loopback_address(client), 1)) except error as e: exc = e assert exc.args[0] == ECONNREFUSED @@ -2261,12 +2279,12 @@ class TestConnection(object): """ `Connection.connect` establishes a connection to the specified address. """ - port = socket() + port = socket_any_family() port.bind(('', 0)) port.listen(3) - clientSSL = Connection(Context(TLSv1_METHOD), socket()) - clientSSL.connect(('127.0.0.1', port.getsockname()[1])) + clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family)) + clientSSL.connect((loopback_address(port), port.getsockname()[1])) # XXX An assertion? Or something? @pytest.mark.skipif( @@ -2278,11 +2296,11 @@ class TestConnection(object): If there is a connection error, `Connection.connect_ex` returns the errno instead of raising an exception. """ - port = socket() + port = socket_any_family() port.bind(('', 0)) port.listen(3) - clientSSL = Connection(Context(TLSv1_METHOD), socket()) + clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family)) clientSSL.setblocking(False) result = clientSSL.connect_ex(port.getsockname()) expected = (EINPROGRESS, EWOULDBLOCK) @@ -2297,16 +2315,16 @@ class TestConnection(object): ctx = Context(TLSv1_METHOD) ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) - port = socket() + port = socket_any_family() portSSL = Connection(ctx, port) portSSL.bind(('', 0)) portSSL.listen(3) - clientSSL = Connection(Context(TLSv1_METHOD), socket()) + clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family)) # Calling portSSL.getsockname() here to get the server IP address # sounds great, but frequently fails on Windows. - clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1])) + clientSSL.connect((loopback_address(port), portSSL.getsockname()[1])) serverSSL, address = portSSL.accept() @@ -2379,7 +2397,7 @@ class TestConnection(object): `Connection.set_shutdown` sets the state of the SSL connection shutdown process. """ - connection = Connection(Context(TLSv1_METHOD), socket()) + connection = Connection(Context(TLSv1_METHOD), socket_any_family()) connection.set_shutdown(RECEIVED_SHUTDOWN) assert connection.get_shutdown() == RECEIVED_SHUTDOWN @@ -2389,7 +2407,7 @@ class TestConnection(object): On Python 2 `Connection.set_shutdown` accepts an argument of type `long` as well as `int`. """ - connection = Connection(Context(TLSv1_METHOD), socket()) + connection = Connection(Context(TLSv1_METHOD), socket_any_family()) connection.set_shutdown(long(RECEIVED_SHUTDOWN)) assert connection.get_shutdown() == RECEIVED_SHUTDOWN @@ -3503,7 +3521,7 @@ class TestMemoryBIO(object): work on `OpenSSL.SSL.Connection`() that use sockets. """ context = Context(TLSv1_METHOD) - client = socket() + client = socket_any_family() clientSSL = Connection(context, client) with pytest.raises(TypeError): clientSSL.bio_read(100) |