# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import inspect import logging import os import platform import ssl import sys import tempfile import threading import unittest import warnings from contextlib import contextmanager import _import_local_thrift # noqa SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' class ServerAcceptor(threading.Thread): def __init__(self, server, expect_failure=False): super(ServerAcceptor, self).__init__() self.daemon = True self._server = server self._listening = threading.Event() self._port = None self._port_bound = threading.Event() self._client = None self._client_accepted = threading.Event() self._expect_failure = expect_failure frame = inspect.stack(3)[2] self.name = frame[3] del frame def run(self): self._server.listen() self._listening.set() try: address = self._server.handle.getsockname() if len(address) > 1: # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are # 4-tuples (host, port, ...), but in each case port is in the second slot. self._port = address[1] finally: self._port_bound.set() try: self._client = self._server.accept() if self._client: self._client.read(5) # hello self._client.write(b"there") except Exception: logging.exception('error on server side (%s):' % self.name) if not self._expect_failure: raise finally: self._client_accepted.set() def await_listening(self): self._listening.wait() @property def port(self): self._port_bound.wait() return self._port @property def client(self): self._client_accepted.wait() return self._client def close(self): if self._client: self._client.close() self._server.close() # Python 2.6 compat class AssertRaises(object): def __init__(self, expected): self._expected = expected def __enter__(self): pass def __exit__(self, exc_type, exc_value, traceback): if not exc_type or not issubclass(exc_type, self._expected): raise Exception('fail') return True class TSSLSocketTest(unittest.TestCase): def _server_socket(self, **kwargs): return TSSLServerSocket(port=0, **kwargs) @contextmanager def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): acc = ServerAcceptor(server, expect_failure) try: acc.start() acc.await_listening() host, port = ('localhost', acc.port) if path is None else (None, None) client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) yield acc, client finally: acc.close() def _assert_connection_failure(self, server, path=None, **client_args): logging.disable(logging.CRITICAL) try: with self._connectable_client(server, True, path=path, **client_args) as (acc, client): # We need to wait for a connection failure, but not too long. 20ms is a tunable # compromise between test speed and stability client.setTimeout(20) with self._assert_raises(TTransportException): client.open() client.write(b"hello") client.read(5) # b"there" finally: logging.disable(logging.NOTSET) def _assert_raises(self, exc): if sys.hexversion >= 0x020700F0: return self.assertRaises(exc) else: return AssertRaises(exc) def _assert_connection_success(self, server, path=None, **client_args): with self._connectable_client(server, path=path, **client_args) as (acc, client): try: self.assertFalse(client.isOpen()) client.open() self.assertTrue(client.isOpen()) client.write(b"hello") self.assertEqual(client.read(5), b"there") self.assertTrue(acc.client is not None) finally: client.close() # deprecated feature def test_deprecation(self): with warnings.catch_warnings(record=True) as w: warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) self.assertEqual(len(w), 1) with warnings.catch_warnings(record=True) as w: warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) # Deprecated signature # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) self.assertEqual(len(w), 7) with warnings.catch_warnings(record=True) as w: warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) # Deprecated signature # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) self.assertEqual(len(w), 3) # deprecated feature def test_set_cert_reqs_by_validate(self): with warnings.catch_warnings(record=True) as w: warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) c1 = TSSLSocket('localhost', 0, validate=False) self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) self.assertEqual(len(w), 2) # deprecated feature def test_set_validate_by_cert_reqs(self): with warnings.catch_warnings(record=True) as w: warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) self.assertFalse(c1.validate) c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) self.assertTrue(c2.validate) c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) self.assertTrue(c3.validate) self.assertEqual(len(w), 3) def test_unix_domain_socket(self): if platform.system() == 'Windows': print('skipping test_unix_domain_socket') return fd, path = tempfile.mkstemp() os.close(fd) os.unlink(path) try: server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) finally: os.unlink(path) def test_server_cert(self): server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) # server cert not in ca_certs self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) def test_set_server_cert(self): server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) with self._assert_raises(Exception): server.certfile = 'foo' with self._assert_raises(Exception): server.certfile = None server.certfile = SERVER_CERT self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) def test_client_cert(self): if not _match_has_ipaddress: print('skipping test_client_cert') return server = self._server_socket( cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, certfile=SERVER_CERT, ca_certs=CLIENT_CERT) self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) server = self._server_socket( cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, certfile=SERVER_CERT, ca_certs=CLIENT_CA) self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) server = self._server_socket( cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, certfile=SERVER_CERT, ca_certs=CLIENT_CA) self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) server = self._server_socket( cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, certfile=SERVER_CERT, ca_certs=CLIENT_CA) self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) def test_ciphers(self): server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) if not TSSLSocket._has_ciphers: # unittest.skip is not available for Python 2.6 print('skipping test_ciphers') return server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') def test_ssl2_and_ssl3_disabled(self): if not hasattr(ssl, 'PROTOCOL_SSLv3'): print('PROTOCOL_SSLv3 is not available') else: server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) self._assert_connection_failure(server, ca_certs=SERVER_CERT) if not hasattr(ssl, 'PROTOCOL_SSLv2'): print('PROTOCOL_SSLv2 is not available') else: server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) self._assert_connection_failure(server, ca_certs=SERVER_CERT) def test_newer_tls(self): if not TSSLSocket._has_ssl_context: # unittest.skip is not available for Python 2.6 print('skipping test_newer_tls') return if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): print('PROTOCOL_TLSv1_2 is not available') else: server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): print('PROTOCOL_TLSv1_1 is not available') else: server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') else: server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) def test_ssl_context(self): if not TSSLSocket._has_ssl_context: # unittest.skip is not available for Python 2.6 print('skipping test_ssl_context') return server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) server_context.load_verify_locations(CLIENT_CA) server_context.verify_mode = ssl.CERT_REQUIRED server = self._server_socket(ssl_context=server_context) client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) client_context.load_verify_locations(SERVER_CERT) client_context.verify_mode = ssl.CERT_REQUIRED self._assert_connection_success(server, ssl_context=client_context) if __name__ == '__main__': logging.basicConfig(level=logging.WARN) from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress from thrift.transport.TTransport import TTransportException unittest.main()