summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Forcier <jeff@bitprophet.org>2023-05-04 14:46:23 -0400
committerJeff Forcier <jeff@bitprophet.org>2023-05-05 12:27:20 -0400
commitc29274a9424b1e7ac9f3f1f5e2d714b7d47e7979 (patch)
tree0a6d61429ca41048ffdeef309fad7c3869a977c0
parente22c5ea330814801d8487dc3da347f987bafe5ec (diff)
downloadparamiko-c29274a9424b1e7ac9f3f1f5e2d714b7d47e7979.tar.gz
Modernize auth tests to use shared server manager
-rw-r--r--tests/_util.py20
-rw-r--r--tests/test_auth.py228
2 files changed, 65 insertions, 183 deletions
diff --git a/tests/_util.py b/tests/_util.py
index 2bfe314d..eaf6aac4 100644
--- a/tests/_util.py
+++ b/tests/_util.py
@@ -346,6 +346,8 @@ def server(
pubkeys=None,
catch_error=False,
transport_factory=None,
+ defer=False,
+ skip_verify=False,
):
"""
SSH server contextmanager for testing.
@@ -368,6 +370,13 @@ def server(
Necessary for connection_time exception testing.
:param transport_factory:
Like the same-named param in SSHClient: which Transport class to use.
+ :param bool defer:
+ Whether to defer authentication during connecting.
+
+ This is really just shorthand for ``connect={}`` which would do roughly
+ the same thing. Also: this implies skip_verify=True automatically!
+ :param bool skip_verify:
+ Whether NOT to do the default "make sure auth passed" check.
"""
if init is None:
init = {}
@@ -376,7 +385,12 @@ def server(
if client_init is None:
client_init = {}
if connect is None:
- connect = dict(username="slowdive", password="pygmalion")
+ # No auth at all please
+ if defer:
+ connect = dict()
+ # Default username based auth
+ else:
+ connect = dict(username="slowdive", password="pygmalion")
socks = LoopSocket()
sockc = LoopSocket()
sockc.link(socks)
@@ -417,6 +431,10 @@ def server(
yield (tc, ts, err) if catch_error else (tc, ts)
+ if not (catch_error or skip_verify):
+ assert ts.is_authenticated()
+ assert tc.is_authenticated()
+
tc.close()
ts.close()
socks.close()
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 02df8c12..70ee0c36 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -20,183 +20,62 @@
Some unit tests for authenticating over a Transport.
"""
-import sys
-import threading
import unittest
-from time import sleep
+from pytest import raises
from paramiko import (
- Transport,
- ServerInterface,
- RSAKey,
DSSKey,
BadAuthenticationType,
- InteractiveQuery,
AuthenticationException,
)
-from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
-from paramiko.util import u
-from ._loop import LoopSocket
-from ._util import _support, slow
-
-
-_pwd = u("\u2022")
-
-
-class NullServer(ServerInterface):
- paranoid_did_password = False
- paranoid_did_public_key = False
- paranoid_key = DSSKey.from_private_key_file(_support("dss.key"))
-
- def get_allowed_auths(self, username):
- if username == "slowdive":
- return "publickey,password"
- if username == "paranoid":
- if (
- not self.paranoid_did_password
- and not self.paranoid_did_public_key
- ):
- return "publickey,password"
- elif self.paranoid_did_password:
- return "publickey"
- else:
- return "password"
- if username == "commie":
- return "keyboard-interactive"
- if username == "utf8":
- return "password"
- if username == "non-utf8":
- return "password"
- return "publickey"
-
- def check_auth_password(self, username, password):
- if (username == "slowdive") and (password == "pygmalion"):
- return AUTH_SUCCESSFUL
- if (username == "paranoid") and (password == "paranoid"):
- # 2-part auth (even openssh doesn't support this)
- self.paranoid_did_password = True
- if self.paranoid_did_public_key:
- return AUTH_SUCCESSFUL
- return AUTH_PARTIALLY_SUCCESSFUL
- if (username == "utf8") and (password == _pwd):
- return AUTH_SUCCESSFUL
- if (username == "non-utf8") and (password == "\xff"):
- return AUTH_SUCCESSFUL
- if username == "bad-server":
- raise Exception("Ack!")
- if username == "unresponsive-server":
- sleep(5)
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_publickey(self, username, key):
- if (username == "paranoid") and (key == self.paranoid_key):
- # 2-part auth
- self.paranoid_did_public_key = True
- if self.paranoid_did_password:
- return AUTH_SUCCESSFUL
- return AUTH_PARTIALLY_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_interactive(self, username, submethods):
- if username == "commie":
- self.username = username
- return InteractiveQuery(
- "password", "Please enter a password.", ("Password", False)
- )
- return AUTH_FAILED
-
- def check_auth_interactive_response(self, responses):
- if self.username == "commie":
- if (len(responses) == 1) and (responses[0] == "cat"):
- return AUTH_SUCCESSFUL
- return AUTH_FAILED
+from ._util import _support, server, unicodey
class AuthTest(unittest.TestCase):
- def setUp(self):
- self.socks = LoopSocket()
- self.sockc = LoopSocket()
- self.sockc.link(self.socks)
- self.tc = Transport(self.sockc)
- self.ts = Transport(self.socks)
-
- def tearDown(self):
- self.tc.close()
- self.ts.close()
- self.socks.close()
- self.sockc.close()
-
- def start_server(self):
- host_key = RSAKey.from_private_key_file(_support("rsa.key"))
- self.public_host_key = RSAKey(data=host_key.asbytes())
- self.ts.add_server_key(host_key)
- self.event = threading.Event()
- self.server = NullServer()
- self.assertTrue(not self.event.is_set())
- self.ts.start_server(self.event, self.server)
-
- def verify_finished(self):
- self.event.wait(1.0)
- self.assertTrue(self.event.is_set())
- self.assertTrue(self.ts.is_active())
-
def test_bad_auth_type(self):
"""
verify that we get the right exception when an unsupported auth
type is requested.
"""
- self.start_server()
- try:
- self.tc.connect(
- hostkey=self.public_host_key,
- username="unknown",
- password="error",
- )
- self.assertTrue(False)
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertEqual(BadAuthenticationType, etype)
- self.assertEqual(["publickey"], evalue.allowed_types)
+ # Server won't allow password auth for this user, so should fail
+ # and return just publickey allowed types
+ with server(
+ connect=dict(username="unknown", password="error"),
+ catch_error=True,
+ ) as (_, _, err):
+ assert isinstance(err, BadAuthenticationType)
+ assert err.allowed_types == ["publickey"]
def test_bad_password(self):
"""
verify that a bad password gets the right exception, and that a retry
with the right password works.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- try:
- self.tc.auth_password(username="slowdive", password="error")
- self.assertTrue(False)
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertTrue(issubclass(etype, AuthenticationException))
- self.tc.auth_password(username="slowdive", password="pygmalion")
- self.verify_finished()
+ # NOTE: Transport.connect doesn't do any auth upfront if no userauth
+ # related kwargs given.
+ with server(defer=True) as (tc, ts):
+ # Auth once, badly
+ with raises(AuthenticationException):
+ tc.auth_password(username="slowdive", password="error")
+ # And again, correctly
+ tc.auth_password(username="slowdive", password="pygmalion")
def test_multipart_auth(self):
"""
verify that multipart auth works.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password(
- username="paranoid", password="paranoid"
- )
- self.assertEqual(["publickey"], remain)
- key = DSSKey.from_private_key_file(_support("dss.key"))
- remain = self.tc.auth_publickey(username="paranoid", key=key)
- self.assertEqual([], remain)
- self.verify_finished()
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_password(
+ username="paranoid", password="paranoid"
+ ) == ["publickey"]
+ key = DSSKey.from_private_key_file(_support("dss.key"))
+ assert tc.auth_publickey(username="paranoid", key=key) == []
def test_interactive_auth(self):
"""
verify keyboard-interactive auth works.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
def handler(title, instructions, prompts):
self.got_title = title
@@ -204,69 +83,54 @@ class AuthTest(unittest.TestCase):
self.got_prompts = prompts
return ["cat"]
- remain = self.tc.auth_interactive("commie", handler)
- self.assertEqual(self.got_title, "password")
- self.assertEqual(self.got_prompts, [("Password", False)])
- self.assertEqual([], remain)
- self.verify_finished()
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_interactive("commie", handler) == []
+ assert self.got_title == "password"
+ assert self.got_prompts == [("Password", False)]
def test_interactive_auth_fallback(self):
"""
verify that a password auth attempt will fallback to "interactive"
if password auth isn't supported but interactive is.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password("commie", "cat")
- self.assertEqual([], remain)
- self.verify_finished()
+ with server(defer=True) as (tc, ts):
+ # This username results in an allowed_auth of just kbd-int,
+ # and has a configured interactive->response on the server.
+ assert tc.auth_password("commie", "cat") == []
def test_auth_utf8(self):
"""
verify that utf-8 encoding happens in authentication.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password("utf8", _pwd)
- self.assertEqual([], remain)
- self.verify_finished()
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_password("utf8", unicodey) == []
def test_auth_non_utf8(self):
"""
verify that non-utf-8 encoded passwords can be used for broken
servers.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- remain = self.tc.auth_password("non-utf8", "\xff")
- self.assertEqual([], remain)
- self.verify_finished()
+ with server(defer=True) as (tc, ts):
+ assert tc.auth_password("non-utf8", "\xff") == []
def test_auth_gets_disconnected(self):
"""
verify that we catch a server disconnecting during auth, and report
it as an auth failure.
"""
- self.start_server()
- self.tc.connect(hostkey=self.public_host_key)
- try:
- self.tc.auth_password("bad-server", "hello")
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertTrue(issubclass(etype, AuthenticationException))
+ with server(defer=True, skip_verify=True) as (tc, ts), raises(
+ AuthenticationException
+ ):
+ tc.auth_password("bad-server", "hello")
- @slow
def test_auth_non_responsive(self):
"""
verify that authentication times out if server takes to long to
respond (or never responds).
"""
- self.tc.auth_timeout = 1 # 1 second, to speed up test
- self.start_server()
- self.tc.connect()
- try:
- self.tc.auth_password("unresponsive-server", "hello")
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertTrue(issubclass(etype, AuthenticationException))
- self.assertTrue("Authentication timeout" in str(evalue))
+ with server(defer=True, skip_verify=True) as (tc, ts), raises(
+ AuthenticationException
+ ) as info:
+ tc.auth_timeout = 1 # 1 second, to speed up test
+ tc.auth_password("unresponsive-server", "hello")
+ assert "Authentication timeout" in str(info.value)