diff options
Diffstat (limited to 'Lib/test/test_ftplib.py')
-rw-r--r-- | Lib/test/test_ftplib.py | 433 |
1 files changed, 394 insertions, 39 deletions
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index fa1079f788..9d2eab7fa3 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -1,17 +1,25 @@ """Test script for ftplib module.""" -# Modified by Giampaolo Rodola' to test FTP class and IPv6 environment +# Modified by Giampaolo Rodola' to test FTP class, IPv6 and TLS +# environment import ftplib -import threading import asyncore import asynchat import socket import io +import errno +import os +import time +try: + import ssl +except ImportError: + ssl = None from unittest import TestCase from test import support from test.support import HOST +threading = support.import_module('threading') # the dummy data returned by server over the data channel when # RETR, LIST and NLST commands are issued @@ -21,6 +29,7 @@ NLST_DATA = 'foo\r\nbar\r\n' class DummyDTPHandler(asynchat.async_chat): + dtp_conn_closed = False def __init__(self, conn, baseclass): asynchat.async_chat.__init__(self, conn) @@ -31,15 +40,25 @@ class DummyDTPHandler(asynchat.async_chat): self.baseclass.last_received_data += self.recv(1024).decode('ascii') def handle_close(self): - self.baseclass.push('226 transfer complete') - self.close() + # XXX: this method can be called many times in a row for a single + # connection, including in clear-text (non-TLS) mode. + # (behaviour witnessed with test_data_connection) + if not self.dtp_conn_closed: + self.baseclass.push('226 transfer complete') + self.close() + self.dtp_conn_closed = True def push(self, what): super(DummyDTPHandler, self).push(what.encode('ascii')) + def handle_error(self): + raise + class DummyFTPHandler(asynchat.async_chat): + dtp_handler = DummyDTPHandler + def __init__(self, conn): asynchat.async_chat.__init__(self, conn) self.set_terminator(b"\r\n") @@ -48,6 +67,7 @@ class DummyFTPHandler(asynchat.async_chat): self.last_received_cmd = None self.last_received_data = '' self.next_response = '' + self.rest = None self.push('220 welcome') def collect_incoming_data(self, data): @@ -83,41 +103,44 @@ class DummyFTPHandler(asynchat.async_chat): ip = '%d.%d.%d.%d' %tuple(addr[:4]) port = (addr[4] * 256) + addr[5] s = socket.create_connection((ip, port), timeout=10) - self.dtp = DummyDTPHandler(s, baseclass=self) + self.dtp = self.dtp_handler(s, baseclass=self) self.push('200 active data connection established') def cmd_pasv(self, arg): - sock = socket.socket() - sock.bind((self.socket.getsockname()[0], 0)) - sock.listen(5) - sock.settimeout(10) - ip, port = sock.getsockname()[:2] - ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 - self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2)) - conn, addr = sock.accept() - self.dtp = DummyDTPHandler(conn, baseclass=self) + with socket.socket() as sock: + sock.bind((self.socket.getsockname()[0], 0)) + sock.listen(5) + sock.settimeout(10) + ip, port = sock.getsockname()[:2] + ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 + self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2)) + conn, addr = sock.accept() + self.dtp = self.dtp_handler(conn, baseclass=self) def cmd_eprt(self, arg): af, ip, port = arg.split(arg[0])[1:-1] port = int(port) s = socket.create_connection((ip, port), timeout=10) - self.dtp = DummyDTPHandler(s, baseclass=self) + self.dtp = self.dtp_handler(s, baseclass=self) self.push('200 active data connection established') def cmd_epsv(self, arg): - sock = socket.socket(socket.AF_INET6) - sock.bind((self.socket.getsockname()[0], 0)) - sock.listen(5) - sock.settimeout(10) - port = sock.getsockname()[1] - self.push('229 entering extended passive mode (|||%d|)' %port) - conn, addr = sock.accept() - self.dtp = DummyDTPHandler(conn, baseclass=self) + with socket.socket(socket.AF_INET6) as sock: + sock.bind((self.socket.getsockname()[0], 0)) + sock.listen(5) + sock.settimeout(10) + port = sock.getsockname()[1] + self.push('229 entering extended passive mode (|||%d|)' %port) + conn, addr = sock.accept() + self.dtp = self.dtp_handler(conn, baseclass=self) def cmd_echo(self, arg): # sends back the received string (used by the test suite) self.push(arg) + def cmd_noop(self, arg): + self.push('200 noop ok') + def cmd_user(self, arg): self.push('331 username ok') @@ -161,10 +184,19 @@ class DummyFTPHandler(asynchat.async_chat): def cmd_stor(self, arg): self.push('125 stor ok') + def cmd_rest(self, arg): + self.rest = arg + self.push('350 rest ok') + def cmd_retr(self, arg): self.push('125 retr ok') - self.dtp.push(RETR_DATA) + if self.rest is not None: + offset = int(self.rest) + else: + offset = 0 + self.dtp.push(RETR_DATA[offset:]) self.dtp.close_when_done() + self.rest = None def cmd_list(self, arg): self.push('125 list ok') @@ -190,6 +222,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): self.active = False self.active_lock = threading.Lock() self.host, self.port = self.socket.getsockname()[:2] + self.handler_instance = None def start(self): assert not self.active @@ -211,10 +244,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): self.active = False self.join() - def handle_accept(self): - conn, addr = self.accept() - self.handler = self.handler(conn) - self.close() + def handle_accepted(self, conn, addr): + self.handler_instance = self.handler(conn) def handle_connect(self): self.close() @@ -227,6 +258,154 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): raise +if ssl is not None: + + CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem") + + class SSLConnection(asyncore.dispatcher): + """An asyncore.dispatcher subclass supporting TLS/SSL.""" + + _ssl_accepting = False + _ssl_closing = False + + def secure_connection(self): + self.del_channel() + socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False, + certfile=CERTFILE, server_side=True, + do_handshake_on_connect=False, + ssl_version=ssl.PROTOCOL_SSLv23) + self.set_socket(socket) + self._ssl_accepting = True + + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + raise + except socket.error as err: + if err.args[0] == errno.ECONNABORTED: + return self.handle_close() + else: + self._ssl_accepting = False + + def _do_ssl_shutdown(self): + self._ssl_closing = True + try: + self.socket = self.socket.unwrap() + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return + except socket.error as err: + # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return + # from OpenSSL's SSL_shutdown(), corresponding to a + # closed socket condition. See also: + # http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html + pass + self._ssl_closing = False + super(SSLConnection, self).close() + + def handle_read_event(self): + if self._ssl_accepting: + self._do_ssl_handshake() + elif self._ssl_closing: + self._do_ssl_shutdown() + else: + super(SSLConnection, self).handle_read_event() + + def handle_write_event(self): + if self._ssl_accepting: + self._do_ssl_handshake() + elif self._ssl_closing: + self._do_ssl_shutdown() + else: + super(SSLConnection, self).handle_write_event() + + def send(self, data): + try: + return super(SSLConnection, self).send(data) + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN, + ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return 0 + raise + + def recv(self, buffer_size): + try: + return super(SSLConnection, self).recv(buffer_size) + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return b'' + if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): + self.handle_close() + return b'' + raise + + def handle_error(self): + raise + + def close(self): + if (isinstance(self.socket, ssl.SSLSocket) and + self.socket._sslobj is not None): + self._do_ssl_shutdown() + else: + super(SSLConnection, self).close() + + + class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler): + """A DummyDTPHandler subclass supporting TLS/SSL.""" + + def __init__(self, conn, baseclass): + DummyDTPHandler.__init__(self, conn, baseclass) + if self.baseclass.secure_data_channel: + self.secure_connection() + + + class DummyTLS_FTPHandler(SSLConnection, DummyFTPHandler): + """A DummyFTPHandler subclass supporting TLS/SSL.""" + + dtp_handler = DummyTLS_DTPHandler + + def __init__(self, conn): + DummyFTPHandler.__init__(self, conn) + self.secure_data_channel = False + + def cmd_auth(self, line): + """Set up secure control channel.""" + self.push('234 AUTH TLS successful') + self.secure_connection() + + def cmd_pbsz(self, line): + """Negotiate size of buffer for secure data transfer. + For TLS/SSL the only valid value for the parameter is '0'. + Any other value is accepted but ignored. + """ + self.push('200 PBSZ=0 successful.') + + def cmd_prot(self, line): + """Setup un/secure data channel.""" + arg = line.upper() + if arg == 'C': + self.push('200 Protection set to Clear') + self.secure_data_channel = False + elif arg == 'P': + self.push('200 Protection set to Private') + self.secure_data_channel = True + else: + self.push("502 Unrecognized PROT type (use C or P).") + + + class DummyTLS_FTPServer(DummyFTPServer): + handler = DummyTLS_FTPHandler + + class TestFTPClass(TestCase): def setUp(self): @@ -285,12 +464,12 @@ class TestFTPClass(TestCase): def test_rename(self): self.client.rename('a', 'b') - self.server.handler.next_response = '200' + self.server.handler_instance.next_response = '200' self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b') def test_delete(self): self.client.delete('foo') - self.server.handler.next_response = '199' + self.server.handler_instance.next_response = '199' self.assertRaises(ftplib.error_reply, self.client.delete, 'foo') def test_size(self): @@ -319,6 +498,17 @@ class TestFTPClass(TestCase): self.client.retrbinary('retr', callback) self.assertEqual(''.join(received), RETR_DATA) + def test_retrbinary_rest(self): + def callback(data): + received.append(data.decode('ascii')) + for rest in (0, 10, 20): + received = [] + self.client.retrbinary('retr', callback, rest=rest) + self.assertEqual(''.join(received), RETR_DATA[rest:], + msg='rest test case %d %d %d' % (rest, + len(''.join(received)), + len(RETR_DATA[rest:]))) + def test_retrlines(self): received = [] self.client.retrlines('retr', received.append) @@ -327,17 +517,24 @@ class TestFTPClass(TestCase): def test_storbinary(self): f = io.BytesIO(RETR_DATA.encode('ascii')) self.client.storbinary('stor', f) - self.assertEqual(self.server.handler.last_received_data, RETR_DATA) + self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg flag = [] f.seek(0) self.client.storbinary('stor', f, callback=lambda x: flag.append(None)) self.assertTrue(flag) + def test_storbinary_rest(self): + f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) + for r in (30, '30'): + f.seek(0) + self.client.storbinary('stor', f, rest=r) + self.assertEqual(self.server.handler_instance.rest, str(r)) + def test_storlines(self): f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) self.client.storlines('stor', f) - self.assertEqual(self.server.handler.last_received_data, RETR_DATA) + self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg flag = [] f.seek(0) @@ -354,16 +551,74 @@ class TestFTPClass(TestCase): self.assertEqual(''.join(l), LIST_DATA.replace('\r\n', '')) def test_makeport(self): - self.client.makeport() - # IPv4 is in use, just make sure send_eprt has not been used - self.assertEqual(self.server.handler.last_received_cmd, 'port') + with self.client.makeport(): + # IPv4 is in use, just make sure send_eprt has not been used + self.assertEqual(self.server.handler_instance.last_received_cmd, + 'port') def test_makepasv(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), 10) conn.close() # IPv4 is in use, just make sure send_epsv has not been used - self.assertEqual(self.server.handler.last_received_cmd, 'pasv') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') + + def test_with_statement(self): + self.client.quit() + + def is_client_connected(): + if self.client.sock is None: + return False + try: + self.client.sendcmd('noop') + except (socket.error, EOFError): + return False + return True + + # base test + with ftplib.FTP(timeout=10) as self.client: + self.client.connect(self.server.host, self.server.port) + self.client.sendcmd('noop') + self.assertTrue(is_client_connected()) + self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit') + self.assertFalse(is_client_connected()) + + # QUIT sent inside the with block + with ftplib.FTP(timeout=10) as self.client: + self.client.connect(self.server.host, self.server.port) + self.client.sendcmd('noop') + self.client.quit() + self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit') + self.assertFalse(is_client_connected()) + + # force a wrong response code to be sent on QUIT: error_perm + # is expected and the connection is supposed to be closed + try: + with ftplib.FTP(timeout=10) as self.client: + self.client.connect(self.server.host, self.server.port) + self.client.sendcmd('noop') + self.server.handler_instance.next_response = '550 error on quit' + except ftplib.error_perm as err: + self.assertEqual(str(err), '550 error on quit') + else: + self.fail('Exception not raised') + # needed to give the threaded server some time to set the attribute + # which otherwise would still be == 'noop' + time.sleep(0.1) + self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit') + self.assertFalse(is_client_connected()) + + def test_parse257(self): + self.assertEqual(ftplib.parse257('257 "/foo/bar"'), '/foo/bar') + self.assertEqual(ftplib.parse257('257 "/foo/bar" created'), '/foo/bar') + self.assertEqual(ftplib.parse257('257 ""'), '') + self.assertEqual(ftplib.parse257('257 "" created'), '') + self.assertRaises(ftplib.error_reply, ftplib.parse257, '250 "/foo/bar"') + # The 257 response is supposed to include the directory + # name and in case it contains embedded double-quotes + # they must be doubled (see RFC-959, chapter 7, appendix 2). + self.assertEqual(ftplib.parse257('257 "/foo/b""ar"'), '/foo/b"ar') + self.assertEqual(ftplib.parse257('257 "/foo/b""ar" created'), '/foo/b"ar') class TestIPv6Environment(TestCase): @@ -382,14 +637,15 @@ class TestIPv6Environment(TestCase): self.assertEqual(self.client.af, socket.AF_INET6) def test_makeport(self): - self.client.makeport() - self.assertEqual(self.server.handler.last_received_cmd, 'eprt') + with self.client.makeport(): + self.assertEqual(self.server.handler_instance.last_received_cmd, + 'eprt') def test_makepasv(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), 10) conn.close() - self.assertEqual(self.server.handler.last_received_cmd, 'epsv') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv') def test_transfer(self): def retr(): @@ -404,6 +660,100 @@ class TestIPv6Environment(TestCase): retr() +class TestTLS_FTPClassMixin(TestFTPClass): + """Repeat TestFTPClass tests starting the TLS layer for both control + and data connections first. + """ + + def setUp(self): + self.server = DummyTLS_FTPServer((HOST, 0)) + self.server.start() + self.client = ftplib.FTP_TLS(timeout=10) + self.client.connect(self.server.host, self.server.port) + # enable TLS + self.client.auth() + self.client.prot_p() + + +class TestTLS_FTPClass(TestCase): + """Specific TLS_FTP class tests.""" + + def setUp(self): + self.server = DummyTLS_FTPServer((HOST, 0)) + self.server.start() + self.client = ftplib.FTP_TLS(timeout=10) + self.client.connect(self.server.host, self.server.port) + + def tearDown(self): + self.client.close() + self.server.stop() + + def test_control_connection(self): + self.assertNotIsInstance(self.client.sock, ssl.SSLSocket) + self.client.auth() + self.assertIsInstance(self.client.sock, ssl.SSLSocket) + + def test_data_connection(self): + # clear text + with self.client.transfercmd('list') as sock: + self.assertNotIsInstance(sock, ssl.SSLSocket) + self.assertEqual(self.client.voidresp(), "226 transfer complete") + + # secured, after PROT P + self.client.prot_p() + with self.client.transfercmd('list') as sock: + self.assertIsInstance(sock, ssl.SSLSocket) + self.assertEqual(self.client.voidresp(), "226 transfer complete") + + # PROT C is issued, the connection must be in cleartext again + self.client.prot_c() + with self.client.transfercmd('list') as sock: + self.assertNotIsInstance(sock, ssl.SSLSocket) + self.assertEqual(self.client.voidresp(), "226 transfer complete") + + def test_login(self): + # login() is supposed to implicitly secure the control connection + self.assertNotIsInstance(self.client.sock, ssl.SSLSocket) + self.client.login() + self.assertIsInstance(self.client.sock, ssl.SSLSocket) + # make sure that AUTH TLS doesn't get issued again + self.client.login() + + def test_auth_issued_twice(self): + self.client.auth() + self.assertRaises(ValueError, self.client.auth) + + def test_auth_ssl(self): + try: + self.client.ssl_version = ssl.PROTOCOL_SSLv3 + self.client.auth() + self.assertRaises(ValueError, self.client.auth) + finally: + self.client.ssl_version = ssl.PROTOCOL_TLSv1 + + def test_context(self): + self.client.quit() + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE, + context=ctx) + self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + context=ctx) + self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + keyfile=CERTFILE, context=ctx) + + self.client = ftplib.FTP_TLS(context=ctx, timeout=10) + self.client.connect(self.server.host, self.server.port) + self.assertNotIsInstance(self.client.sock, ssl.SSLSocket) + self.client.auth() + self.assertIs(self.client.sock.context, ctx) + self.assertIsInstance(self.client.sock, ssl.SSLSocket) + + self.client.prot_p() + with self.client.transfercmd('list') as sock: + self.assertIs(sock.context, ctx) + self.assertIsInstance(sock, ssl.SSLSocket) + + class TestTimeouts(TestCase): def setUp(self): @@ -419,6 +769,7 @@ class TestTimeouts(TestCase): def tearDown(self): self.evt.wait() + self.sock.close() def server(self, evt, serv): # This method sets the evt 3 times: @@ -505,6 +856,10 @@ def test_main(): pass else: tests.append(TestIPv6Environment) + + if ssl is not None: + tests.extend([TestTLS_FTPClassMixin, TestTLS_FTPClass]) + thread_info = support.threading_setup() try: support.run_unittest(*tests) |