From e22c5ea330814801d8487dc3da347f987bafe5ec Mon Sep 17 00:00:00 2001 From: Jeff Forcier Date: Thu, 4 May 2023 13:52:40 -0400 Subject: Start consolidating test server nonsense --- tests/_util.py | 245 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_transport.py | 198 ++++---------------------------------- 2 files changed, 264 insertions(+), 179 deletions(-) diff --git a/tests/_util.py b/tests/_util.py index 2f1c5ac2..2bfe314d 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -1,13 +1,29 @@ +from contextlib import contextmanager from os.path import dirname, realpath, join import builtins import os from pathlib import Path +import socket import struct import sys import unittest +from time import sleep +import threading import pytest +from paramiko import ( + ServerInterface, + RSAKey, + DSSKey, + AUTH_FAILED, + AUTH_PARTIALLY_SUCCESSFUL, + AUTH_SUCCESSFUL, + OPEN_SUCCEEDED, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + InteractiveQuery, + Transport, +) from paramiko.ssh_gss import GSS_AUTH_AVAILABLE from cryptography.exceptions import UnsupportedAlgorithm, _Reasons @@ -17,6 +33,8 @@ from cryptography.hazmat.primitives.asymmetric import padding, rsa tests_dir = dirname(realpath(__file__)) +from ._loop import LoopSocket + def _support(filename): base = Path(tests_dir) @@ -176,3 +194,230 @@ def sha1_signing_unsupported(): requires_sha1_signing = unittest.skipIf( sha1_signing_unsupported(), "SHA-1 signing not supported" ) + +_disable_sha2 = dict( + disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"]) +) +_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"])) +_disable_sha2_pubkey = dict( + disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"]) +) +_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"])) + + +unicodey = "\u2022" + + +class TestServer(ServerInterface): + paranoid_did_password = False + paranoid_did_public_key = False + # TODO: make this ed25519 or something else modern? (_is_ this used??) + paranoid_key = DSSKey.from_private_key_file(_support("dss.key")) + + def __init__(self, allowed_keys=None): + self.allowed_keys = allowed_keys if allowed_keys is not None else [] + + def check_channel_request(self, kind, chanid): + if kind == "bogus": + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + return OPEN_SUCCEEDED + + def check_channel_exec_request(self, channel, command): + if command != b"yes": + return False + return True + + def check_channel_shell_request(self, channel): + return True + + def check_global_request(self, kind, msg): + self._global_request = kind + # NOTE: for w/e reason, older impl of this returned False always, even + # tho that's only supposed to occur if the request cannot be served. + # For now, leaving that the default unless test supplies specific + # 'acceptable' request kind + return kind == "acceptable" + + def check_channel_x11_request( + self, + channel, + single_connection, + auth_protocol, + auth_cookie, + screen_number, + ): + self._x11_single_connection = single_connection + self._x11_auth_protocol = auth_protocol + self._x11_auth_cookie = auth_cookie + self._x11_screen_number = screen_number + return True + + def check_port_forward_request(self, addr, port): + self._listen = socket.socket() + self._listen.bind(("127.0.0.1", 0)) + self._listen.listen(1) + return self._listen.getsockname()[1] + + def cancel_port_forward_request(self, addr, port): + self._listen.close() + self._listen = None + + def check_channel_direct_tcpip_request(self, chanid, origin, destination): + self._tcpip_dest = destination + return OPEN_SUCCEEDED + + def get_allowed_auths(self, username): + if username == "slowdive": + return "publickey,password" + if username == "paranoid": + if ( + not self.paranoid_did_password + and not self.paranoid_did_public_key + ): + return "publickey,password" + elif self.paranoid_did_password: + return "publickey" + else: + return "password" + if username == "commie": + return "keyboard-interactive" + if username == "utf8": + return "password" + if username == "non-utf8": + return "password" + return "publickey" + + def check_auth_password(self, username, password): + if (username == "slowdive") and (password == "pygmalion"): + return AUTH_SUCCESSFUL + if (username == "paranoid") and (password == "paranoid"): + # 2-part auth (even openssh doesn't support this) + self.paranoid_did_password = True + if self.paranoid_did_public_key: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + if (username == "utf8") and (password == unicodey): + return AUTH_SUCCESSFUL + if (username == "non-utf8") and (password == "\xff"): + return AUTH_SUCCESSFUL + if username == "bad-server": + raise Exception("Ack!") + if username == "unresponsive-server": + sleep(5) + return AUTH_SUCCESSFUL + return AUTH_FAILED + + def check_auth_publickey(self, username, key): + if (username == "paranoid") and (key == self.paranoid_key): + # 2-part auth + self.paranoid_did_public_key = True + if self.paranoid_did_password: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + # TODO: make sure all tests incidentally using this to pass, _without + # sending a username oops_, get updated somehow - probably via server() + # default always injecting a username + elif key in self.allowed_keys: + return AUTH_SUCCESSFUL + return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + if username == "commie": + self.username = username + return InteractiveQuery( + "password", "Please enter a password.", ("Password", False) + ) + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + if self.username == "commie": + if (len(responses) == 1) and (responses[0] == "cat"): + return AUTH_SUCCESSFUL + return AUTH_FAILED + + +@contextmanager +def server( + hostkey=None, + init=None, + server_init=None, + client_init=None, + connect=None, + pubkeys=None, + catch_error=False, + transport_factory=None, +): + """ + SSH server contextmanager for testing. + + :param hostkey: + Host key to use for the server; if None, loads + ``rsa.key``. + :param init: + Default `Transport` constructor kwargs to use for both sides. + :param server_init: + Extends and/or overrides ``init`` for server transport only. + :param client_init: + Extends and/or overrides ``init`` for client transport only. + :param connect: + Kwargs to use for ``connect()`` on the client. + :param pubkeys: + List of public keys for auth. + :param catch_error: + Whether to capture connection errors & yield from contextmanager. + Necessary for connection_time exception testing. + :param transport_factory: + Like the same-named param in SSHClient: which Transport class to use. + """ + if init is None: + init = {} + if server_init is None: + server_init = {} + if client_init is None: + client_init = {} + if connect is None: + connect = dict(username="slowdive", password="pygmalion") + socks = LoopSocket() + sockc = LoopSocket() + sockc.link(socks) + if transport_factory is None: + transport_factory = Transport + tc = transport_factory(sockc, **dict(init, **client_init)) + ts = transport_factory(socks, **dict(init, **server_init)) + + if hostkey is None: + hostkey = RSAKey.from_private_key_file(_support("rsa.key")) + ts.add_server_key(hostkey) + event = threading.Event() + server = TestServer(allowed_keys=pubkeys) + assert not event.is_set() + assert not ts.is_active() + assert tc.get_username() is None + assert ts.get_username() is None + assert not tc.is_authenticated() + assert not ts.is_authenticated() + + err = None + # Trap errors and yield instead of raising right away; otherwise callers + # cannot usefully deal with problems at connect time which stem from errors + # in the server side. + try: + ts.start_server(event, server) + tc.connect(**connect) + + event.wait(1.0) + assert event.is_set() + assert ts.is_active() + assert tc.is_active() + + except Exception as e: + if not catch_error: + raise + err = e + + yield (tc, ts, err) if catch_error else (tc, ts) + + tc.close() + ts.close() + socks.close() + sockc.close() diff --git a/tests/test_transport.py b/tests/test_transport.py index d7704af6..ee00830a 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -22,8 +22,6 @@ Some unit tests for the ssh2 protocol in Transport. from binascii import hexlify -from contextlib import contextmanager -import pytest import select import socket import time @@ -35,18 +33,15 @@ from unittest.mock import Mock from paramiko import ( AuthHandler, ChannelException, - DSSKey, Packetizer, RSAKey, SSHException, AuthenticationException, IncompatiblePeer, SecurityOptions, - ServerInterface, Transport, ) -from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL -from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from paramiko import OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko.common import ( DEFAULT_MAX_PACKET_SIZE, DEFAULT_WINDOW_SIZE, @@ -61,7 +56,18 @@ from paramiko.common import ( ) from paramiko.message import Message -from ._util import needs_builtin, _support, requires_sha1_signing, slow +from ._util import ( + needs_builtin, + _support, + requires_sha1_signing, + slow, + server, + _disable_sha2, + _disable_sha1, + _disable_sha2_pubkey, + _disable_sha1_pubkey, + TestServer as NullServer, +) from ._loop import LoopSocket @@ -78,79 +84,6 @@ Maybe. """ -class NullServer(ServerInterface): - paranoid_did_password = False - paranoid_did_public_key = False - paranoid_key = DSSKey.from_private_key_file(_support("dss.key")) - - def __init__(self, allowed_keys=None): - self.allowed_keys = allowed_keys if allowed_keys is not None else [] - - def get_allowed_auths(self, username): - if username == "slowdive": - return "publickey,password" - return "publickey" - - def check_auth_password(self, username, password): - if (username == "slowdive") and (password == "pygmalion"): - return AUTH_SUCCESSFUL - return AUTH_FAILED - - def check_auth_publickey(self, username, key): - if key in self.allowed_keys: - return AUTH_SUCCESSFUL - return AUTH_FAILED - - def check_channel_request(self, kind, chanid): - if kind == "bogus": - return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED - return OPEN_SUCCEEDED - - def check_channel_exec_request(self, channel, command): - if command != b"yes": - return False - return True - - def check_channel_shell_request(self, channel): - return True - - def check_global_request(self, kind, msg): - self._global_request = kind - # NOTE: for w/e reason, older impl of this returned False always, even - # tho that's only supposed to occur if the request cannot be served. - # For now, leaving that the default unless test supplies specific - # 'acceptable' request kind - return kind == "acceptable" - - def check_channel_x11_request( - self, - channel, - single_connection, - auth_protocol, - auth_cookie, - screen_number, - ): - self._x11_single_connection = single_connection - self._x11_auth_protocol = auth_protocol - self._x11_auth_cookie = auth_cookie - self._x11_screen_number = screen_number - return True - - def check_port_forward_request(self, addr, port): - self._listen = socket.socket() - self._listen.bind(("127.0.0.1", 0)) - self._listen.listen(1) - return self._listen.getsockname()[1] - - def cancel_port_forward_request(self, addr, port): - self._listen.close() - self._listen = None - - def check_channel_direct_tcpip_request(self, chanid, origin, destination): - self._tcpip_dest = destination - return OPEN_SUCCEEDED - - class TransportTest(unittest.TestCase): def setUp(self): self.socks = LoopSocket() @@ -1190,103 +1123,6 @@ class AlgorithmDisablingTests(unittest.TestCase): assert "zlib" not in compressions -@contextmanager -def server( - hostkey=None, - init=None, - server_init=None, - client_init=None, - connect=None, - pubkeys=None, - catch_error=False, - transport_factory=None, -): - """ - SSH server contextmanager for testing. - - :param hostkey: - Host key to use for the server; if None, loads - ``rsa.key``. - :param init: - Default `Transport` constructor kwargs to use for both sides. - :param server_init: - Extends and/or overrides ``init`` for server transport only. - :param client_init: - Extends and/or overrides ``init`` for client transport only. - :param connect: - Kwargs to use for ``connect()`` on the client. - :param pubkeys: - List of public keys for auth. - :param catch_error: - Whether to capture connection errors & yield from contextmanager. - Necessary for connection_time exception testing. - :param transport_factory: - Like the same-named param in SSHClient: which Transport class to use. - """ - if init is None: - init = {} - if server_init is None: - server_init = {} - if client_init is None: - client_init = {} - if connect is None: - connect = dict(username="slowdive", password="pygmalion") - socks = LoopSocket() - sockc = LoopSocket() - sockc.link(socks) - if transport_factory is None: - transport_factory = Transport - tc = transport_factory(sockc, **dict(init, **client_init)) - ts = transport_factory(socks, **dict(init, **server_init)) - - if hostkey is None: - hostkey = RSAKey.from_private_key_file(_support("rsa.key")) - ts.add_server_key(hostkey) - event = threading.Event() - server = NullServer(allowed_keys=pubkeys) - assert not event.is_set() - assert not ts.is_active() - assert tc.get_username() is None - assert ts.get_username() is None - assert not tc.is_authenticated() - assert not ts.is_authenticated() - - err = None - # Trap errors and yield instead of raising right away; otherwise callers - # cannot usefully deal with problems at connect time which stem from errors - # in the server side. - try: - ts.start_server(event, server) - tc.connect(**connect) - - event.wait(1.0) - assert event.is_set() - assert ts.is_active() - assert tc.is_active() - - except Exception as e: - if not catch_error: - raise - err = e - - yield (tc, ts, err) if catch_error else (tc, ts) - - tc.close() - ts.close() - socks.close() - sockc.close() - - -_disable_sha2 = dict( - disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"]) -) -_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"])) -_disable_sha2_pubkey = dict( - disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"]) -) -_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"])) - - class TestSHA2SignatureKeyExchange(unittest.TestCase): # NOTE: these all rely on the default server() hostkey being RSA # NOTE: these rely on both sides being properly implemented re: agreed-upon @@ -1351,7 +1187,10 @@ class TestSHA2SignatureKeyExchange(unittest.TestCase): # the entire preferred-hostkeys structure when given an explicit key as # a client.) hostkey = RSAKey.from_private_key_file(_support("rsa.key")) - with server(hostkey=hostkey, connect=dict(hostkey=hostkey)) as (tc, _): + connect = dict( + hostkey=hostkey, username="slowdive", password="pygmalion" + ) + with server(hostkey=hostkey, connect=connect) as (tc, _): assert tc.host_key_type == "rsa-sha2-512" @@ -1442,7 +1281,7 @@ class TestSHA2SignaturePubkeys(unittest.TestCase): server_init = dict(_disable_sha2_pubkey, server_sig_algs=False) with server( pubkeys=[privkey], - connect=dict(pkey=privkey), + connect=dict(username="slowdive", pkey=privkey), server_init=server_init, catch_error=True, ) as (tc, ts, err): @@ -1455,6 +1294,7 @@ class TestSHA2SignaturePubkeys(unittest.TestCase): privkey = RSAKey.from_private_key_file(_support("rsa.key")) with server( pubkeys=[privkey], + # TODO: why is this passing without a username? connect=dict(pkey=privkey), init=dict( disabled_algorithms=dict(pubkeys=["ssh-rsa", "rsa-sha2-256"]) -- cgit v1.2.1