diff options
author | Nobuaki Sukegawa <nsuke@apache.org> | 2016-09-04 18:49:23 +0900 |
---|---|---|
committer | Nobuaki Sukegawa <nsuke@apache.org> | 2016-09-04 18:49:23 +0900 |
commit | d2b4f248368be36ff24c5a54fa4f8cfb86b7ab36 (patch) | |
tree | a5ca568e5bd4b6222b9055f88979e63ef173c354 | |
parent | d4eecda6b2f8b3b27a191605a054aa3bf79a4684 (diff) | |
download | thrift-d2b4f248368be36ff24c5a54fa4f8cfb86b7ab36.tar.gz |
THRIFT-3917 Check backports.ssl_match_hostname module version
This closes #1076
-rw-r--r-- | lib/py/src/transport/sslcompat.py | 43 | ||||
-rw-r--r-- | lib/py/test/test_sslsocket.py | 16 |
2 files changed, 39 insertions, 20 deletions
diff --git a/lib/py/src/transport/sslcompat.py b/lib/py/src/transport/sslcompat.py index 19cfacae6..7bf5e0667 100644 --- a/lib/py/src/transport/sslcompat.py +++ b/lib/py/src/transport/sslcompat.py @@ -17,10 +17,13 @@ # under the License. # +import logging import sys from thrift.transport.TTransport import TTransportException +logger = logging.getLogger(__name__) + def legacy_validate_callback(self, cert, hostname): """legacy method to validate the peer's SSL certificate, and to check @@ -61,20 +64,36 @@ def legacy_validate_callback(self, cert, hostname): % (hostname, cert)) -try: - import ipaddress # noqa - _match_has_ipaddress = True -except ImportError: - _match_has_ipaddress = False +def _optional_dependencies(): + try: + import ipaddress # noqa + logger.debug('ipaddress module is available') + ipaddr = True + except ImportError: + logger.warn('ipaddress module is unavailable') + ipaddr = False -try: - from backports.ssl_match_hostname import match_hostname - _match_hostname = match_hostname -except ImportError: if sys.hexversion < 0x030500F0: - _match_has_ipaddress = False + try: + from backports.ssl_match_hostname import match_hostname, __version__ as ver + ver = list(map(int, ver.split('.'))) + logger.debug('backports.ssl_match_hostname module is available') + match = match_hostname + if ver[0] * 10 + ver[1] >= 35: + return ipaddr, match + else: + logger.warn('backports.ssl_match_hostname module is too old') + ipaddr = False + except ImportError: + logger.warn('backports.ssl_match_hostname is unavailable') + ipaddr = False try: from ssl import match_hostname - _match_hostname = match_hostname + logger.debug('ssl.match_hostname is available') + match = match_hostname except ImportError: - _match_hostname = legacy_validate_callback + logger.warn('using legacy validation callback') + match = legacy_validate_callback + return ipaddr, match + +_match_has_ipaddress, _match_hostname = _optional_dependencies() diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py index 93d34d073..3e4b266c3 100644 --- a/lib/py/test/test_sslsocket.py +++ b/lib/py/test/test_sslsocket.py @@ -30,8 +30,6 @@ import warnings from contextlib import contextmanager import _import_local_thrift # noqa -from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket -from thrift.transport.TTransport import TTransportException SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) @@ -133,16 +131,16 @@ class TSSLSocketTest(unittest.TestCase): def _assert_connection_failure(self, server, path=None, **client_args): logging.disable(logging.CRITICAL) - with self._connectable_client(server, True, path=path, **client_args) as (acc, client): - try: + try: + with self._connectable_client(server, True, path=path, **client_args) as (acc, client): # We need to wait for a connection failure, but not too long. 20ms is a tunable # compromise between test speed and stability client.setTimeout(20) with self._assert_raises(TTransportException): client.open() self.assertTrue(acc.client is None) - finally: - logging.disable(logging.NOTSET) + finally: + logging.disable(logging.NOTSET) def _assert_raises(self, exc): if sys.hexversion >= 0x020700F0: @@ -334,6 +332,8 @@ class TSSLSocketTest(unittest.TestCase): self._assert_connection_success(server, ssl_context=client_context) if __name__ == '__main__': - # import logging - # logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.WARN) + from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket + from thrift.transport.TTransport import TTransportException + unittest.main() |