diff options
author | Jeff Forcier <jeff@bitprophet.org> | 2023-05-04 14:46:23 -0400 |
---|---|---|
committer | Jeff Forcier <jeff@bitprophet.org> | 2023-05-05 12:27:20 -0400 |
commit | c29274a9424b1e7ac9f3f1f5e2d714b7d47e7979 (patch) | |
tree | 0a6d61429ca41048ffdeef309fad7c3869a977c0 | |
parent | e22c5ea330814801d8487dc3da347f987bafe5ec (diff) | |
download | paramiko-c29274a9424b1e7ac9f3f1f5e2d714b7d47e7979.tar.gz |
Modernize auth tests to use shared server manager
-rw-r--r-- | tests/_util.py | 20 | ||||
-rw-r--r-- | tests/test_auth.py | 228 |
2 files changed, 65 insertions, 183 deletions
diff --git a/tests/_util.py b/tests/_util.py index 2bfe314d..eaf6aac4 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -346,6 +346,8 @@ def server( pubkeys=None, catch_error=False, transport_factory=None, + defer=False, + skip_verify=False, ): """ SSH server contextmanager for testing. @@ -368,6 +370,13 @@ def server( Necessary for connection_time exception testing. :param transport_factory: Like the same-named param in SSHClient: which Transport class to use. + :param bool defer: + Whether to defer authentication during connecting. + + This is really just shorthand for ``connect={}`` which would do roughly + the same thing. Also: this implies skip_verify=True automatically! + :param bool skip_verify: + Whether NOT to do the default "make sure auth passed" check. """ if init is None: init = {} @@ -376,7 +385,12 @@ def server( if client_init is None: client_init = {} if connect is None: - connect = dict(username="slowdive", password="pygmalion") + # No auth at all please + if defer: + connect = dict() + # Default username based auth + else: + connect = dict(username="slowdive", password="pygmalion") socks = LoopSocket() sockc = LoopSocket() sockc.link(socks) @@ -417,6 +431,10 @@ def server( yield (tc, ts, err) if catch_error else (tc, ts) + if not (catch_error or skip_verify): + assert ts.is_authenticated() + assert tc.is_authenticated() + tc.close() ts.close() socks.close() diff --git a/tests/test_auth.py b/tests/test_auth.py index 02df8c12..70ee0c36 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -20,183 +20,62 @@ Some unit tests for authenticating over a Transport. """ -import sys -import threading import unittest -from time import sleep +from pytest import raises from paramiko import ( - Transport, - ServerInterface, - RSAKey, DSSKey, BadAuthenticationType, - InteractiveQuery, AuthenticationException, ) -from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL -from paramiko.util import u -from ._loop import LoopSocket -from ._util import _support, slow - - -_pwd = u("\u2022") - - -class NullServer(ServerInterface): - paranoid_did_password = False - paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support("dss.key")) - - def get_allowed_auths(self, username): - if username == "slowdive": - return "publickey,password" - if username == "paranoid": - if ( - not self.paranoid_did_password - and not self.paranoid_did_public_key - ): - return "publickey,password" - elif self.paranoid_did_password: - return "publickey" - else: - return "password" - if username == "commie": - return "keyboard-interactive" - if username == "utf8": - return "password" - if username == "non-utf8": - return "password" - return "publickey" - - def check_auth_password(self, username, password): - if (username == "slowdive") and (password == "pygmalion"): - return AUTH_SUCCESSFUL - if (username == "paranoid") and (password == "paranoid"): - # 2-part auth (even openssh doesn't support this) - self.paranoid_did_password = True - if self.paranoid_did_public_key: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - if (username == "utf8") and (password == _pwd): - return AUTH_SUCCESSFUL - if (username == "non-utf8") and (password == "\xff"): - return AUTH_SUCCESSFUL - if username == "bad-server": - raise Exception("Ack!") - if username == "unresponsive-server": - sleep(5) - return AUTH_SUCCESSFUL - return AUTH_FAILED - - def check_auth_publickey(self, username, key): - if (username == "paranoid") and (key == self.paranoid_key): - # 2-part auth - self.paranoid_did_public_key = True - if self.paranoid_did_password: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - return AUTH_FAILED - - def check_auth_interactive(self, username, submethods): - if username == "commie": - self.username = username - return InteractiveQuery( - "password", "Please enter a password.", ("Password", False) - ) - return AUTH_FAILED - - def check_auth_interactive_response(self, responses): - if self.username == "commie": - if (len(responses) == 1) and (responses[0] == "cat"): - return AUTH_SUCCESSFUL - return AUTH_FAILED +from ._util import _support, server, unicodey class AuthTest(unittest.TestCase): - def setUp(self): - self.socks = LoopSocket() - self.sockc = LoopSocket() - self.sockc.link(self.socks) - self.tc = Transport(self.sockc) - self.ts = Transport(self.socks) - - def tearDown(self): - self.tc.close() - self.ts.close() - self.socks.close() - self.sockc.close() - - def start_server(self): - host_key = RSAKey.from_private_key_file(_support("rsa.key")) - self.public_host_key = RSAKey(data=host_key.asbytes()) - self.ts.add_server_key(host_key) - self.event = threading.Event() - self.server = NullServer() - self.assertTrue(not self.event.is_set()) - self.ts.start_server(self.event, self.server) - - def verify_finished(self): - self.event.wait(1.0) - self.assertTrue(self.event.is_set()) - self.assertTrue(self.ts.is_active()) - def test_bad_auth_type(self): """ verify that we get the right exception when an unsupported auth type is requested. """ - self.start_server() - try: - self.tc.connect( - hostkey=self.public_host_key, - username="unknown", - password="error", - ) - self.assertTrue(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEqual(BadAuthenticationType, etype) - self.assertEqual(["publickey"], evalue.allowed_types) + # Server won't allow password auth for this user, so should fail + # and return just publickey allowed types + with server( + connect=dict(username="unknown", password="error"), + catch_error=True, + ) as (_, _, err): + assert isinstance(err, BadAuthenticationType) + assert err.allowed_types == ["publickey"] def test_bad_password(self): """ verify that a bad password gets the right exception, and that a retry with the right password works. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - try: - self.tc.auth_password(username="slowdive", password="error") - self.assertTrue(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertTrue(issubclass(etype, AuthenticationException)) - self.tc.auth_password(username="slowdive", password="pygmalion") - self.verify_finished() + # NOTE: Transport.connect doesn't do any auth upfront if no userauth + # related kwargs given. + with server(defer=True) as (tc, ts): + # Auth once, badly + with raises(AuthenticationException): + tc.auth_password(username="slowdive", password="error") + # And again, correctly + tc.auth_password(username="slowdive", password="pygmalion") def test_multipart_auth(self): """ verify that multipart auth works. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password( - username="paranoid", password="paranoid" - ) - self.assertEqual(["publickey"], remain) - key = DSSKey.from_private_key_file(_support("dss.key")) - remain = self.tc.auth_publickey(username="paranoid", key=key) - self.assertEqual([], remain) - self.verify_finished() + with server(defer=True) as (tc, ts): + assert tc.auth_password( + username="paranoid", password="paranoid" + ) == ["publickey"] + key = DSSKey.from_private_key_file(_support("dss.key")) + assert tc.auth_publickey(username="paranoid", key=key) == [] def test_interactive_auth(self): """ verify keyboard-interactive auth works. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) def handler(title, instructions, prompts): self.got_title = title @@ -204,69 +83,54 @@ class AuthTest(unittest.TestCase): self.got_prompts = prompts return ["cat"] - remain = self.tc.auth_interactive("commie", handler) - self.assertEqual(self.got_title, "password") - self.assertEqual(self.got_prompts, [("Password", False)]) - self.assertEqual([], remain) - self.verify_finished() + with server(defer=True) as (tc, ts): + assert tc.auth_interactive("commie", handler) == [] + assert self.got_title == "password" + assert self.got_prompts == [("Password", False)] def test_interactive_auth_fallback(self): """ verify that a password auth attempt will fallback to "interactive" if password auth isn't supported but interactive is. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password("commie", "cat") - self.assertEqual([], remain) - self.verify_finished() + with server(defer=True) as (tc, ts): + # This username results in an allowed_auth of just kbd-int, + # and has a configured interactive->response on the server. + assert tc.auth_password("commie", "cat") == [] def test_auth_utf8(self): """ verify that utf-8 encoding happens in authentication. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password("utf8", _pwd) - self.assertEqual([], remain) - self.verify_finished() + with server(defer=True) as (tc, ts): + assert tc.auth_password("utf8", unicodey) == [] def test_auth_non_utf8(self): """ verify that non-utf-8 encoded passwords can be used for broken servers. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - remain = self.tc.auth_password("non-utf8", "\xff") - self.assertEqual([], remain) - self.verify_finished() + with server(defer=True) as (tc, ts): + assert tc.auth_password("non-utf8", "\xff") == [] def test_auth_gets_disconnected(self): """ verify that we catch a server disconnecting during auth, and report it as an auth failure. """ - self.start_server() - self.tc.connect(hostkey=self.public_host_key) - try: - self.tc.auth_password("bad-server", "hello") - except: - etype, evalue, etb = sys.exc_info() - self.assertTrue(issubclass(etype, AuthenticationException)) + with server(defer=True, skip_verify=True) as (tc, ts), raises( + AuthenticationException + ): + tc.auth_password("bad-server", "hello") - @slow def test_auth_non_responsive(self): """ verify that authentication times out if server takes to long to respond (or never responds). """ - self.tc.auth_timeout = 1 # 1 second, to speed up test - self.start_server() - self.tc.connect() - try: - self.tc.auth_password("unresponsive-server", "hello") - except: - etype, evalue, etb = sys.exc_info() - self.assertTrue(issubclass(etype, AuthenticationException)) - self.assertTrue("Authentication timeout" in str(evalue)) + with server(defer=True, skip_verify=True) as (tc, ts), raises( + AuthenticationException + ) as info: + tc.auth_timeout = 1 # 1 second, to speed up test + tc.auth_password("unresponsive-server", "hello") + assert "Authentication timeout" in str(info.value) |