summaryrefslogtreecommitdiff
path: root/bzrlib/tests/test_smart_transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'bzrlib/tests/test_smart_transport.py')
-rw-r--r--bzrlib/tests/test_smart_transport.py4299
1 files changed, 4299 insertions, 0 deletions
diff --git a/bzrlib/tests/test_smart_transport.py b/bzrlib/tests/test_smart_transport.py
new file mode 100644
index 0000000..a4e8a39
--- /dev/null
+++ b/bzrlib/tests/test_smart_transport.py
@@ -0,0 +1,4299 @@
+# Copyright (C) 2006-2011 Canonical Ltd
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+"""Tests for smart transport"""
+
+# all of this deals with byte strings so this is safe
+from cStringIO import StringIO
+import doctest
+import errno
+import os
+import socket
+import subprocess
+import sys
+import threading
+import time
+
+from testtools.matchers import DocTestMatches
+
+import bzrlib
+from bzrlib import (
+ bzrdir,
+ controldir,
+ debug,
+ errors,
+ osutils,
+ tests,
+ transport as _mod_transport,
+ urlutils,
+ )
+from bzrlib.smart import (
+ client,
+ medium,
+ message,
+ protocol,
+ request as _mod_request,
+ server as _mod_server,
+ vfs,
+)
+from bzrlib.tests import (
+ features,
+ test_smart,
+ test_server,
+ )
+from bzrlib.transport import (
+ http,
+ local,
+ memory,
+ remote,
+ ssh,
+ )
+
+
+def create_file_pipes():
+ r, w = os.pipe()
+ # These must be opened without buffering, or we get undefined results
+ rf = os.fdopen(r, 'rb', 0)
+ wf = os.fdopen(w, 'wb', 0)
+ return rf, wf
+
+
+def portable_socket_pair():
+ """Return a pair of TCP sockets connected to each other.
+
+ Unlike socket.socketpair, this should work on Windows.
+ """
+ listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(1)
+ client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client_sock.connect(listen_sock.getsockname())
+ server_sock, addr = listen_sock.accept()
+ listen_sock.close()
+ return server_sock, client_sock
+
+
+class StringIOSSHVendor(object):
+ """A SSH vendor that uses StringIO to buffer writes and answer reads."""
+
+ def __init__(self, read_from, write_to):
+ self.read_from = read_from
+ self.write_to = write_to
+ self.calls = []
+
+ def connect_ssh(self, username, password, host, port, command):
+ self.calls.append(('connect_ssh', username, password, host, port,
+ command))
+ return StringIOSSHConnection(self)
+
+
+class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
+ """The first connection will be considered closed.
+
+ The second connection will succeed normally.
+ """
+
+ def __init__(self, read_from, write_to, fail_at_write=True):
+ super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
+ write_to)
+ self.fail_at_write = fail_at_write
+ self._first = True
+
+ def connect_ssh(self, username, password, host, port, command):
+ self.calls.append(('connect_ssh', username, password, host, port,
+ command))
+ if self._first:
+ self._first = False
+ return ClosedSSHConnection(self)
+ return StringIOSSHConnection(self)
+
+
+class StringIOSSHConnection(ssh.SSHConnection):
+ """A SSH connection that uses StringIO to buffer writes and answer reads."""
+
+ def __init__(self, vendor):
+ self.vendor = vendor
+
+ def close(self):
+ self.vendor.calls.append(('close', ))
+ self.vendor.read_from.close()
+ self.vendor.write_to.close()
+
+ def get_sock_or_pipes(self):
+ return 'pipes', (self.vendor.read_from, self.vendor.write_to)
+
+
+class ClosedSSHConnection(ssh.SSHConnection):
+ """An SSH connection that just has closed channels."""
+
+ def __init__(self, vendor):
+ self.vendor = vendor
+
+ def close(self):
+ self.vendor.calls.append(('close', ))
+
+ def get_sock_or_pipes(self):
+ # We create matching pipes, and then close the ssh side
+ bzr_read, ssh_write = create_file_pipes()
+ # We always fail when bzr goes to read
+ ssh_write.close()
+ if self.vendor.fail_at_write:
+ # If set, we'll also fail when bzr goes to write
+ ssh_read, bzr_write = create_file_pipes()
+ ssh_read.close()
+ else:
+ bzr_write = self.vendor.write_to
+ return 'pipes', (bzr_read, bzr_write)
+
+
+class _InvalidHostnameFeature(features.Feature):
+ """Does 'non_existent.invalid' fail to resolve?
+
+ RFC 2606 states that .invalid is reserved for invalid domain names, and
+ also underscores are not a valid character in domain names. Despite this,
+ it's possible a badly misconfigured name server might decide to always
+ return an address for any name, so this feature allows us to distinguish a
+ broken system from a broken test.
+ """
+
+ def _probe(self):
+ try:
+ socket.gethostbyname('non_existent.invalid')
+ except socket.gaierror:
+ # The host name failed to resolve. Good.
+ return True
+ else:
+ return False
+
+ def feature_name(self):
+ return 'invalid hostname'
+
+InvalidHostnameFeature = _InvalidHostnameFeature()
+
+
+class SmartClientMediumTests(tests.TestCase):
+ """Tests for SmartClientMedium.
+
+ We should create a test scenario for this: we need a server module that
+ construct the test-servers (like make_loopsocket_and_medium), and the list
+ of SmartClientMedium classes to test.
+ """
+
+ def make_loopsocket_and_medium(self):
+ """Create a loopback socket for testing, and a medium aimed at it."""
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ sock.listen(1)
+ port = sock.getsockname()[1]
+ client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base')
+ return sock, client_medium
+
+ def receive_bytes_on_server(self, sock, bytes):
+ """Accept a connection on sock and read 3 bytes.
+
+ The bytes are appended to the list bytes.
+
+ :return: a Thread which is running to do the accept and recv.
+ """
+ def _receive_bytes_on_server():
+ connection, address = sock.accept()
+ bytes.append(osutils.recv_all(connection, 3))
+ connection.close()
+ t = threading.Thread(target=_receive_bytes_on_server)
+ t.start()
+ return t
+
+ def test_construct_smart_simple_pipes_client_medium(self):
+ # the SimplePipes client medium takes two pipes:
+ # readable pipe, writeable pipe.
+ # Constructing one should just save these and do nothing.
+ # We test this by passing in None.
+ client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
+
+ def test_simple_pipes_client_request_type(self):
+ # SimplePipesClient should use SmartClientStreamMediumRequest's.
+ client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
+ request = client_medium.get_request()
+ self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
+
+ def test_simple_pipes_client_get_concurrent_requests(self):
+ # the simple_pipes client does not support pipelined requests:
+ # but it does support serial requests: we construct one after
+ # another is finished. This is a smoke test testing the integration
+ # of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
+ # classes - as the sibling classes share this logic, they do not have
+ # explicit tests for this.
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = client_medium.get_request()
+ request.finished_writing()
+ request.finished_reading()
+ request2 = client_medium.get_request()
+ request2.finished_writing()
+ request2.finished_reading()
+
+ def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
+ # accept_bytes writes to the writeable pipe.
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ client_medium._accept_bytes('abc')
+ self.assertEqual('abc', output.getvalue())
+
+ def test_simple_pipes__accept_bytes_subprocess_closed(self):
+ # It is unfortunate that we have to use Popen for this. However,
+ # os.pipe() does not behave the same as subprocess.Popen().
+ # On Windows, if you use os.pipe() and close the write side,
+ # read.read() hangs. On Linux, read.read() returns the empty string.
+ p = subprocess.Popen([sys.executable, '-c',
+ 'import sys\n'
+ 'sys.stdout.write(sys.stdin.read(4))\n'
+ 'sys.stdout.close()\n'],
+ stdout=subprocess.PIPE, stdin=subprocess.PIPE)
+ client_medium = medium.SmartSimplePipesClientMedium(
+ p.stdout, p.stdin, 'base')
+ client_medium._accept_bytes('abc\n')
+ self.assertEqual('abc', client_medium._read_bytes(3))
+ p.wait()
+ # While writing to the underlying pipe,
+ # Windows py2.6.6 we get IOError(EINVAL)
+ # Lucid py2.6.5, we get IOError(EPIPE)
+ # In both cases, it should be wrapped to ConnectionReset
+ self.assertRaises(errors.ConnectionReset,
+ client_medium._accept_bytes, 'more')
+
+ def test_simple_pipes__accept_bytes_pipe_closed(self):
+ child_read, client_write = create_file_pipes()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, client_write, 'base')
+ client_medium._accept_bytes('abc\n')
+ self.assertEqual('abc\n', child_read.read(4))
+ # While writing to the underlying pipe,
+ # Windows py2.6.6 we get IOError(EINVAL)
+ # Lucid py2.6.5, we get IOError(EPIPE)
+ # In both cases, it should be wrapped to ConnectionReset
+ child_read.close()
+ self.assertRaises(errors.ConnectionReset,
+ client_medium._accept_bytes, 'more')
+
+ def test_simple_pipes__flush_pipe_closed(self):
+ child_read, client_write = create_file_pipes()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, client_write, 'base')
+ client_medium._accept_bytes('abc\n')
+ child_read.close()
+ # Even though the pipe is closed, flush on the write side seems to be a
+ # no-op, rather than a failure.
+ client_medium._flush()
+
+ def test_simple_pipes__flush_subprocess_closed(self):
+ p = subprocess.Popen([sys.executable, '-c',
+ 'import sys\n'
+ 'sys.stdout.write(sys.stdin.read(4))\n'
+ 'sys.stdout.close()\n'],
+ stdout=subprocess.PIPE, stdin=subprocess.PIPE)
+ client_medium = medium.SmartSimplePipesClientMedium(
+ p.stdout, p.stdin, 'base')
+ client_medium._accept_bytes('abc\n')
+ p.wait()
+ # Even though the child process is dead, flush seems to be a no-op.
+ client_medium._flush()
+
+ def test_simple_pipes__read_bytes_pipe_closed(self):
+ child_read, client_write = create_file_pipes()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ child_read, client_write, 'base')
+ client_medium._accept_bytes('abc\n')
+ client_write.close()
+ self.assertEqual('abc\n', client_medium._read_bytes(4))
+ self.assertEqual('', client_medium._read_bytes(4))
+
+ def test_simple_pipes__read_bytes_subprocess_closed(self):
+ p = subprocess.Popen([sys.executable, '-c',
+ 'import sys\n'
+ 'if sys.platform == "win32":\n'
+ ' import msvcrt, os\n'
+ ' msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
+ ' msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
+ 'sys.stdout.write(sys.stdin.read(4))\n'
+ 'sys.stdout.close()\n'],
+ stdout=subprocess.PIPE, stdin=subprocess.PIPE)
+ client_medium = medium.SmartSimplePipesClientMedium(
+ p.stdout, p.stdin, 'base')
+ client_medium._accept_bytes('abc\n')
+ p.wait()
+ self.assertEqual('abc\n', client_medium._read_bytes(4))
+ self.assertEqual('', client_medium._read_bytes(4))
+
+ def test_simple_pipes_client_disconnect_does_nothing(self):
+ # calling disconnect does nothing.
+ input = StringIO()
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ # send some bytes to ensure disconnecting after activity still does not
+ # close.
+ client_medium._accept_bytes('abc')
+ client_medium.disconnect()
+ self.assertFalse(input.closed)
+ self.assertFalse(output.closed)
+
+ def test_simple_pipes_client_accept_bytes_after_disconnect(self):
+ # calling disconnect on the client does not alter the pipe that
+ # accept_bytes writes to.
+ input = StringIO()
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ client_medium._accept_bytes('abc')
+ client_medium.disconnect()
+ client_medium._accept_bytes('abc')
+ self.assertFalse(input.closed)
+ self.assertFalse(output.closed)
+ self.assertEqual('abcabc', output.getvalue())
+
+ def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
+ # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
+ # does nothing.
+ client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
+ client_medium.disconnect()
+
+ def test_simple_pipes_client_can_always_read(self):
+ # SmartSimplePipesClientMedium is never disconnected, so read_bytes
+ # always tries to read from the underlying pipe.
+ input = StringIO('abcdef')
+ client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
+ self.assertEqual('abc', client_medium.read_bytes(3))
+ client_medium.disconnect()
+ self.assertEqual('def', client_medium.read_bytes(3))
+
+ def test_simple_pipes_client_supports__flush(self):
+ # invoking _flush on a SimplePipesClient should flush the output
+ # pipe. We test this by creating an output pipe that records
+ # flush calls made to it.
+ from StringIO import StringIO # get regular StringIO
+ input = StringIO()
+ output = StringIO()
+ flush_calls = []
+ def logging_flush(): flush_calls.append('flush')
+ output.flush = logging_flush
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ # this call is here to ensure we only flush once, not on every
+ # _accept_bytes call.
+ client_medium._accept_bytes('abc')
+ client_medium._flush()
+ client_medium.disconnect()
+ self.assertEqual(['flush'], flush_calls)
+
+ def test_construct_smart_ssh_client_medium(self):
+ # the SSH client medium takes:
+ # host, port, username, password, vendor
+ # Constructing one should just save these and do nothing.
+ # we test this by creating a empty bound socket and constructing
+ # a medium.
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ unopened_port = sock.getsockname()[1]
+ # having vendor be invalid means that if it tries to connect via the
+ # vendor it will blow up.
+ ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None)
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', ssh_params, "not a vendor")
+ sock.close()
+
+ def test_ssh_client_connects_on_first_use(self):
+ # The only thing that initiates a connection from the medium is giving
+ # it bytes.
+ output = StringIO()
+ vendor = StringIOSSHVendor(StringIO(), output)
+ ssh_params = medium.SSHParams(
+ 'a hostname', 'a port', 'a username', 'a password', 'bzr')
+ client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
+ client_medium._accept_bytes('abc')
+ self.assertEqual('abc', output.getvalue())
+ self.assertEqual([('connect_ssh', 'a username', 'a password',
+ 'a hostname', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
+ vendor.calls)
+
+ def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
+ # The only thing that initiates a connection from the medium is giving
+ # it bytes.
+ output = StringIO()
+ vendor = StringIOSSHVendor(StringIO(), output)
+ ssh_params = medium.SSHParams(
+ 'a hostname', 'a port', 'a username', 'a password',
+ bzr_remote_path='fugly')
+ client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
+ client_medium._accept_bytes('abc')
+ self.assertEqual('abc', output.getvalue())
+ self.assertEqual([('connect_ssh', 'a username', 'a password',
+ 'a hostname', 'a port',
+ ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
+ vendor.calls)
+
+ def test_ssh_client_disconnect_does_so(self):
+ # calling disconnect should disconnect both the read_from and write_to
+ # file-like object it from the ssh connection.
+ input = StringIO()
+ output = StringIO()
+ vendor = StringIOSSHVendor(input, output)
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams('a hostname'), vendor)
+ client_medium._accept_bytes('abc')
+ client_medium.disconnect()
+ self.assertTrue(input.closed)
+ self.assertTrue(output.closed)
+ self.assertEqual([
+ ('connect_ssh', None, None, 'a hostname', None,
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close', ),
+ ],
+ vendor.calls)
+
+ def test_ssh_client_disconnect_allows_reconnection(self):
+ # calling disconnect on the client terminates the connection, but should
+ # not prevent additional connections occuring.
+ # we test this by initiating a second connection after doing a
+ # disconnect.
+ input = StringIO()
+ output = StringIO()
+ vendor = StringIOSSHVendor(input, output)
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams('a hostname'), vendor)
+ client_medium._accept_bytes('abc')
+ client_medium.disconnect()
+ # the disconnect has closed output, so we need a new output for the
+ # new connection to write to.
+ input2 = StringIO()
+ output2 = StringIO()
+ vendor.read_from = input2
+ vendor.write_to = output2
+ client_medium._accept_bytes('abc')
+ client_medium.disconnect()
+ self.assertTrue(input.closed)
+ self.assertTrue(output.closed)
+ self.assertTrue(input2.closed)
+ self.assertTrue(output2.closed)
+ self.assertEqual([
+ ('connect_ssh', None, None, 'a hostname', None,
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close', ),
+ ('connect_ssh', None, None, 'a hostname', None,
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close', ),
+ ],
+ vendor.calls)
+
+ def test_ssh_client_repr(self):
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams("example.com", "4242", "username"))
+ self.assertEquals(
+ "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
+ repr(client_medium))
+
+ def test_ssh_client_repr_no_port(self):
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams("example.com", None, "username"))
+ self.assertEquals(
+ "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
+ repr(client_medium))
+
+ def test_ssh_client_repr_no_username(self):
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams("example.com", None, None))
+ self.assertEquals(
+ "SmartSSHClientMedium(bzr+ssh://example.com/)",
+ repr(client_medium))
+
+ def test_ssh_client_ignores_disconnect_when_not_connected(self):
+ # Doing a disconnect on a new (and thus unconnected) SSH medium
+ # does not fail. It's ok to disconnect an unconnected medium.
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams(None))
+ client_medium.disconnect()
+
+ def test_ssh_client_raises_on_read_when_not_connected(self):
+ # Doing a read on a new (and thus unconnected) SSH medium raises
+ # MediumNotConnected.
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams(None))
+ self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
+ 0)
+ self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
+ 1)
+
+ def test_ssh_client_supports__flush(self):
+ # invoking _flush on a SSHClientMedium should flush the output
+ # pipe. We test this by creating an output pipe that records
+ # flush calls made to it.
+ from StringIO import StringIO # get regular StringIO
+ input = StringIO()
+ output = StringIO()
+ flush_calls = []
+ def logging_flush(): flush_calls.append('flush')
+ output.flush = logging_flush
+ vendor = StringIOSSHVendor(input, output)
+ client_medium = medium.SmartSSHClientMedium(
+ 'base', medium.SSHParams('a hostname'), vendor=vendor)
+ # this call is here to ensure we only flush once, not on every
+ # _accept_bytes call.
+ client_medium._accept_bytes('abc')
+ client_medium._flush()
+ client_medium.disconnect()
+ self.assertEqual(['flush'], flush_calls)
+
+ def test_construct_smart_tcp_client_medium(self):
+ # the TCP client medium takes a host and a port. Constructing it won't
+ # connect to anything.
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ unopened_port = sock.getsockname()[1]
+ client_medium = medium.SmartTCPClientMedium(
+ '127.0.0.1', unopened_port, 'base')
+ sock.close()
+
+ def test_tcp_client_connects_on_first_use(self):
+ # The only thing that initiates a connection from the medium is giving
+ # it bytes.
+ sock, medium = self.make_loopsocket_and_medium()
+ bytes = []
+ t = self.receive_bytes_on_server(sock, bytes)
+ medium.accept_bytes('abc')
+ t.join()
+ sock.close()
+ self.assertEqual(['abc'], bytes)
+
+ def test_tcp_client_disconnect_does_so(self):
+ # calling disconnect on the client terminates the connection.
+ # we test this by forcing a short read during a socket.MSG_WAITALL
+ # call: write 2 bytes, try to read 3, and then the client disconnects.
+ sock, medium = self.make_loopsocket_and_medium()
+ bytes = []
+ t = self.receive_bytes_on_server(sock, bytes)
+ medium.accept_bytes('ab')
+ medium.disconnect()
+ t.join()
+ sock.close()
+ self.assertEqual(['ab'], bytes)
+ # now disconnect again: this should not do anything, if disconnection
+ # really did disconnect.
+ medium.disconnect()
+
+
+ def test_tcp_client_ignores_disconnect_when_not_connected(self):
+ # Doing a disconnect on a new (and thus unconnected) TCP medium
+ # does not fail. It's ok to disconnect an unconnected medium.
+ client_medium = medium.SmartTCPClientMedium(None, None, None)
+ client_medium.disconnect()
+
+ def test_tcp_client_raises_on_read_when_not_connected(self):
+ # Doing a read on a new (and thus unconnected) TCP medium raises
+ # MediumNotConnected.
+ client_medium = medium.SmartTCPClientMedium(None, None, None)
+ self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
+ self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
+
+ def test_tcp_client_supports__flush(self):
+ # invoking _flush on a TCPClientMedium should do something useful.
+ # RBC 20060922 not sure how to test/tell in this case.
+ sock, medium = self.make_loopsocket_and_medium()
+ bytes = []
+ t = self.receive_bytes_on_server(sock, bytes)
+ # try with nothing buffered
+ medium._flush()
+ medium._accept_bytes('ab')
+ # and with something sent.
+ medium._flush()
+ medium.disconnect()
+ t.join()
+ sock.close()
+ self.assertEqual(['ab'], bytes)
+ # now disconnect again : this should not do anything, if disconnection
+ # really did disconnect.
+ medium.disconnect()
+
+ def test_tcp_client_host_unknown_connection_error(self):
+ self.requireFeature(InvalidHostnameFeature)
+ client_medium = medium.SmartTCPClientMedium(
+ 'non_existent.invalid', 4155, 'base')
+ self.assertRaises(
+ errors.ConnectionError, client_medium._ensure_connection)
+
+
+class TestSmartClientStreamMediumRequest(tests.TestCase):
+ """Tests the for SmartClientStreamMediumRequest.
+
+ SmartClientStreamMediumRequest is a helper for the three stream based
+ mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
+ those three mediums implement the interface it expects.
+ """
+
+ def test_accept_bytes_after_finished_writing_errors(self):
+ # calling accept_bytes after calling finished_writing raises
+ # WritingCompleted to prevent bad assumptions on stream environments
+ # breaking the needs of message-based environments.
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ request.finished_writing()
+ self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
+
+ def test_accept_bytes(self):
+ # accept bytes should invoke _accept_bytes on the stream medium.
+ # we test this by using the SimplePipes medium - the most trivial one
+ # and checking that the pipes get the data.
+ input = StringIO()
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ request.accept_bytes('123')
+ request.finished_writing()
+ request.finished_reading()
+ self.assertEqual('', input.getvalue())
+ self.assertEqual('123', output.getvalue())
+
+ def test_construct_sets_stream_request(self):
+ # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
+ # the current request to the new SmartClientStreamMediumRequest
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ self.assertIs(client_medium._current_request, request)
+
+ def test_construct_while_another_request_active_throws(self):
+ # constructing a SmartClientStreamMediumRequest on a StreamMedium with
+ # a non-None _current_request raises TooManyConcurrentRequests.
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ client_medium._current_request = "a"
+ self.assertRaises(errors.TooManyConcurrentRequests,
+ medium.SmartClientStreamMediumRequest, client_medium)
+
+ def test_finished_read_clears_current_request(self):
+ # calling finished_reading clears the current request from the requests
+ # medium
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ request.finished_writing()
+ request.finished_reading()
+ self.assertEqual(None, client_medium._current_request)
+
+ def test_finished_read_before_finished_write_errors(self):
+ # calling finished_reading before calling finished_writing triggers a
+ # WritingNotComplete error.
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, None, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ self.assertRaises(errors.WritingNotComplete, request.finished_reading)
+
+ def test_read_bytes(self):
+ # read bytes should invoke _read_bytes on the stream medium.
+ # we test this by using the SimplePipes medium - the most trivial one
+ # and checking that the data is supplied. Its possible that a
+ # faulty implementation could poke at the pipe variables them selves,
+ # but we trust that this will be caught as it will break the integration
+ # smoke tests.
+ input = StringIO('321')
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ request.finished_writing()
+ self.assertEqual('321', request.read_bytes(3))
+ request.finished_reading()
+ self.assertEqual('', input.read())
+ self.assertEqual('', output.getvalue())
+
+ def test_read_bytes_before_finished_write_errors(self):
+ # calling read_bytes before calling finished_writing triggers a
+ # WritingNotComplete error because the Smart protocol is designed to be
+ # compatible with strict message based protocols like HTTP where the
+ # request cannot be submitted until the writing has completed.
+ client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
+
+ def test_read_bytes_after_finished_reading_errors(self):
+ # calling read_bytes after calling finished_reading raises
+ # ReadingCompleted to prevent bad assumptions on stream environments
+ # breaking the needs of message-based environments.
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = medium.SmartClientStreamMediumRequest(client_medium)
+ request.finished_writing()
+ request.finished_reading()
+ self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
+
+ def test_reset(self):
+ server_sock, client_sock = portable_socket_pair()
+ # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
+ # bzr where it exists.
+ client_medium = medium.SmartTCPClientMedium(None, None, None)
+ client_medium._socket = client_sock
+ client_medium._connected = True
+ req = client_medium.get_request()
+ self.assertRaises(errors.TooManyConcurrentRequests,
+ client_medium.get_request)
+ client_medium.reset()
+ # The stream should be reset, marked as disconnected, though ready for
+ # us to make a new request
+ self.assertFalse(client_medium._connected)
+ self.assertIs(None, client_medium._socket)
+ try:
+ self.assertEqual('', client_sock.recv(1))
+ except socket.error, e:
+ if e.errno not in (errno.EBADF,):
+ raise
+ req = client_medium.get_request()
+
+
+class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
+
+ def test_plausible_url(self):
+ self.assert_(self.get_url().startswith('bzr://'))
+
+ def test_probe_transport(self):
+ t = self.get_transport()
+ self.assertIsInstance(t, remote.RemoteTransport)
+
+ def test_get_medium_from_transport(self):
+ """Remote transport has a medium always, which it can return."""
+ t = self.get_transport()
+ client_medium = t.get_smart_medium()
+ self.assertIsInstance(client_medium, medium.SmartClientMedium)
+
+
+class ErrorRaisingProtocol(object):
+
+ def __init__(self, exception):
+ self.exception = exception
+
+ def next_read_size(self):
+ raise self.exception
+
+
+class SampleRequest(object):
+
+ def __init__(self, expected_bytes):
+ self.accepted_bytes = ''
+ self._finished_reading = False
+ self.expected_bytes = expected_bytes
+ self.unused_data = ''
+
+ def accept_bytes(self, bytes):
+ self.accepted_bytes += bytes
+ if self.accepted_bytes.startswith(self.expected_bytes):
+ self._finished_reading = True
+ self.unused_data = self.accepted_bytes[len(self.expected_bytes):]
+
+ def next_read_size(self):
+ if self._finished_reading:
+ return 0
+ else:
+ return 1
+
+
+class TestSmartServerStreamMedium(tests.TestCase):
+
+ def setUp(self):
+ super(TestSmartServerStreamMedium, self).setUp()
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+
+ def create_pipe_medium(self, to_server, from_server, transport,
+ timeout=4.0):
+ """Create a new SmartServerPipeStreamMedium."""
+ return medium.SmartServerPipeStreamMedium(to_server, from_server,
+ transport, timeout=timeout)
+
+ def create_pipe_context(self, to_server_bytes, transport):
+ """Create a SmartServerSocketStreamMedium.
+
+ This differes from create_pipe_medium, in that we initialize the
+ request that is sent to the server, and return the StringIO class that
+ will hold the response.
+ """
+ to_server = StringIO(to_server_bytes)
+ from_server = StringIO()
+ m = self.create_pipe_medium(to_server, from_server, transport)
+ return m, from_server
+
+ def create_socket_medium(self, server_sock, transport, timeout=4.0):
+ """Initialize a new medium.SmartServerSocketStreamMedium."""
+ return medium.SmartServerSocketStreamMedium(server_sock, transport,
+ timeout=timeout)
+
+ def create_socket_context(self, transport, timeout=4.0):
+ """Create a new SmartServerSocketStreamMedium with default context.
+
+ This will call portable_socket_pair and pass the server side to
+ create_socket_medium along with transport.
+ It then returns the client_sock and the server.
+ """
+ server_sock, client_sock = portable_socket_pair()
+ server = self.create_socket_medium(server_sock, transport,
+ timeout=timeout)
+ return server, client_sock
+
+ def test_smart_query_version(self):
+ """Feed a canned query version to a server"""
+ # wire-to-wire, using the whole stack
+ transport = local.LocalTransport(urlutils.local_path_to_url('/'))
+ server, from_server = self.create_pipe_context('hello\n', transport)
+ smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
+ from_server.write)
+ server._serve_one_request(smart_protocol)
+ self.assertEqual('ok\0012\n',
+ from_server.getvalue())
+
+ def test_response_to_canned_get(self):
+ transport = memory.MemoryTransport('memory:///')
+ transport.put_bytes('testfile', 'contents\nof\nfile\n')
+ server, from_server = self.create_pipe_context('get\001./testfile\n',
+ transport)
+ smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
+ from_server.write)
+ server._serve_one_request(smart_protocol)
+ self.assertEqual('ok\n'
+ '17\n'
+ 'contents\nof\nfile\n'
+ 'done\n',
+ from_server.getvalue())
+
+ def test_response_to_canned_get_of_utf8(self):
+ # wire-to-wire, using the whole stack, with a UTF-8 filename.
+ transport = memory.MemoryTransport('memory:///')
+ utf8_filename = u'testfile\N{INTERROBANG}'.encode('utf-8')
+ # VFS requests use filenames, not raw UTF-8.
+ hpss_path = urlutils.escape(utf8_filename)
+ transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
+ server, from_server = self.create_pipe_context(
+ 'get\001' + hpss_path + '\n', transport)
+ smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
+ from_server.write)
+ server._serve_one_request(smart_protocol)
+ self.assertEqual('ok\n'
+ '17\n'
+ 'contents\nof\nfile\n'
+ 'done\n',
+ from_server.getvalue())
+
+ def test_pipe_like_stream_with_bulk_data(self):
+ sample_request_bytes = 'command\n9\nbulk datadone\n'
+ server, from_server = self.create_pipe_context(
+ sample_request_bytes, None)
+ sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
+ server._serve_one_request(sample_protocol)
+ self.assertEqual('', from_server.getvalue())
+ self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
+ self.assertFalse(server.finished)
+
+ def test_socket_stream_with_bulk_data(self):
+ sample_request_bytes = 'command\n9\nbulk datadone\n'
+ server, client_sock = self.create_socket_context(None)
+ sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
+ client_sock.sendall(sample_request_bytes)
+ server._serve_one_request(sample_protocol)
+ server._disconnect_client()
+ self.assertEqual('', client_sock.recv(1))
+ self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
+ self.assertFalse(server.finished)
+
+ def test_pipe_like_stream_shutdown_detection(self):
+ server, _ = self.create_pipe_context('', None)
+ server._serve_one_request(SampleRequest('x'))
+ self.assertTrue(server.finished)
+
+ def test_socket_stream_shutdown_detection(self):
+ server, client_sock = self.create_socket_context(None)
+ client_sock.close()
+ server._serve_one_request(SampleRequest('x'))
+ self.assertTrue(server.finished)
+
+ def test_socket_stream_incomplete_request(self):
+ """The medium should still construct the right protocol version even if
+ the initial read only reads part of the request.
+
+ Specifically, it should correctly read the protocol version line even
+ if the partial read doesn't end in a newline. An older, naive
+ implementation of _get_line in the server used to have a bug in that
+ case.
+ """
+ incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
+ rest_of_request_bytes = 'lo\n'
+ expected_response = (
+ protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
+ server, client_sock = self.create_socket_context(None)
+ client_sock.sendall(incomplete_request_bytes)
+ server_protocol = server._build_protocol()
+ client_sock.sendall(rest_of_request_bytes)
+ server._serve_one_request(server_protocol)
+ server._disconnect_client()
+ self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
+ "Not a version 2 response to 'hello' request.")
+ self.assertEqual('', client_sock.recv(1))
+
+ def test_pipe_stream_incomplete_request(self):
+ """The medium should still construct the right protocol version even if
+ the initial read only reads part of the request.
+
+ Specifically, it should correctly read the protocol version line even
+ if the partial read doesn't end in a newline. An older, naive
+ implementation of _get_line in the server used to have a bug in that
+ case.
+ """
+ incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
+ rest_of_request_bytes = 'lo\n'
+ expected_response = (
+ protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
+ # Make a pair of pipes, to and from the server
+ to_server, to_server_w = os.pipe()
+ from_server_r, from_server = os.pipe()
+ to_server = os.fdopen(to_server, 'r', 0)
+ to_server_w = os.fdopen(to_server_w, 'w', 0)
+ from_server_r = os.fdopen(from_server_r, 'r', 0)
+ from_server = os.fdopen(from_server, 'w', 0)
+ server = self.create_pipe_medium(to_server, from_server, None)
+ # Like test_socket_stream_incomplete_request, write an incomplete
+ # request (that does not end in '\n') and build a protocol from it.
+ to_server_w.write(incomplete_request_bytes)
+ server_protocol = server._build_protocol()
+ # Send the rest of the request, and finish serving it.
+ to_server_w.write(rest_of_request_bytes)
+ server._serve_one_request(server_protocol)
+ to_server_w.close()
+ from_server.close()
+ self.assertEqual(expected_response, from_server_r.read(),
+ "Not a version 2 response to 'hello' request.")
+ self.assertEqual('', from_server_r.read(1))
+ from_server_r.close()
+ to_server.close()
+
+ def test_pipe_like_stream_with_two_requests(self):
+ # If two requests are read in one go, then two calls to
+ # _serve_one_request should still process both of them as if they had
+ # been received separately.
+ sample_request_bytes = 'command\n'
+ server, from_server = self.create_pipe_context(
+ sample_request_bytes * 2, None)
+ first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
+ server._serve_one_request(first_protocol)
+ self.assertEqual(0, first_protocol.next_read_size())
+ self.assertEqual('', from_server.getvalue())
+ self.assertFalse(server.finished)
+ # Make a new protocol, call _serve_one_request with it to collect the
+ # second request.
+ second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
+ server._serve_one_request(second_protocol)
+ self.assertEqual('', from_server.getvalue())
+ self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
+ self.assertFalse(server.finished)
+
+ def test_socket_stream_with_two_requests(self):
+ # If two requests are read in one go, then two calls to
+ # _serve_one_request should still process both of them as if they had
+ # been received separately.
+ sample_request_bytes = 'command\n'
+ server, client_sock = self.create_socket_context(None)
+ first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
+ # Put two whole requests on the wire.
+ client_sock.sendall(sample_request_bytes * 2)
+ server._serve_one_request(first_protocol)
+ self.assertEqual(0, first_protocol.next_read_size())
+ self.assertFalse(server.finished)
+ # Make a new protocol, call _serve_one_request with it to collect the
+ # second request.
+ second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
+ stream_still_open = server._serve_one_request(second_protocol)
+ self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
+ self.assertFalse(server.finished)
+ server._disconnect_client()
+ self.assertEqual('', client_sock.recv(1))
+
+ def test_pipe_like_stream_error_handling(self):
+ # Use plain python StringIO so we can monkey-patch the close method to
+ # not discard the contents.
+ from StringIO import StringIO
+ to_server = StringIO('')
+ from_server = StringIO()
+ self.closed = False
+ def close():
+ self.closed = True
+ from_server.close = close
+ server = self.create_pipe_medium(
+ to_server, from_server, None)
+ fake_protocol = ErrorRaisingProtocol(Exception('boom'))
+ server._serve_one_request(fake_protocol)
+ self.assertEqual('', from_server.getvalue())
+ self.assertTrue(self.closed)
+ self.assertTrue(server.finished)
+
+ def test_socket_stream_error_handling(self):
+ server, client_sock = self.create_socket_context(None)
+ fake_protocol = ErrorRaisingProtocol(Exception('boom'))
+ server._serve_one_request(fake_protocol)
+ # recv should not block, because the other end of the socket has been
+ # closed.
+ self.assertEqual('', client_sock.recv(1))
+ self.assertTrue(server.finished)
+
+ def test_pipe_like_stream_keyboard_interrupt_handling(self):
+ server, from_server = self.create_pipe_context('', None)
+ fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
+ self.assertRaises(
+ KeyboardInterrupt, server._serve_one_request, fake_protocol)
+ self.assertEqual('', from_server.getvalue())
+
+ def test_socket_stream_keyboard_interrupt_handling(self):
+ server, client_sock = self.create_socket_context(None)
+ fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
+ self.assertRaises(
+ KeyboardInterrupt, server._serve_one_request, fake_protocol)
+ server._disconnect_client()
+ self.assertEqual('', client_sock.recv(1))
+
+ def build_protocol_pipe_like(self, bytes):
+ server, _ = self.create_pipe_context(bytes, None)
+ return server._build_protocol()
+
+ def build_protocol_socket(self, bytes):
+ server, client_sock = self.create_socket_context(None)
+ client_sock.sendall(bytes)
+ client_sock.close()
+ return server._build_protocol()
+
+ def assertProtocolOne(self, server_protocol):
+ # Use assertIs because assertIsInstance will wrongly pass
+ # SmartServerRequestProtocolTwo (because it subclasses
+ # SmartServerRequestProtocolOne).
+ self.assertIs(
+ type(server_protocol), protocol.SmartServerRequestProtocolOne)
+
+ def assertProtocolTwo(self, server_protocol):
+ self.assertIsInstance(
+ server_protocol, protocol.SmartServerRequestProtocolTwo)
+
+ def test_pipe_like_build_protocol_empty_bytes(self):
+ # Any empty request (i.e. no bytes) is detected as protocol version one.
+ server_protocol = self.build_protocol_pipe_like('')
+ self.assertProtocolOne(server_protocol)
+
+ def test_socket_like_build_protocol_empty_bytes(self):
+ # Any empty request (i.e. no bytes) is detected as protocol version one.
+ server_protocol = self.build_protocol_socket('')
+ self.assertProtocolOne(server_protocol)
+
+ def test_pipe_like_build_protocol_non_two(self):
+ # A request that doesn't start with "bzr request 2\n" is version one.
+ server_protocol = self.build_protocol_pipe_like('abc\n')
+ self.assertProtocolOne(server_protocol)
+
+ def test_socket_build_protocol_non_two(self):
+ # A request that doesn't start with "bzr request 2\n" is version one.
+ server_protocol = self.build_protocol_socket('abc\n')
+ self.assertProtocolOne(server_protocol)
+
+ def test_pipe_like_build_protocol_two(self):
+ # A request that starts with "bzr request 2\n" is version two.
+ server_protocol = self.build_protocol_pipe_like('bzr request 2\n')
+ self.assertProtocolTwo(server_protocol)
+
+ def test_socket_build_protocol_two(self):
+ # A request that starts with "bzr request 2\n" is version two.
+ server_protocol = self.build_protocol_socket('bzr request 2\n')
+ self.assertProtocolTwo(server_protocol)
+
+ def test__build_protocol_returns_if_stopping(self):
+ # _build_protocol should notice that we are stopping, and return
+ # without waiting for bytes from the client.
+ server, client_sock = self.create_socket_context(None)
+ server._stop_gracefully()
+ self.assertIs(None, server._build_protocol())
+
+ def test_socket_set_timeout(self):
+ server, _ = self.create_socket_context(None, timeout=1.23)
+ self.assertEqual(1.23, server._client_timeout)
+
+ def test_pipe_set_timeout(self):
+ server = self.create_pipe_medium(None, None, None,
+ timeout=1.23)
+ self.assertEqual(1.23, server._client_timeout)
+
+ def test_socket_wait_for_bytes_with_timeout_with_data(self):
+ server, client_sock = self.create_socket_context(None)
+ client_sock.sendall('data\n')
+ # This should not block or consume any actual content
+ self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
+ data = server.read_bytes(5)
+ self.assertEqual('data\n', data)
+
+ def test_socket_wait_for_bytes_with_timeout_no_data(self):
+ server, client_sock = self.create_socket_context(None)
+ # This should timeout quickly, reporting that there wasn't any data
+ self.assertRaises(errors.ConnectionTimeout,
+ server._wait_for_bytes_with_timeout, 0.01)
+ client_sock.close()
+ data = server.read_bytes(1)
+ self.assertEqual('', data)
+
+ def test_socket_wait_for_bytes_with_timeout_closed(self):
+ server, client_sock = self.create_socket_context(None)
+ # With the socket closed, this should return right away.
+ # It seems select.select() returns that you *can* read on the socket,
+ # even though it closed. Presumably as a way to tell it is closed?
+ # Testing shows that without sock.close() this times-out failing the
+ # test, but with it, it returns False immediately.
+ client_sock.close()
+ self.assertFalse(server._wait_for_bytes_with_timeout(10))
+ data = server.read_bytes(1)
+ self.assertEqual('', data)
+
+ def test_socket_wait_for_bytes_with_shutdown(self):
+ server, client_sock = self.create_socket_context(None)
+ t = time.time()
+ # Override the _timer functionality, so that time never increments,
+ # this way, we can be sure we stopped because of the flag, and not
+ # because of a timeout, etc.
+ server._timer = lambda: t
+ server._client_poll_timeout = 0.1
+ server._stop_gracefully()
+ server._wait_for_bytes_with_timeout(1.0)
+
+ def test_socket_serve_timeout_closes_socket(self):
+ server, client_sock = self.create_socket_context(None, timeout=0.1)
+ # This should timeout quickly, and then close the connection so that
+ # client_sock recv doesn't block.
+ server.serve()
+ self.assertEqual('', client_sock.recv(1))
+
+ def test_pipe_wait_for_bytes_with_timeout_with_data(self):
+ # We intentionally use a real pipe here, so that we can 'select' on it.
+ # You can't select() on a StringIO
+ (r_server, w_client) = os.pipe()
+ self.addCleanup(os.close, w_client)
+ with os.fdopen(r_server, 'rb') as rf_server:
+ server = self.create_pipe_medium(
+ rf_server, None, None)
+ os.write(w_client, 'data\n')
+ # This should not block or consume any actual content
+ server._wait_for_bytes_with_timeout(0.1)
+ data = server.read_bytes(5)
+ self.assertEqual('data\n', data)
+
+ def test_pipe_wait_for_bytes_with_timeout_no_data(self):
+ # We intentionally use a real pipe here, so that we can 'select' on it.
+ # You can't select() on a StringIO
+ (r_server, w_client) = os.pipe()
+ # We can't add an os.close cleanup here, because we need to control
+ # when the file handle gets closed ourselves.
+ with os.fdopen(r_server, 'rb') as rf_server:
+ server = self.create_pipe_medium(
+ rf_server, None, None)
+ if sys.platform == 'win32':
+ # Windows cannot select() on a pipe, so we just always return
+ server._wait_for_bytes_with_timeout(0.01)
+ else:
+ self.assertRaises(errors.ConnectionTimeout,
+ server._wait_for_bytes_with_timeout, 0.01)
+ os.close(w_client)
+ data = server.read_bytes(5)
+ self.assertEqual('', data)
+
+ def test_pipe_wait_for_bytes_no_fileno(self):
+ server, _ = self.create_pipe_context('', None)
+ # Our file doesn't support polling, so we should always just return
+ # 'you have data to consume.
+ server._wait_for_bytes_with_timeout(0.01)
+
+
+class TestGetProtocolFactoryForBytes(tests.TestCase):
+ """_get_protocol_factory_for_bytes identifies the protocol factory a server
+ should use to decode a given request. Any bytes not part of the version
+ marker string (and thus part of the actual request) are returned alongside
+ the protocol factory.
+ """
+
+ def test_version_three(self):
+ result = medium._get_protocol_factory_for_bytes(
+ 'bzr message 3 (bzr 1.6)\nextra bytes')
+ protocol_factory, remainder = result
+ self.assertEqual(
+ protocol.build_server_protocol_three, protocol_factory)
+ self.assertEqual('extra bytes', remainder)
+
+ def test_version_two(self):
+ result = medium._get_protocol_factory_for_bytes(
+ 'bzr request 2\nextra bytes')
+ protocol_factory, remainder = result
+ self.assertEqual(
+ protocol.SmartServerRequestProtocolTwo, protocol_factory)
+ self.assertEqual('extra bytes', remainder)
+
+ def test_version_one(self):
+ """Version one requests have no version markers."""
+ result = medium._get_protocol_factory_for_bytes('anything\n')
+ protocol_factory, remainder = result
+ self.assertEqual(
+ protocol.SmartServerRequestProtocolOne, protocol_factory)
+ self.assertEqual('anything\n', remainder)
+
+
+class TestSmartTCPServer(tests.TestCase):
+
+ def make_server(self):
+ """Create a SmartTCPServer that we can exercise.
+
+ Note: we don't use SmartTCPServer_for_testing because the testing
+ version overrides lots of functionality like 'serve', and we want to
+ test the raw service.
+
+ This will start the server in another thread, and wait for it to
+ indicate it has finished starting up.
+
+ :return: (server, server_thread)
+ """
+ t = _mod_transport.get_transport_from_url('memory:///')
+ server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
+ server._ACCEPT_TIMEOUT = 0.1
+ # We don't use 'localhost' because that might be an IPv6 address.
+ server.start_server('127.0.0.1', 0)
+ server_thread = threading.Thread(target=server.serve,
+ args=(self.id(),))
+ server_thread.start()
+ # Ensure this gets called at some point
+ self.addCleanup(server._stop_gracefully)
+ server._started.wait()
+ return server, server_thread
+
+ def ensure_client_disconnected(self, client_sock):
+ """Ensure that a socket is closed, discarding all errors."""
+ try:
+ client_sock.close()
+ except Exception:
+ pass
+
+ def connect_to_server(self, server):
+ """Create a client socket that can talk to the server."""
+ client_sock = socket.socket()
+ server_info = server._server_socket.getsockname()
+ client_sock.connect(server_info)
+ self.addCleanup(self.ensure_client_disconnected, client_sock)
+ return client_sock
+
+ def connect_to_server_and_hangup(self, server):
+ """Connect to the server, and then hang up.
+ That way it doesn't sit waiting for 'accept()' to timeout.
+ """
+ # If the server has already signaled that the socket is closed, we
+ # don't need to try to connect to it. Not being set, though, the server
+ # might still close the socket while we try to connect to it. So we
+ # still have to catch the exception.
+ if server._stopped.isSet():
+ return
+ try:
+ client_sock = self.connect_to_server(server)
+ client_sock.close()
+ except socket.error, e:
+ # If the server has hung up already, that is fine.
+ pass
+
+ def say_hello(self, client_sock):
+ """Send the 'hello' smart RPC, and expect the response."""
+ client_sock.send('hello\n')
+ self.assertEqual('ok\x012\n', client_sock.recv(5))
+
+ def shutdown_server_cleanly(self, server, server_thread):
+ server._stop_gracefully()
+ self.connect_to_server_and_hangup(server)
+ server._stopped.wait()
+ server._fully_stopped.wait()
+ server_thread.join()
+
+ def test_get_error_unexpected(self):
+ """Error reported by server with no specific representation"""
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ class FlakyTransport(object):
+ base = 'a_url'
+ def external_url(self):
+ return self.base
+ def get(self, path):
+ raise Exception("some random exception from inside server")
+
+ class FlakyServer(test_server.SmartTCPServer_for_testing):
+ def get_backing_transport(self, backing_transport_server):
+ return FlakyTransport()
+
+ smart_server = FlakyServer()
+ smart_server.start_server()
+ self.addCleanup(smart_server.stop_server)
+ t = remote.RemoteTCPTransport(smart_server.get_url())
+ self.addCleanup(t.disconnect)
+ err = self.assertRaises(errors.UnknownErrorFromSmartServer,
+ t.get, 'something')
+ self.assertContainsRe(str(err), 'some random exception')
+
+ def test_propagates_timeout(self):
+ server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
+ server_sock, client_sock = portable_socket_pair()
+ handler = server._make_handler(server_sock)
+ self.assertEqual(1.23, handler._client_timeout)
+
+ def test_serve_conn_tracks_connections(self):
+ server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
+ server_sock, client_sock = portable_socket_pair()
+ server.serve_conn(server_sock, '-%s' % (self.id(),))
+ self.assertEqual(1, len(server._active_connections))
+ # We still want to talk on the connection. Polling should indicate it
+ # is still active.
+ server._poll_active_connections()
+ self.assertEqual(1, len(server._active_connections))
+ # Closing the socket will end the active thread, and polling will
+ # notice and remove it from the active set.
+ client_sock.close()
+ server._poll_active_connections(0.1)
+ self.assertEqual(0, len(server._active_connections))
+
+ def test_serve_closes_out_finished_connections(self):
+ server, server_thread = self.make_server()
+ # The server is started, connect to it.
+ client_sock = self.connect_to_server(server)
+ # We send and receive on the connection, so that we know the
+ # server-side has seen the connect, and started handling the
+ # results.
+ self.say_hello(client_sock)
+ self.assertEqual(1, len(server._active_connections))
+ # Grab a handle to the thread that is processing our request
+ _, server_side_thread = server._active_connections[0]
+ # Close the connection, ask the server to stop, and wait for the
+ # server to stop, as well as the thread that was servicing the
+ # client request.
+ client_sock.close()
+ # Wait for the server-side request thread to notice we are closed.
+ server_side_thread.join()
+ # Stop the server, it should notice the connection has finished.
+ self.shutdown_server_cleanly(server, server_thread)
+ # The server should have noticed that all clients are gone before
+ # exiting.
+ self.assertEqual(0, len(server._active_connections))
+
+ def test_serve_reaps_finished_connections(self):
+ server, server_thread = self.make_server()
+ client_sock1 = self.connect_to_server(server)
+ # We send and receive on the connection, so that we know the
+ # server-side has seen the connect, and started handling the
+ # results.
+ self.say_hello(client_sock1)
+ server_handler1, server_side_thread1 = server._active_connections[0]
+ client_sock1.close()
+ server_side_thread1.join()
+ # By waiting until the first connection is fully done, the server
+ # should notice after another connection that the first has finished.
+ client_sock2 = self.connect_to_server(server)
+ self.say_hello(client_sock2)
+ server_handler2, server_side_thread2 = server._active_connections[-1]
+ # There is a race condition. We know that client_sock2 has been
+ # registered, but not that _poll_active_connections has been called. We
+ # know that it will be called before the server will accept a new
+ # connection, however. So connect one more time, and assert that we
+ # either have 1 or 2 active connections (never 3), and that the 'first'
+ # connection is not connection 1
+ client_sock3 = self.connect_to_server(server)
+ self.say_hello(client_sock3)
+ # Copy the list, so we don't have it mutating behind our back
+ conns = list(server._active_connections)
+ self.assertEqual(2, len(conns))
+ self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
+ self.assertEqual((server_handler2, server_side_thread2), conns[0])
+ client_sock2.close()
+ client_sock3.close()
+ self.shutdown_server_cleanly(server, server_thread)
+
+ def test_graceful_shutdown_waits_for_clients_to_stop(self):
+ server, server_thread = self.make_server()
+ # We need something big enough that it won't fit in a single recv. So
+ # the server thread gets blocked writing content to the client until we
+ # finish reading on the client.
+ server.backing_transport.put_bytes('bigfile',
+ 'a'*1024*1024)
+ client_sock = self.connect_to_server(server)
+ self.say_hello(client_sock)
+ _, server_side_thread = server._active_connections[0]
+ # Start the RPC, but don't finish reading the response
+ client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
+ 'base', client_sock)
+ client_client = client._SmartClient(client_medium)
+ resp, response_handler = client_client.call_expecting_body('get',
+ 'bigfile')
+ self.assertEqual(('ok',), resp)
+ # Ask the server to stop gracefully, and wait for it.
+ server._stop_gracefully()
+ self.connect_to_server_and_hangup(server)
+ server._stopped.wait()
+ # It should not be accepting another connection.
+ self.assertRaises(socket.error, self.connect_to_server, server)
+ # It should also not be fully stopped
+ server._fully_stopped.wait(0.01)
+ self.assertFalse(server._fully_stopped.isSet())
+ response_handler.read_body_bytes()
+ client_sock.close()
+ server_side_thread.join()
+ server_thread.join()
+ self.assertTrue(server._fully_stopped.isSet())
+ log = self.get_log()
+ self.assertThat(log, DocTestMatches("""\
+ INFO Requested to stop gracefully
+... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
+ INFO Waiting for 1 client(s) to finish
+""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
+
+ def test_stop_gracefully_tells_handlers_to_stop(self):
+ server, server_thread = self.make_server()
+ client_sock = self.connect_to_server(server)
+ self.say_hello(client_sock)
+ server_handler, server_side_thread = server._active_connections[0]
+ self.assertFalse(server_handler.finished)
+ server._stop_gracefully()
+ self.assertTrue(server_handler.finished)
+ client_sock.close()
+ self.connect_to_server_and_hangup(server)
+ server_thread.join()
+
+
+class SmartTCPTests(tests.TestCase):
+ """Tests for connection/end to end behaviour using the TCP server.
+
+ All of these tests are run with a server running in another thread serving
+ a MemoryTransport, and a connection to it already open.
+
+ the server is obtained by calling self.start_server(readonly=False).
+ """
+
+ def start_server(self, readonly=False, backing_transport=None):
+ """Setup the server.
+
+ :param readonly: Create a readonly server.
+ """
+ # NB: Tests using this fall into two categories: tests of the server,
+ # tests wanting a server. The latter should be updated to use
+ # self.vfs_transport_factory etc.
+ if backing_transport is None:
+ mem_server = memory.MemoryServer()
+ mem_server.start_server()
+ self.addCleanup(mem_server.stop_server)
+ self.permit_url(mem_server.get_url())
+ self.backing_transport = _mod_transport.get_transport_from_url(
+ mem_server.get_url())
+ else:
+ self.backing_transport = backing_transport
+ if readonly:
+ self.real_backing_transport = self.backing_transport
+ self.backing_transport = _mod_transport.get_transport_from_url(
+ "readonly+" + self.backing_transport.abspath('.'))
+ self.server = _mod_server.SmartTCPServer(self.backing_transport,
+ client_timeout=4.0)
+ self.server.start_server('127.0.0.1', 0)
+ self.server.start_background_thread('-' + self.id())
+ self.transport = remote.RemoteTCPTransport(self.server.get_url())
+ self.addCleanup(self.stop_server)
+ self.permit_url(self.server.get_url())
+
+ def stop_server(self):
+ """Disconnect the client and stop the server.
+
+ This must be re-entrant as some tests will call it explicitly in
+ addition to the normal cleanup.
+ """
+ if getattr(self, 'transport', None):
+ self.transport.disconnect()
+ del self.transport
+ if getattr(self, 'server', None):
+ self.server.stop_background_thread()
+ del self.server
+
+
+class TestServerSocketUsage(SmartTCPTests):
+
+ def test_server_start_stop(self):
+ """It should be safe to stop the server with no requests."""
+ self.start_server()
+ t = remote.RemoteTCPTransport(self.server.get_url())
+ self.stop_server()
+ self.assertRaises(errors.ConnectionError, t.has, '.')
+
+ def test_server_closes_listening_sock_on_shutdown_after_request(self):
+ """The server should close its listening socket when it's stopped."""
+ self.start_server()
+ server_url = self.server.get_url()
+ self.transport.has('.')
+ self.stop_server()
+ # if the listening socket has closed, we should get a BADFD error
+ # when connecting, rather than a hang.
+ t = remote.RemoteTCPTransport(server_url)
+ self.assertRaises(errors.ConnectionError, t.has, '.')
+
+
+class WritableEndToEndTests(SmartTCPTests):
+ """Client to server tests that require a writable transport."""
+
+ def setUp(self):
+ super(WritableEndToEndTests, self).setUp()
+ self.start_server()
+
+ def test_start_tcp_server(self):
+ url = self.server.get_url()
+ self.assertContainsRe(url, r'^bzr://127\.0\.0\.1:[0-9]{2,}/')
+
+ def test_smart_transport_has(self):
+ """Checking for file existence over smart."""
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ self.backing_transport.put_bytes("foo", "contents of foo\n")
+ self.assertTrue(self.transport.has("foo"))
+ self.assertFalse(self.transport.has("non-foo"))
+
+ def test_smart_transport_get(self):
+ """Read back a file over smart."""
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
+ fp = self.transport.get("foo")
+ self.assertEqual('contents\nof\nfoo\n', fp.read())
+
+ def test_get_error_enoent(self):
+ """Error reported from server getting nonexistent file."""
+ # The path in a raised NoSuchFile exception should be the precise path
+ # asked for by the client. This gives meaningful and unsurprising errors
+ # for users.
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ err = self.assertRaises(
+ errors.NoSuchFile, self.transport.get, 'not%20a%20file')
+ self.assertSubset([err.path], ['not%20a%20file', './not%20a%20file'])
+
+ def test_simple_clone_conn(self):
+ """Test that cloning reuses the same connection."""
+ # we create a real connection not a loopback one, but it will use the
+ # same server and pipes
+ conn2 = self.transport.clone('.')
+ self.assertIs(self.transport.get_smart_medium(),
+ conn2.get_smart_medium())
+
+ def test__remote_path(self):
+ self.assertEquals('/foo/bar',
+ self.transport._remote_path('foo/bar'))
+
+ def test_clone_changes_base(self):
+ """Cloning transport produces one with a new base location"""
+ conn2 = self.transport.clone('subdir')
+ self.assertEquals(self.transport.base + 'subdir/',
+ conn2.base)
+
+ def test_open_dir(self):
+ """Test changing directory"""
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ transport = self.transport
+ self.backing_transport.mkdir('toffee')
+ self.backing_transport.mkdir('toffee/apple')
+ self.assertEquals('/toffee', transport._remote_path('toffee'))
+ toffee_trans = transport.clone('toffee')
+ # Check that each transport has only the contents of its directory
+ # directly visible. If state was being held in the wrong object, it's
+ # conceivable that cloning a transport would alter the state of the
+ # cloned-from transport.
+ self.assertTrue(transport.has('toffee'))
+ self.assertFalse(toffee_trans.has('toffee'))
+ self.assertFalse(transport.has('apple'))
+ self.assertTrue(toffee_trans.has('apple'))
+
+ def test_open_bzrdir(self):
+ """Open an existing bzrdir over smart transport"""
+ transport = self.transport
+ t = self.backing_transport
+ bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
+ result_dir = controldir.ControlDir.open_containing_from_transport(
+ transport)
+
+
+class ReadOnlyEndToEndTests(SmartTCPTests):
+ """Tests from the client to the server using a readonly backing transport."""
+
+ def test_mkdir_error_readonly(self):
+ """TransportNotPossible should be preserved from the backing transport."""
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ self.start_server(readonly=True)
+ self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
+ 'foo')
+
+
+class TestServerHooks(SmartTCPTests):
+
+ def capture_server_call(self, backing_urls, public_url):
+ """Record a server_started|stopped hook firing."""
+ self.hook_calls.append((backing_urls, public_url))
+
+ def test_server_started_hook_memory(self):
+ """The server_started hook fires when the server is started."""
+ self.hook_calls = []
+ _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
+ self.capture_server_call, None)
+ self.start_server()
+ # at this point, the server will be starting a thread up.
+ # there is no indicator at the moment, so bodge it by doing a request.
+ self.transport.has('.')
+ # The default test server uses MemoryTransport and that has no external
+ # url:
+ self.assertEqual([([self.backing_transport.base], self.transport.base)],
+ self.hook_calls)
+
+ def test_server_started_hook_file(self):
+ """The server_started hook fires when the server is started."""
+ self.hook_calls = []
+ _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
+ self.capture_server_call, None)
+ self.start_server(
+ backing_transport=_mod_transport.get_transport_from_path("."))
+ # at this point, the server will be starting a thread up.
+ # there is no indicator at the moment, so bodge it by doing a request.
+ self.transport.has('.')
+ # The default test server uses MemoryTransport and that has no external
+ # url:
+ self.assertEqual([([
+ self.backing_transport.base, self.backing_transport.external_url()],
+ self.transport.base)],
+ self.hook_calls)
+
+ def test_server_stopped_hook_simple_memory(self):
+ """The server_stopped hook fires when the server is stopped."""
+ self.hook_calls = []
+ _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
+ self.capture_server_call, None)
+ self.start_server()
+ result = [([self.backing_transport.base], self.transport.base)]
+ # check the stopping message isn't emitted up front.
+ self.assertEqual([], self.hook_calls)
+ # nor after a single message
+ self.transport.has('.')
+ self.assertEqual([], self.hook_calls)
+ # clean up the server
+ self.stop_server()
+ # now it should have fired.
+ self.assertEqual(result, self.hook_calls)
+
+ def test_server_stopped_hook_simple_file(self):
+ """The server_stopped hook fires when the server is stopped."""
+ self.hook_calls = []
+ _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
+ self.capture_server_call, None)
+ self.start_server(
+ backing_transport=_mod_transport.get_transport_from_path("."))
+ result = [(
+ [self.backing_transport.base, self.backing_transport.external_url()]
+ , self.transport.base)]
+ # check the stopping message isn't emitted up front.
+ self.assertEqual([], self.hook_calls)
+ # nor after a single message
+ self.transport.has('.')
+ self.assertEqual([], self.hook_calls)
+ # clean up the server
+ self.stop_server()
+ # now it should have fired.
+ self.assertEqual(result, self.hook_calls)
+
+# TODO: test that when the server suffers an exception that it calls the
+# server-stopped hook.
+
+
+class SmartServerCommandTests(tests.TestCaseWithTransport):
+ """Tests that call directly into the command objects, bypassing the network
+ and the request dispatching.
+
+ Note: these tests are rudimentary versions of the command object tests in
+ test_smart.py.
+ """
+
+ def test_hello(self):
+ cmd = _mod_request.HelloRequest(None, '/')
+ response = cmd.execute()
+ self.assertEqual(('ok', '2'), response.args)
+ self.assertEqual(None, response.body)
+
+ def test_get_bundle(self):
+ from bzrlib.bundle import serializer
+ wt = self.make_branch_and_tree('.')
+ self.build_tree_contents([('hello', 'hello world')])
+ wt.add('hello')
+ rev_id = wt.commit('add hello')
+
+ cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
+ response = cmd.execute('.', rev_id)
+ bundle = serializer.read_bundle(StringIO(response.body))
+ self.assertEqual((), response.args)
+
+
+class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
+ """Test that call directly into the handler logic, bypassing the network."""
+
+ def setUp(self):
+ super(SmartServerRequestHandlerTests, self).setUp()
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+
+ def build_handler(self, transport):
+ """Returns a handler for the commands in protocol version one."""
+ return _mod_request.SmartServerRequestHandler(
+ transport, _mod_request.request_handlers, '/')
+
+ def test_construct_request_handler(self):
+ """Constructing a request handler should be easy and set defaults."""
+ handler = _mod_request.SmartServerRequestHandler(None, commands=None,
+ root_client_path='/')
+ self.assertFalse(handler.finished_reading)
+
+ def test_hello(self):
+ handler = self.build_handler(None)
+ handler.args_received(('hello',))
+ self.assertEqual(('ok', '2'), handler.response.args)
+ self.assertEqual(None, handler.response.body)
+
+ def test_disable_vfs_handler_classes_via_environment(self):
+ # VFS handler classes will raise an error from "execute" if
+ # BZR_NO_SMART_VFS is set.
+ handler = vfs.HasRequest(None, '/')
+ # set environment variable after construction to make sure it's
+ # examined.
+ self.overrideEnv('BZR_NO_SMART_VFS', '')
+ self.assertRaises(errors.DisabledMethod, handler.execute)
+
+ def test_readonly_exception_becomes_transport_not_possible(self):
+ """The response for a read-only error is ('ReadOnlyError')."""
+ handler = self.build_handler(self.get_readonly_transport())
+ # send a mkdir for foo, with no explicit mode - should fail.
+ handler.args_received(('mkdir', 'foo', ''))
+ # and the failure should be an explicit ReadOnlyError
+ self.assertEqual(("ReadOnlyError", ), handler.response.args)
+ # XXX: TODO: test that other TransportNotPossible errors are
+ # presented as TransportNotPossible - not possible to do that
+ # until I figure out how to trigger that relatively cleanly via
+ # the api. RBC 20060918
+
+ def test_hello_has_finished_body_on_dispatch(self):
+ """The 'hello' command should set finished_reading."""
+ handler = self.build_handler(None)
+ handler.args_received(('hello',))
+ self.assertTrue(handler.finished_reading)
+ self.assertNotEqual(None, handler.response)
+
+ def test_put_bytes_non_atomic(self):
+ """'put_...' should set finished_reading after reading the bytes."""
+ handler = self.build_handler(self.get_transport())
+ handler.args_received(('put_non_atomic', 'a-file', '', 'F', ''))
+ self.assertFalse(handler.finished_reading)
+ handler.accept_body('1234')
+ self.assertFalse(handler.finished_reading)
+ handler.accept_body('5678')
+ handler.end_of_body()
+ self.assertTrue(handler.finished_reading)
+ self.assertEqual(('ok', ), handler.response.args)
+ self.assertEqual(None, handler.response.body)
+
+ def test_readv_accept_body(self):
+ """'readv' should set finished_reading after reading offsets."""
+ self.build_tree(['a-file'])
+ handler = self.build_handler(self.get_readonly_transport())
+ handler.args_received(('readv', 'a-file'))
+ self.assertFalse(handler.finished_reading)
+ handler.accept_body('2,')
+ self.assertFalse(handler.finished_reading)
+ handler.accept_body('3')
+ handler.end_of_body()
+ self.assertTrue(handler.finished_reading)
+ self.assertEqual(('readv', ), handler.response.args)
+ # co - nte - nt of a-file is the file contents we are extracting from.
+ self.assertEqual('nte', handler.response.body)
+
+ def test_readv_short_read_response_contents(self):
+ """'readv' when a short read occurs sets the response appropriately."""
+ self.build_tree(['a-file'])
+ handler = self.build_handler(self.get_readonly_transport())
+ handler.args_received(('readv', 'a-file'))
+ # read beyond the end of the file.
+ handler.accept_body('100,1')
+ handler.end_of_body()
+ self.assertTrue(handler.finished_reading)
+ self.assertEqual(('ShortReadvError', './a-file', '100', '1', '0'),
+ handler.response.args)
+ self.assertEqual(None, handler.response.body)
+
+
+class RemoteTransportRegistration(tests.TestCase):
+
+ def test_registration(self):
+ t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
+ self.assertIsInstance(t, remote.RemoteSSHTransport)
+ self.assertEqual('example.com', t._parsed_url.host)
+
+ def test_bzr_https(self):
+ # https://bugs.launchpad.net/bzr/+bug/128456
+ t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
+ self.assertIsInstance(t, remote.RemoteHTTPTransport)
+ self.assertStartsWith(
+ t._http_transport.base,
+ 'https://')
+
+
+class TestRemoteTransport(tests.TestCase):
+
+ def test_use_connection_factory(self):
+ # We want to be able to pass a client as a parameter to RemoteTransport.
+ input = StringIO('ok\n3\nbardone\n')
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ transport = remote.RemoteTransport(
+ 'bzr://localhost/', medium=client_medium)
+ # Disable version detection.
+ client_medium._protocol_version = 1
+
+ # We want to make sure the client is used when the first remote
+ # method is called. No data should have been sent, or read.
+ self.assertEqual(0, input.tell())
+ self.assertEqual('', output.getvalue())
+
+ # Now call a method that should result in one request: as the
+ # transport makes its own protocol instances, we check on the wire.
+ # XXX: TODO: give the transport a protocol factory, which can make
+ # an instrumented protocol for us.
+ self.assertEqual('bar', transport.get_bytes('foo'))
+ # only the needed data should have been sent/received.
+ self.assertEqual(13, input.tell())
+ self.assertEqual('get\x01/foo\n', output.getvalue())
+
+ def test__translate_error_readonly(self):
+ """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
+ client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
+ transport = remote.RemoteTransport(
+ 'bzr://localhost/', medium=client_medium)
+ err = errors.ErrorFromSmartServer(("ReadOnlyError", ))
+ self.assertRaises(errors.TransportNotPossible,
+ transport._translate_error, err)
+
+
+class TestSmartProtocol(tests.TestCase):
+ """Base class for smart protocol tests.
+
+ Each test case gets a smart_server and smart_client created during setUp().
+
+ It is planned that the client can be called with self.call_client() giving
+ it an expected server response, which will be fed into it when it tries to
+ read. Likewise, self.call_server will call a servers method with a canned
+ serialised client request. Output done by the client or server for these
+ calls will be captured to self.to_server and self.to_client. Each element
+ in the list is a write call from the client or server respectively.
+
+ Subclasses can override client_protocol_class and server_protocol_class.
+ """
+
+ request_encoder = None
+ response_decoder = None
+ server_protocol_class = None
+ client_protocol_class = None
+
+ def make_client_protocol_and_output(self, input_bytes=None):
+ """
+ :returns: a Request
+ """
+ # This is very similar to
+ # bzrlib.smart.client._SmartClient._build_client_protocol
+ # XXX: make this use _SmartClient!
+ if input_bytes is None:
+ input = StringIO()
+ else:
+ input = StringIO(input_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ if self.client_protocol_class is not None:
+ client_protocol = self.client_protocol_class(request)
+ return client_protocol, client_protocol, output
+ else:
+ self.assertNotEqual(None, self.request_encoder)
+ self.assertNotEqual(None, self.response_decoder)
+ requester = self.request_encoder(request)
+ response_handler = message.ConventionalResponseHandler()
+ response_protocol = self.response_decoder(
+ response_handler, expect_version_marker=True)
+ response_handler.setProtoAndMediumRequest(
+ response_protocol, request)
+ return requester, response_handler, output
+
+ def make_client_protocol(self, input_bytes=None):
+ result = self.make_client_protocol_and_output(input_bytes=input_bytes)
+ requester, response_handler, output = result
+ return requester, response_handler
+
+ def make_server_protocol(self):
+ out_stream = StringIO()
+ smart_protocol = self.server_protocol_class(None, out_stream.write)
+ return smart_protocol, out_stream
+
+ def setUp(self):
+ super(TestSmartProtocol, self).setUp()
+ self.response_marker = getattr(
+ self.client_protocol_class, 'response_marker', None)
+ self.request_marker = getattr(
+ self.client_protocol_class, 'request_marker', None)
+
+ def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
+ requester):
+ """Check that smart (de)serialises offsets as expected.
+
+ We check both serialisation and deserialisation at the same time
+ to ensure that the round tripping cannot skew: both directions should
+ be as expected.
+
+ :param expected_offsets: a readv offset list.
+ :param expected_seralised: an expected serial form of the offsets.
+ """
+ # XXX: '_deserialise_offsets' should be a method of the
+ # SmartServerRequestProtocol in future.
+ readv_cmd = vfs.ReadvRequest(None, '/')
+ offsets = readv_cmd._deserialise_offsets(expected_serialised)
+ self.assertEqual(expected_offsets, offsets)
+ serialised = requester._serialise_offsets(offsets)
+ self.assertEqual(expected_serialised, serialised)
+
+ def build_protocol_waiting_for_body(self):
+ smart_protocol, out_stream = self.make_server_protocol()
+ smart_protocol._has_dispatched = True
+ smart_protocol.request = _mod_request.SmartServerRequestHandler(
+ None, _mod_request.request_handlers, '/')
+ # GZ 2010-08-10: Cycle with closure affects 4 tests
+ class FakeCommand(_mod_request.SmartServerRequest):
+ def do_body(self_cmd, body_bytes):
+ self.end_received = True
+ self.assertEqual('abcdefg', body_bytes)
+ return _mod_request.SuccessfulSmartServerResponse(('ok', ))
+ smart_protocol.request._command = FakeCommand(None)
+ # Call accept_bytes to make sure that internal state like _body_decoder
+ # is initialised. This test should probably be given a clearer
+ # interface to work with that will not cause this inconsistency.
+ # -- Andrew Bennetts, 2006-09-28
+ smart_protocol.accept_bytes('')
+ return smart_protocol
+
+ def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
+ input_tuples):
+ """Assert that each input_tuple serialises as expected_bytes, and the
+ bytes deserialise as expected_tuple.
+ """
+ # check the encoding of the server for all input_tuples matches
+ # expected bytes
+ for input_tuple in input_tuples:
+ server_protocol, server_output = self.make_server_protocol()
+ server_protocol._send_response(
+ _mod_request.SuccessfulSmartServerResponse(input_tuple))
+ self.assertEqual(expected_bytes, server_output.getvalue())
+ # check the decoding of the client smart_protocol from expected_bytes:
+ requester, response_handler = self.make_client_protocol(expected_bytes)
+ requester.call('foo')
+ self.assertEqual(expected_tuple, response_handler.read_response_tuple())
+
+
+class CommonSmartProtocolTestMixin(object):
+
+ def test_connection_closed_reporting(self):
+ requester, response_handler = self.make_client_protocol()
+ requester.call('hello')
+ ex = self.assertRaises(errors.ConnectionReset,
+ response_handler.read_response_tuple)
+ self.assertEqual("Connection closed: "
+ "Unexpected end of message. Please check connectivity "
+ "and permissions, and report a bug if problems persist. ",
+ str(ex))
+
+ def test_server_offset_serialisation(self):
+ """The Smart protocol serialises offsets as a comma and \n string.
+
+ We check a number of boundary cases are as expected: empty, one offset,
+ one with the order of reads not increasing (an out of order read), and
+ one that should coalesce.
+ """
+ requester, response_handler = self.make_client_protocol()
+ self.assertOffsetSerialisation([], '', requester)
+ self.assertOffsetSerialisation([(1,2)], '1,2', requester)
+ self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
+ requester)
+ self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
+ '1,2\n3,4\n100,200', requester)
+
+
+class TestVersionOneFeaturesInProtocolOne(
+ TestSmartProtocol, CommonSmartProtocolTestMixin):
+ """Tests for version one smart protocol features as implemeted by version
+ one."""
+
+ client_protocol_class = protocol.SmartClientRequestProtocolOne
+ server_protocol_class = protocol.SmartServerRequestProtocolOne
+
+ def test_construct_version_one_server_protocol(self):
+ smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
+ self.assertEqual('', smart_protocol.unused_data)
+ self.assertEqual('', smart_protocol.in_buffer)
+ self.assertFalse(smart_protocol._has_dispatched)
+ self.assertEqual(1, smart_protocol.next_read_size())
+
+ def test_construct_version_one_client_protocol(self):
+ # we can construct a client protocol from a client medium request
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = client_medium.get_request()
+ client_protocol = protocol.SmartClientRequestProtocolOne(request)
+
+ def test_accept_bytes_of_bad_request_to_protocol(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolOne(
+ None, out_stream.write)
+ smart_protocol.accept_bytes('abc')
+ self.assertEqual('abc', smart_protocol.in_buffer)
+ smart_protocol.accept_bytes('\n')
+ self.assertEqual(
+ "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
+ out_stream.getvalue())
+ self.assertTrue(smart_protocol._has_dispatched)
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test_accept_body_bytes_to_protocol(self):
+ protocol = self.build_protocol_waiting_for_body()
+ self.assertEqual(6, protocol.next_read_size())
+ protocol.accept_bytes('7\nabc')
+ self.assertEqual(9, protocol.next_read_size())
+ protocol.accept_bytes('defgd')
+ protocol.accept_bytes('one\n')
+ self.assertEqual(0, protocol.next_read_size())
+ self.assertTrue(self.end_received)
+
+ def test_accept_request_and_body_all_at_once(self):
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ mem_transport = memory.MemoryTransport()
+ mem_transport.put_bytes('foo', 'abcdefghij')
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
+ out_stream.write)
+ smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
+ self.assertEqual(0, smart_protocol.next_read_size())
+ self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
+ self.assertEqual('', smart_protocol.unused_data)
+ self.assertEqual('', smart_protocol.in_buffer)
+
+ def test_accept_excess_bytes_are_preserved(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolOne(
+ None, out_stream.write)
+ smart_protocol.accept_bytes('hello\nhello\n')
+ self.assertEqual("ok\x012\n", out_stream.getvalue())
+ self.assertEqual("hello\n", smart_protocol.unused_data)
+ self.assertEqual("", smart_protocol.in_buffer)
+
+ def test_accept_excess_bytes_after_body(self):
+ protocol = self.build_protocol_waiting_for_body()
+ protocol.accept_bytes('7\nabcdefgdone\nX')
+ self.assertTrue(self.end_received)
+ self.assertEqual("X", protocol.unused_data)
+ self.assertEqual("", protocol.in_buffer)
+ protocol.accept_bytes('Y')
+ self.assertEqual("XY", protocol.unused_data)
+ self.assertEqual("", protocol.in_buffer)
+
+ def test_accept_excess_bytes_after_dispatch(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolOne(
+ None, out_stream.write)
+ smart_protocol.accept_bytes('hello\n')
+ self.assertEqual("ok\x012\n", out_stream.getvalue())
+ smart_protocol.accept_bytes('hel')
+ self.assertEqual("hel", smart_protocol.unused_data)
+ smart_protocol.accept_bytes('lo\n')
+ self.assertEqual("hello\n", smart_protocol.unused_data)
+ self.assertEqual("", smart_protocol.in_buffer)
+
+ def test__send_response_sets_finished_reading(self):
+ smart_protocol = protocol.SmartServerRequestProtocolOne(
+ None, lambda x: None)
+ self.assertEqual(1, smart_protocol.next_read_size())
+ smart_protocol._send_response(
+ _mod_request.SuccessfulSmartServerResponse(('x',)))
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test__send_response_errors_with_base_response(self):
+ """Ensure that only the Successful/Failed subclasses are used."""
+ smart_protocol = protocol.SmartServerRequestProtocolOne(
+ None, lambda x: None)
+ self.assertRaises(AttributeError, smart_protocol._send_response,
+ _mod_request.SmartServerResponse(('x',)))
+
+ def test_query_version(self):
+ """query_version on a SmartClientProtocolOne should return a number.
+
+ The protocol provides the query_version because the domain level clients
+ may all need to be able to probe for capabilities.
+ """
+ # What we really want to test here is that SmartClientProtocolOne calls
+ # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
+ # response of tuple-encoded (ok, 1). Also, separately we should test
+ # the error if the response is a non-understood version.
+ input = StringIO('ok\x012\n')
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ self.assertEqual(2, smart_protocol.query_version())
+
+ def test_client_call_empty_response(self):
+ # protocol.call() can get back an empty tuple as a response. This occurs
+ # when the parsed line is an empty line, and results in a tuple with
+ # one element - an empty string.
+ self.assertServerToClientEncoding('\n', ('', ), [(), ('', )])
+
+ def test_client_call_three_element_response(self):
+ # protocol.call() can get back tuples of other lengths. A three element
+ # tuple should be unpacked as three strings.
+ self.assertServerToClientEncoding('a\x01b\x0134\n', ('a', 'b', '34'),
+ [('a', 'b', '34')])
+
+ def test_client_call_with_body_bytes_uploads(self):
+ # protocol.call_with_body_bytes should length-prefix the bytes onto the
+ # wire.
+ expected_bytes = "foo\n7\nabcdefgdone\n"
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
+ self.assertEqual(expected_bytes, output.getvalue())
+
+ def test_client_call_with_body_readv_array(self):
+ # protocol.call_with_upload should encode the readv array and then
+ # length-prefix the bytes onto the wire.
+ expected_bytes = "foo\n7\n1,2\n5,6done\n"
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
+ self.assertEqual(expected_bytes, output.getvalue())
+
+ def _test_client_read_response_tuple_raises_UnknownSmartMethod(self,
+ server_bytes):
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ smart_protocol.call('foo')
+ self.assertRaises(
+ errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
+ # The request has been finished. There is no body to read, and
+ # attempts to read one will fail.
+ self.assertRaises(
+ errors.ReadingCompleted, smart_protocol.read_body_bytes)
+
+ def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
+ """read_response_tuple raises UnknownSmartMethod if the response says
+ the server did not recognise the request.
+ """
+ server_bytes = (
+ "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
+ self._test_client_read_response_tuple_raises_UnknownSmartMethod(
+ server_bytes)
+
+ def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self):
+ """read_response_tuple also raises UnknownSmartMethod if the response
+ from a bzr 0.11 says the server did not recognise the request.
+
+ (bzr 0.11 sends a slightly different error message to later versions.)
+ """
+ server_bytes = (
+ "error\x01Generic bzr smart protocol error: bad request u'foo'\n")
+ self._test_client_read_response_tuple_raises_UnknownSmartMethod(
+ server_bytes)
+
+ def test_client_read_body_bytes_all(self):
+ # read_body_bytes should decode the body bytes from the wire into
+ # a response.
+ expected_bytes = "1234567"
+ server_bytes = "ok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
+
+ def test_client_read_body_bytes_incremental(self):
+ # test reading a few bytes at a time from the body
+ # XXX: possibly we should test dribbling the bytes into the stringio
+ # to make the state machine work harder: however, as we use the
+ # LengthPrefixedBodyDecoder that is already well tested - we can skip
+ # that.
+ expected_bytes = "1234567"
+ server_bytes = "ok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
+
+ def test_client_cancel_read_body_does_not_eat_body_bytes(self):
+ # cancelling the expected body needs to finish the request, but not
+ # read any more bytes.
+ expected_bytes = "1234567"
+ server_bytes = "ok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolOne(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ smart_protocol.cancel_read_body()
+ self.assertEqual(3, input.tell())
+ self.assertRaises(
+ errors.ReadingCompleted, smart_protocol.read_body_bytes)
+
+ def test_client_read_body_bytes_interrupted_connection(self):
+ server_bytes = "ok\n999\nincomplete body"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertRaises(
+ errors.ConnectionReset, smart_protocol.read_body_bytes)
+
+
+class TestVersionOneFeaturesInProtocolTwo(
+ TestSmartProtocol, CommonSmartProtocolTestMixin):
+ """Tests for version one smart protocol features as implemeted by version
+ two.
+ """
+
+ client_protocol_class = protocol.SmartClientRequestProtocolTwo
+ server_protocol_class = protocol.SmartServerRequestProtocolTwo
+
+ def test_construct_version_two_server_protocol(self):
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
+ self.assertEqual('', smart_protocol.unused_data)
+ self.assertEqual('', smart_protocol.in_buffer)
+ self.assertFalse(smart_protocol._has_dispatched)
+ self.assertEqual(1, smart_protocol.next_read_size())
+
+ def test_construct_version_two_client_protocol(self):
+ # we can construct a client protocol from a client medium request
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ None, output, 'base')
+ request = client_medium.get_request()
+ client_protocol = protocol.SmartClientRequestProtocolTwo(request)
+
+ def test_accept_bytes_of_bad_request_to_protocol(self):
+ out_stream = StringIO()
+ smart_protocol = self.server_protocol_class(None, out_stream.write)
+ smart_protocol.accept_bytes('abc')
+ self.assertEqual('abc', smart_protocol.in_buffer)
+ smart_protocol.accept_bytes('\n')
+ self.assertEqual(
+ self.response_marker +
+ "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
+ out_stream.getvalue())
+ self.assertTrue(smart_protocol._has_dispatched)
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test_accept_body_bytes_to_protocol(self):
+ protocol = self.build_protocol_waiting_for_body()
+ self.assertEqual(6, protocol.next_read_size())
+ protocol.accept_bytes('7\nabc')
+ self.assertEqual(9, protocol.next_read_size())
+ protocol.accept_bytes('defgd')
+ protocol.accept_bytes('one\n')
+ self.assertEqual(0, protocol.next_read_size())
+ self.assertTrue(self.end_received)
+
+ def test_accept_request_and_body_all_at_once(self):
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+ mem_transport = memory.MemoryTransport()
+ mem_transport.put_bytes('foo', 'abcdefghij')
+ out_stream = StringIO()
+ smart_protocol = self.server_protocol_class(
+ mem_transport, out_stream.write)
+ smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
+ self.assertEqual(0, smart_protocol.next_read_size())
+ self.assertEqual(self.response_marker +
+ 'success\nreadv\n3\ndefdone\n',
+ out_stream.getvalue())
+ self.assertEqual('', smart_protocol.unused_data)
+ self.assertEqual('', smart_protocol.in_buffer)
+
+ def test_accept_excess_bytes_are_preserved(self):
+ out_stream = StringIO()
+ smart_protocol = self.server_protocol_class(None, out_stream.write)
+ smart_protocol.accept_bytes('hello\nhello\n')
+ self.assertEqual(self.response_marker + "success\nok\x012\n",
+ out_stream.getvalue())
+ self.assertEqual("hello\n", smart_protocol.unused_data)
+ self.assertEqual("", smart_protocol.in_buffer)
+
+ def test_accept_excess_bytes_after_body(self):
+ # The excess bytes look like the start of another request.
+ server_protocol = self.build_protocol_waiting_for_body()
+ server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
+ self.assertTrue(self.end_received)
+ self.assertEqual(self.response_marker,
+ server_protocol.unused_data)
+ self.assertEqual("", server_protocol.in_buffer)
+ server_protocol.accept_bytes('Y')
+ self.assertEqual(self.response_marker + "Y",
+ server_protocol.unused_data)
+ self.assertEqual("", server_protocol.in_buffer)
+
+ def test_accept_excess_bytes_after_dispatch(self):
+ out_stream = StringIO()
+ smart_protocol = self.server_protocol_class(None, out_stream.write)
+ smart_protocol.accept_bytes('hello\n')
+ self.assertEqual(self.response_marker + "success\nok\x012\n",
+ out_stream.getvalue())
+ smart_protocol.accept_bytes(self.request_marker + 'hel')
+ self.assertEqual(self.request_marker + "hel",
+ smart_protocol.unused_data)
+ smart_protocol.accept_bytes('lo\n')
+ self.assertEqual(self.request_marker + "hello\n",
+ smart_protocol.unused_data)
+ self.assertEqual("", smart_protocol.in_buffer)
+
+ def test__send_response_sets_finished_reading(self):
+ smart_protocol = self.server_protocol_class(None, lambda x: None)
+ self.assertEqual(1, smart_protocol.next_read_size())
+ smart_protocol._send_response(
+ _mod_request.SuccessfulSmartServerResponse(('x',)))
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test__send_response_errors_with_base_response(self):
+ """Ensure that only the Successful/Failed subclasses are used."""
+ smart_protocol = self.server_protocol_class(None, lambda x: None)
+ self.assertRaises(AttributeError, smart_protocol._send_response,
+ _mod_request.SmartServerResponse(('x',)))
+
+ def test_query_version(self):
+ """query_version on a SmartClientProtocolTwo should return a number.
+
+ The protocol provides the query_version because the domain level clients
+ may all need to be able to probe for capabilities.
+ """
+ # What we really want to test here is that SmartClientProtocolTwo calls
+ # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
+ # response of tuple-encoded (ok, 1). Also, separately we should test
+ # the error if the response is a non-understood version.
+ input = StringIO(self.response_marker + 'success\nok\x012\n')
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ self.assertEqual(2, smart_protocol.query_version())
+
+ def test_client_call_empty_response(self):
+ # protocol.call() can get back an empty tuple as a response. This occurs
+ # when the parsed line is an empty line, and results in a tuple with
+ # one element - an empty string.
+ self.assertServerToClientEncoding(
+ self.response_marker + 'success\n\n', ('', ), [(), ('', )])
+
+ def test_client_call_three_element_response(self):
+ # protocol.call() can get back tuples of other lengths. A three element
+ # tuple should be unpacked as three strings.
+ self.assertServerToClientEncoding(
+ self.response_marker + 'success\na\x01b\x0134\n',
+ ('a', 'b', '34'),
+ [('a', 'b', '34')])
+
+ def test_client_call_with_body_bytes_uploads(self):
+ # protocol.call_with_body_bytes should length-prefix the bytes onto the
+ # wire.
+ expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
+ self.assertEqual(expected_bytes, output.getvalue())
+
+ def test_client_call_with_body_readv_array(self):
+ # protocol.call_with_upload should encode the readv array and then
+ # length-prefix the bytes onto the wire.
+ expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
+ self.assertEqual(expected_bytes, output.getvalue())
+
+ def test_client_read_body_bytes_all(self):
+ # read_body_bytes should decode the body bytes from the wire into
+ # a response.
+ expected_bytes = "1234567"
+ server_bytes = (self.response_marker +
+ "success\nok\n7\n1234567done\n")
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
+
+ def test_client_read_body_bytes_incremental(self):
+ # test reading a few bytes at a time from the body
+ # XXX: possibly we should test dribbling the bytes into the stringio
+ # to make the state machine work harder: however, as we use the
+ # LengthPrefixedBodyDecoder that is already well tested - we can skip
+ # that.
+ expected_bytes = "1234567"
+ server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
+
+ def test_client_cancel_read_body_does_not_eat_body_bytes(self):
+ # cancelling the expected body needs to finish the request, but not
+ # read any more bytes.
+ server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ smart_protocol.cancel_read_body()
+ self.assertEqual(len(self.response_marker + 'success\nok\n'),
+ input.tell())
+ self.assertRaises(
+ errors.ReadingCompleted, smart_protocol.read_body_bytes)
+
+ def test_client_read_body_bytes_interrupted_connection(self):
+ server_bytes = (self.response_marker +
+ "success\nok\n999\nincomplete body")
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = self.client_protocol_class(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertRaises(
+ errors.ConnectionReset, smart_protocol.read_body_bytes)
+
+
+class TestSmartProtocolTwoSpecificsMixin(object):
+
+ def assertBodyStreamSerialisation(self, expected_serialisation,
+ body_stream):
+ """Assert that body_stream is serialised as expected_serialisation."""
+ out_stream = StringIO()
+ protocol._send_stream(body_stream, out_stream.write)
+ self.assertEqual(expected_serialisation, out_stream.getvalue())
+
+ def assertBodyStreamRoundTrips(self, body_stream):
+ """Assert that body_stream is the same after being serialised and
+ deserialised.
+ """
+ out_stream = StringIO()
+ protocol._send_stream(body_stream, out_stream.write)
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes(out_stream.getvalue())
+ decoded_stream = list(iter(decoder.read_next_chunk, None))
+ self.assertEqual(body_stream, decoded_stream)
+
+ def test_body_stream_serialisation_empty(self):
+ """A body_stream with no bytes can be serialised."""
+ self.assertBodyStreamSerialisation('chunked\nEND\n', [])
+ self.assertBodyStreamRoundTrips([])
+
+ def test_body_stream_serialisation(self):
+ stream = ['chunk one', 'chunk two', 'chunk three']
+ self.assertBodyStreamSerialisation(
+ 'chunked\n' + '9\nchunk one' + '9\nchunk two' + 'b\nchunk three' +
+ 'END\n',
+ stream)
+ self.assertBodyStreamRoundTrips(stream)
+
+ def test_body_stream_with_empty_element_serialisation(self):
+ """A body stream can include ''.
+
+ The empty string can be transmitted like any other string.
+ """
+ stream = ['', 'chunk']
+ self.assertBodyStreamSerialisation(
+ 'chunked\n' + '0\n' + '5\nchunk' + 'END\n', stream)
+ self.assertBodyStreamRoundTrips(stream)
+
+ def test_body_stream_error_serialistion(self):
+ stream = ['first chunk',
+ _mod_request.FailedSmartServerResponse(
+ ('FailureName', 'failure arg'))]
+ expected_bytes = (
+ 'chunked\n' + 'b\nfirst chunk' +
+ 'ERR\n' + 'b\nFailureName' + 'b\nfailure arg' +
+ 'END\n')
+ self.assertBodyStreamSerialisation(expected_bytes, stream)
+ self.assertBodyStreamRoundTrips(stream)
+
+ def test__send_response_includes_failure_marker(self):
+ """FailedSmartServerResponse have 'failed\n' after the version."""
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, out_stream.write)
+ smart_protocol._send_response(
+ _mod_request.FailedSmartServerResponse(('x',)))
+ self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'failed\nx\n',
+ out_stream.getvalue())
+
+ def test__send_response_includes_success_marker(self):
+ """SuccessfulSmartServerResponse have 'success\n' after the version."""
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, out_stream.write)
+ smart_protocol._send_response(
+ _mod_request.SuccessfulSmartServerResponse(('x',)))
+ self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
+ out_stream.getvalue())
+
+ def test__send_response_with_body_stream_sets_finished_reading(self):
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, lambda x: None)
+ self.assertEqual(1, smart_protocol.next_read_size())
+ smart_protocol._send_response(
+ _mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test_streamed_body_bytes(self):
+ body_header = 'chunked\n'
+ two_body_chunks = "4\n1234" + "3\n567"
+ body_terminator = "END\n"
+ server_bytes = (protocol.RESPONSE_VERSION_TWO +
+ "success\nok\n" + body_header + two_body_chunks +
+ body_terminator)
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ stream = smart_protocol.read_streamed_body()
+ self.assertEqual(['1234', '567'], list(stream))
+
+ def test_read_streamed_body_error(self):
+ """When a stream is interrupted by an error..."""
+ body_header = 'chunked\n'
+ a_body_chunk = '4\naaaa'
+ err_signal = 'ERR\n'
+ err_chunks = 'a\nerror arg1' + '4\narg2'
+ finish = 'END\n'
+ body = body_header + a_body_chunk + err_signal + err_chunks + finish
+ server_bytes = (protocol.RESPONSE_VERSION_TWO +
+ "success\nok\n" + body)
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ smart_request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ expected_chunks = [
+ 'aaaa',
+ _mod_request.FailedSmartServerResponse(('error arg1', 'arg2'))]
+ stream = smart_protocol.read_streamed_body()
+ self.assertEqual(expected_chunks, list(stream))
+
+ def test_streamed_body_bytes_interrupted_connection(self):
+ body_header = 'chunked\n'
+ incomplete_body_chunk = "9999\nincomplete chunk"
+ server_bytes = (protocol.RESPONSE_VERSION_TWO +
+ "success\nok\n" + body_header + incomplete_body_chunk)
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ stream = smart_protocol.read_streamed_body()
+ self.assertRaises(errors.ConnectionReset, stream.next)
+
+ def test_client_read_response_tuple_sets_response_status(self):
+ server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(False)
+ self.assertEqual(True, smart_protocol.response_status)
+
+ def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
+ """read_response_tuple raises UnknownSmartMethod if the response says
+ the server did not recognise the request.
+ """
+ server_bytes = (
+ protocol.RESPONSE_VERSION_TWO +
+ "failed\n" +
+ "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'base')
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ self.assertRaises(
+ errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
+ # The request has been finished. There is no body to read, and
+ # attempts to read one will fail.
+ self.assertRaises(
+ errors.ReadingCompleted, smart_protocol.read_body_bytes)
+
+
+class TestSmartProtocolTwoSpecifics(
+ TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
+ """Tests for aspects of smart protocol version two that are unique to
+ version two.
+
+ Thus tests involving body streams and success/failure markers belong here.
+ """
+
+ client_protocol_class = protocol.SmartClientRequestProtocolTwo
+ server_protocol_class = protocol.SmartServerRequestProtocolTwo
+
+
+class TestVersionOneFeaturesInProtocolThree(
+ TestSmartProtocol, CommonSmartProtocolTestMixin):
+ """Tests for version one smart protocol features as implemented by version
+ three.
+ """
+
+ request_encoder = protocol.ProtocolThreeRequester
+ response_decoder = protocol.ProtocolThreeDecoder
+ # build_server_protocol_three is a function, so we can't set it as a class
+ # attribute directly, because then Python will assume it is actually a
+ # method. So we make server_protocol_class be a static method, rather than
+ # simply doing:
+ # "server_protocol_class = protocol.build_server_protocol_three".
+ server_protocol_class = staticmethod(protocol.build_server_protocol_three)
+
+ def setUp(self):
+ super(TestVersionOneFeaturesInProtocolThree, self).setUp()
+ self.response_marker = protocol.MESSAGE_VERSION_THREE
+ self.request_marker = protocol.MESSAGE_VERSION_THREE
+
+ def test_construct_version_three_server_protocol(self):
+ smart_protocol = protocol.ProtocolThreeDecoder(None)
+ self.assertEqual('', smart_protocol.unused_data)
+ self.assertEqual([], smart_protocol._in_buffer_list)
+ self.assertEqual(0, smart_protocol._in_buffer_len)
+ self.assertFalse(smart_protocol._has_dispatched)
+ # The protocol starts by expecting four bytes, a length prefix for the
+ # headers.
+ self.assertEqual(4, smart_protocol.next_read_size())
+
+
+class LoggingMessageHandler(object):
+
+ def __init__(self):
+ self.event_log = []
+
+ def _log(self, *args):
+ self.event_log.append(args)
+
+ def headers_received(self, headers):
+ self._log('headers', headers)
+
+ def protocol_error(self, exception):
+ self._log('protocol_error', exception)
+
+ def byte_part_received(self, byte):
+ self._log('byte', byte)
+
+ def bytes_part_received(self, bytes):
+ self._log('bytes', bytes)
+
+ def structure_part_received(self, structure):
+ self._log('structure', structure)
+
+ def end_received(self):
+ self._log('end')
+
+
+class TestProtocolThree(TestSmartProtocol):
+ """Tests for v3 of the server-side protocol."""
+
+ request_encoder = protocol.ProtocolThreeRequester
+ response_decoder = protocol.ProtocolThreeDecoder
+ server_protocol_class = protocol.ProtocolThreeDecoder
+
+ def test_trivial_request(self):
+ """Smoke test for the simplest possible v3 request: empty headers, no
+ message parts.
+ """
+ output = StringIO()
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ end = 'e'
+ request_bytes = headers + end
+ smart_protocol = self.server_protocol_class(LoggingMessageHandler())
+ smart_protocol.accept_bytes(request_bytes)
+ self.assertEqual(0, smart_protocol.next_read_size())
+ self.assertEqual('', smart_protocol.unused_data)
+
+ def test_repeated_excess(self):
+ """Repeated calls to accept_bytes after the message end has been parsed
+ accumlates the bytes in the unused_data attribute.
+ """
+ output = StringIO()
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ end = 'e'
+ request_bytes = headers + end
+ smart_protocol = self.server_protocol_class(LoggingMessageHandler())
+ smart_protocol.accept_bytes(request_bytes)
+ self.assertEqual('', smart_protocol.unused_data)
+ smart_protocol.accept_bytes('aaa')
+ self.assertEqual('aaa', smart_protocol.unused_data)
+ smart_protocol.accept_bytes('bbb')
+ self.assertEqual('aaabbb', smart_protocol.unused_data)
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def make_protocol_expecting_message_part(self):
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ message_handler = LoggingMessageHandler()
+ smart_protocol = self.server_protocol_class(message_handler)
+ smart_protocol.accept_bytes(headers)
+ # Clear the event log
+ del message_handler.event_log[:]
+ return smart_protocol, message_handler.event_log
+
+ def test_decode_one_byte(self):
+ """The protocol can decode a 'one byte' message part."""
+ smart_protocol, event_log = self.make_protocol_expecting_message_part()
+ smart_protocol.accept_bytes('ox')
+ self.assertEqual([('byte', 'x')], event_log)
+
+ def test_decode_bytes(self):
+ """The protocol can decode a 'bytes' message part."""
+ smart_protocol, event_log = self.make_protocol_expecting_message_part()
+ smart_protocol.accept_bytes(
+ 'b' # message part kind
+ '\0\0\0\x07' # length prefix
+ 'payload' # payload
+ )
+ self.assertEqual([('bytes', 'payload')], event_log)
+
+ def test_decode_structure(self):
+ """The protocol can decode a 'structure' message part."""
+ smart_protocol, event_log = self.make_protocol_expecting_message_part()
+ smart_protocol.accept_bytes(
+ 's' # message part kind
+ '\0\0\0\x07' # length prefix
+ 'l3:ARGe' # ['ARG']
+ )
+ self.assertEqual([('structure', ('ARG',))], event_log)
+
+ def test_decode_multiple_bytes(self):
+ """The protocol can decode a multiple 'bytes' message parts."""
+ smart_protocol, event_log = self.make_protocol_expecting_message_part()
+ smart_protocol.accept_bytes(
+ 'b' # message part kind
+ '\0\0\0\x05' # length prefix
+ 'first' # payload
+ 'b' # message part kind
+ '\0\0\0\x06'
+ 'second'
+ )
+ self.assertEqual(
+ [('bytes', 'first'), ('bytes', 'second')], event_log)
+
+
+class TestConventionalResponseHandlerBodyStream(tests.TestCase):
+
+ def make_response_handler(self, response_bytes):
+ from bzrlib.smart.message import ConventionalResponseHandler
+ response_handler = ConventionalResponseHandler()
+ protocol_decoder = protocol.ProtocolThreeDecoder(response_handler)
+ # put decoder in desired state (waiting for message parts)
+ protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ StringIO(response_bytes), output, 'base')
+ medium_request = client_medium.get_request()
+ medium_request.finished_writing()
+ response_handler.setProtoAndMediumRequest(
+ protocol_decoder, medium_request)
+ return response_handler
+
+ def test_interrupted_by_error(self):
+ response_handler = self.make_response_handler(interrupted_body_stream)
+ stream = response_handler.read_streamed_body()
+ self.assertEqual('aaa', stream.next())
+ self.assertEqual('bbb', stream.next())
+ exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
+ self.assertEqual(('error', 'Exception', 'Boom!'), exc.error_tuple)
+
+ def test_interrupted_by_connection_lost(self):
+ interrupted_body_stream = (
+ 'oS' # successful response
+ 's\0\0\0\x02le' # empty args
+ 'b\0\0\xff\xffincomplete chunk')
+ response_handler = self.make_response_handler(interrupted_body_stream)
+ stream = response_handler.read_streamed_body()
+ self.assertRaises(errors.ConnectionReset, stream.next)
+
+ def test_read_body_bytes_interrupted_by_connection_lost(self):
+ interrupted_body_stream = (
+ 'oS' # successful response
+ 's\0\0\0\x02le' # empty args
+ 'b\0\0\xff\xffincomplete chunk')
+ response_handler = self.make_response_handler(interrupted_body_stream)
+ self.assertRaises(
+ errors.ConnectionReset, response_handler.read_body_bytes)
+
+ def test_multiple_bytes_parts(self):
+ multiple_bytes_parts = (
+ 'oS' # successful response
+ 's\0\0\0\x02le' # empty args
+ 'b\0\0\0\x0bSome bytes\n' # some bytes
+ 'b\0\0\0\x0aMore bytes' # more bytes
+ 'e' # message end
+ )
+ response_handler = self.make_response_handler(multiple_bytes_parts)
+ self.assertEqual(
+ 'Some bytes\nMore bytes', response_handler.read_body_bytes())
+ response_handler = self.make_response_handler(multiple_bytes_parts)
+ self.assertEqual(
+ ['Some bytes\n', 'More bytes'],
+ list(response_handler.read_streamed_body()))
+
+
+class FakeResponder(object):
+
+ response_sent = False
+
+ def send_error(self, exc):
+ raise exc
+
+ def send_response(self, response):
+ pass
+
+
+class TestConventionalRequestHandlerBodyStream(tests.TestCase):
+ """Tests for ConventionalRequestHandler's handling of request bodies."""
+
+ def make_request_handler(self, request_bytes):
+ """Make a ConventionalRequestHandler for the given bytes using test
+ doubles for the request_handler and the responder.
+ """
+ from bzrlib.smart.message import ConventionalRequestHandler
+ request_handler = InstrumentedRequestHandler()
+ request_handler.response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
+ responder = FakeResponder()
+ message_handler = ConventionalRequestHandler(request_handler, responder)
+ protocol_decoder = protocol.ProtocolThreeDecoder(message_handler)
+ # put decoder in desired state (waiting for message parts)
+ protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
+ protocol_decoder.accept_bytes(request_bytes)
+ return request_handler
+
+ def test_multiple_bytes_parts(self):
+ """Each bytes part triggers a call to the request_handler's
+ accept_body method.
+ """
+ multiple_bytes_parts = (
+ 's\0\0\0\x07l3:fooe' # args
+ 'b\0\0\0\x0bSome bytes\n' # some bytes
+ 'b\0\0\0\x0aMore bytes' # more bytes
+ 'e' # message end
+ )
+ request_handler = self.make_request_handler(multiple_bytes_parts)
+ accept_body_calls = [
+ call_info[1] for call_info in request_handler.calls
+ if call_info[0] == 'accept_body']
+ self.assertEqual(
+ ['Some bytes\n', 'More bytes'], accept_body_calls)
+
+ def test_error_flag_after_body(self):
+ body_then_error = (
+ 's\0\0\0\x07l3:fooe' # request args
+ 'b\0\0\0\x0bSome bytes\n' # some bytes
+ 'b\0\0\0\x0aMore bytes' # more bytes
+ 'oE' # error flag
+ 's\0\0\0\x07l3:bare' # error args
+ 'e' # message end
+ )
+ request_handler = self.make_request_handler(body_then_error)
+ self.assertEqual(
+ [('post_body_error_received', ('bar',)), ('end_received',)],
+ request_handler.calls[-2:])
+
+
+class TestMessageHandlerErrors(tests.TestCase):
+ """Tests for v3 that unrecognised (but well-formed) requests/responses are
+ still fully read off the wire, so that subsequent requests/responses on the
+ same medium can be decoded.
+ """
+
+ def test_non_conventional_request(self):
+ """ConventionalRequestHandler (the default message handler on the
+ server side) will reject an unconventional message, but still consume
+ all the bytes of that message and signal when it has done so.
+
+ This is what allows a server to continue to accept requests after the
+ client sends a completely unrecognised request.
+ """
+ # Define an invalid request (but one that is a well-formed message).
+ # This particular invalid request not only lacks the mandatory
+ # verb+args tuple, it has a single-byte part, which is forbidden. In
+ # fact it has that part twice, to trigger multiple errors.
+ invalid_request = (
+ protocol.MESSAGE_VERSION_THREE + # protocol version marker
+ '\0\0\0\x02de' + # empty headers
+ 'oX' + # a single byte part: 'X'. ConventionalRequestHandler will
+ # error at this part.
+ 'oX' + # and again.
+ 'e' # end of message
+ )
+
+ to_server = StringIO(invalid_request)
+ from_server = StringIO()
+ transport = memory.MemoryTransport('memory:///')
+ server = medium.SmartServerPipeStreamMedium(
+ to_server, from_server, transport, timeout=4.0)
+ proto = server._build_protocol()
+ message_handler = proto.message_handler
+ server._serve_one_request(proto)
+ # All the bytes have been read from the medium...
+ self.assertEqual('', to_server.read())
+ # ...and the protocol decoder has consumed all the bytes, and has
+ # finished reading.
+ self.assertEqual('', proto.unused_data)
+ self.assertEqual(0, proto.next_read_size())
+
+
+class InstrumentedRequestHandler(object):
+ """Test Double of SmartServerRequestHandler."""
+
+ def __init__(self):
+ self.calls = []
+ self.finished_reading = False
+
+ def no_body_received(self):
+ self.calls.append(('no_body_received',))
+
+ def end_received(self):
+ self.calls.append(('end_received',))
+ self.finished_reading = True
+
+ def args_received(self, args):
+ self.calls.append(('args_received', args))
+
+ def accept_body(self, bytes):
+ self.calls.append(('accept_body', bytes))
+
+ def end_of_body(self):
+ self.calls.append(('end_of_body',))
+ self.finished_reading = True
+
+ def post_body_error_received(self, error_args):
+ self.calls.append(('post_body_error_received', error_args))
+
+
+class StubRequest(object):
+
+ def finished_reading(self):
+ pass
+
+
+class TestClientDecodingProtocolThree(TestSmartProtocol):
+ """Tests for v3 of the client-side protocol decoding."""
+
+ def make_logging_response_decoder(self):
+ """Make v3 response decoder using a test response handler."""
+ response_handler = LoggingMessageHandler()
+ decoder = protocol.ProtocolThreeDecoder(response_handler)
+ return decoder, response_handler
+
+ def make_conventional_response_decoder(self):
+ """Make v3 response decoder using a conventional response handler."""
+ response_handler = message.ConventionalResponseHandler()
+ decoder = protocol.ProtocolThreeDecoder(response_handler)
+ response_handler.setProtoAndMediumRequest(decoder, StubRequest())
+ return decoder, response_handler
+
+ def test_trivial_response_decoding(self):
+ """Smoke test for the simplest possible v3 response: empty headers,
+ status byte, empty args, no body.
+ """
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ response_status = 'oS' # success
+ args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
+ end = 'e' # end marker
+ message_bytes = headers + response_status + args + end
+ decoder, response_handler = self.make_logging_response_decoder()
+ decoder.accept_bytes(message_bytes)
+ # The protocol decoder has finished, and consumed all bytes
+ self.assertEqual(0, decoder.next_read_size())
+ self.assertEqual('', decoder.unused_data)
+ # The message handler has been invoked with all the parts of the
+ # trivial response: empty headers, status byte, no args, end.
+ self.assertEqual(
+ [('headers', {}), ('byte', 'S'), ('structure', ()), ('end',)],
+ response_handler.event_log)
+
+ def test_incomplete_message(self):
+ """A decoder will keep signalling that it needs more bytes via
+ next_read_size() != 0 until it has seen a complete message, regardless
+ which state it is in.
+ """
+ # Define a simple response that uses all possible message parts.
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ response_status = 'oS' # success
+ args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
+ body = 'b\0\0\0\x04BODY' # a body: 'BODY'
+ end = 'e' # end marker
+ simple_response = headers + response_status + args + body + end
+ # Feed the request to the decoder one byte at a time.
+ decoder, response_handler = self.make_logging_response_decoder()
+ for byte in simple_response:
+ self.assertNotEqual(0, decoder.next_read_size())
+ decoder.accept_bytes(byte)
+ # Now the response is complete
+ self.assertEqual(0, decoder.next_read_size())
+
+ def test_read_response_tuple_raises_UnknownSmartMethod(self):
+ """read_response_tuple raises UnknownSmartMethod if the server replied
+ with 'UnknownMethod'.
+ """
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ response_status = 'oE' # error flag
+ # args: ('UnknownMethod', 'method-name')
+ args = 's\0\0\0\x20l13:UnknownMethod11:method-namee'
+ end = 'e' # end marker
+ message_bytes = headers + response_status + args + end
+ decoder, response_handler = self.make_conventional_response_decoder()
+ decoder.accept_bytes(message_bytes)
+ error = self.assertRaises(
+ errors.UnknownSmartMethod, response_handler.read_response_tuple)
+ self.assertEqual('method-name', error.verb)
+
+ def test_read_response_tuple_error(self):
+ """If the response has an error, it is raised as an exception."""
+ headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
+ response_status = 'oE' # error
+ args = 's\0\0\0\x1al9:first arg10:second arge' # two args
+ end = 'e' # end marker
+ message_bytes = headers + response_status + args + end
+ decoder, response_handler = self.make_conventional_response_decoder()
+ decoder.accept_bytes(message_bytes)
+ error = self.assertRaises(
+ errors.ErrorFromSmartServer, response_handler.read_response_tuple)
+ self.assertEqual(('first arg', 'second arg'), error.error_tuple)
+
+
+class TestClientEncodingProtocolThree(TestSmartProtocol):
+
+ request_encoder = protocol.ProtocolThreeRequester
+ response_decoder = protocol.ProtocolThreeDecoder
+ server_protocol_class = protocol.ProtocolThreeDecoder
+
+ def make_client_encoder_and_output(self):
+ result = self.make_client_protocol_and_output()
+ requester, response_handler, output = result
+ return requester, output
+
+ def test_call_smoke_test(self):
+ """A smoke test for ProtocolThreeRequester.call.
+
+ This test checks that a particular simple invocation of call emits the
+ correct bytes for that invocation.
+ """
+ requester, output = self.make_client_encoder_and_output()
+ requester.set_headers({'header name': 'header value'})
+ requester.call('one arg')
+ self.assertEquals(
+ 'bzr message 3 (bzr 1.6)\n' # protocol version
+ '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
+ 's\x00\x00\x00\x0bl7:one arge' # args
+ 'e', # end
+ output.getvalue())
+
+ def test_call_with_body_bytes_smoke_test(self):
+ """A smoke test for ProtocolThreeRequester.call_with_body_bytes.
+
+ This test checks that a particular simple invocation of
+ call_with_body_bytes emits the correct bytes for that invocation.
+ """
+ requester, output = self.make_client_encoder_and_output()
+ requester.set_headers({'header name': 'header value'})
+ requester.call_with_body_bytes(('one arg',), 'body bytes')
+ self.assertEquals(
+ 'bzr message 3 (bzr 1.6)\n' # protocol version
+ '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
+ 's\x00\x00\x00\x0bl7:one arge' # args
+ 'b' # there is a prefixed body
+ '\x00\x00\x00\nbody bytes' # the prefixed body
+ 'e', # end
+ output.getvalue())
+
+ def test_call_writes_just_once(self):
+ """A bodyless request is written to the medium all at once."""
+ medium_request = StubMediumRequest()
+ encoder = protocol.ProtocolThreeRequester(medium_request)
+ encoder.call('arg1', 'arg2', 'arg3')
+ self.assertEqual(
+ ['accept_bytes', 'finished_writing'], medium_request.calls)
+
+ def test_call_with_body_bytes_writes_just_once(self):
+ """A request with body bytes is written to the medium all at once."""
+ medium_request = StubMediumRequest()
+ encoder = protocol.ProtocolThreeRequester(medium_request)
+ encoder.call_with_body_bytes(('arg', 'arg'), 'body bytes')
+ self.assertEqual(
+ ['accept_bytes', 'finished_writing'], medium_request.calls)
+
+ def test_call_with_body_stream_smoke_test(self):
+ """A smoke test for ProtocolThreeRequester.call_with_body_stream.
+
+ This test checks that a particular simple invocation of
+ call_with_body_stream emits the correct bytes for that invocation.
+ """
+ requester, output = self.make_client_encoder_and_output()
+ requester.set_headers({'header name': 'header value'})
+ stream = ['chunk 1', 'chunk two']
+ requester.call_with_body_stream(('one arg',), stream)
+ self.assertEquals(
+ 'bzr message 3 (bzr 1.6)\n' # protocol version
+ '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
+ 's\x00\x00\x00\x0bl7:one arge' # args
+ 'b\x00\x00\x00\x07chunk 1' # a prefixed body chunk
+ 'b\x00\x00\x00\x09chunk two' # a prefixed body chunk
+ 'e', # end
+ output.getvalue())
+
+ def test_call_with_body_stream_empty_stream(self):
+ """call_with_body_stream with an empty stream."""
+ requester, output = self.make_client_encoder_and_output()
+ requester.set_headers({})
+ stream = []
+ requester.call_with_body_stream(('one arg',), stream)
+ self.assertEquals(
+ 'bzr message 3 (bzr 1.6)\n' # protocol version
+ '\x00\x00\x00\x02de' # headers
+ 's\x00\x00\x00\x0bl7:one arge' # args
+ # no body chunks
+ 'e', # end
+ output.getvalue())
+
+ def test_call_with_body_stream_error(self):
+ """call_with_body_stream will abort the streamed body with an
+ error if the stream raises an error during iteration.
+
+ The resulting request will still be a complete message.
+ """
+ requester, output = self.make_client_encoder_and_output()
+ requester.set_headers({})
+ def stream_that_fails():
+ yield 'aaa'
+ yield 'bbb'
+ raise Exception('Boom!')
+ self.assertRaises(Exception, requester.call_with_body_stream,
+ ('one arg',), stream_that_fails())
+ self.assertEquals(
+ 'bzr message 3 (bzr 1.6)\n' # protocol version
+ '\x00\x00\x00\x02de' # headers
+ 's\x00\x00\x00\x0bl7:one arge' # args
+ 'b\x00\x00\x00\x03aaa' # body
+ 'b\x00\x00\x00\x03bbb' # more body
+ 'oE' # error flag
+ 's\x00\x00\x00\x09l5:errore' # error args: ('error',)
+ 'e', # end
+ output.getvalue())
+
+ def test_records_start_of_body_stream(self):
+ requester, output = self.make_client_encoder_and_output()
+ requester.set_headers({})
+ in_stream = [False]
+ def stream_checker():
+ self.assertTrue(requester.body_stream_started)
+ in_stream[0] = True
+ yield 'content'
+ flush_called = []
+ orig_flush = requester.flush
+ def tracked_flush():
+ flush_called.append(in_stream[0])
+ if in_stream[0]:
+ self.assertTrue(requester.body_stream_started)
+ else:
+ self.assertFalse(requester.body_stream_started)
+ return orig_flush()
+ requester.flush = tracked_flush
+ requester.call_with_body_stream(('one arg',), stream_checker())
+ self.assertEqual(
+ 'bzr message 3 (bzr 1.6)\n' # protocol version
+ '\x00\x00\x00\x02de' # headers
+ 's\x00\x00\x00\x0bl7:one arge' # args
+ 'b\x00\x00\x00\x07content' # body
+ 'e', output.getvalue())
+ self.assertEqual([False, True, True], flush_called)
+
+
+class StubMediumRequest(object):
+ """A stub medium request that tracks the number of times accept_bytes is
+ called.
+ """
+
+ def __init__(self):
+ self.calls = []
+ self._medium = 'dummy medium'
+
+ def accept_bytes(self, bytes):
+ self.calls.append('accept_bytes')
+
+ def finished_writing(self):
+ self.calls.append('finished_writing')
+
+
+interrupted_body_stream = (
+ 'oS' # status flag (success)
+ 's\x00\x00\x00\x08l4:argse' # args struct ('args,')
+ 'b\x00\x00\x00\x03aaa' # body part ('aaa')
+ 'b\x00\x00\x00\x03bbb' # body part ('bbb')
+ 'oE' # status flag (error)
+ # err struct ('error', 'Exception', 'Boom!')
+ 's\x00\x00\x00\x1bl5:error9:Exception5:Boom!e'
+ 'e' # EOM
+ )
+
+
+class TestResponseEncodingProtocolThree(tests.TestCase):
+
+ def make_response_encoder(self):
+ out_stream = StringIO()
+ response_encoder = protocol.ProtocolThreeResponder(out_stream.write)
+ return response_encoder, out_stream
+
+ def test_send_error_unknown_method(self):
+ encoder, out_stream = self.make_response_encoder()
+ encoder.send_error(errors.UnknownSmartMethod('method name'))
+ # Use assertEndsWith so that we don't compare the header, which varies
+ # by bzrlib.__version__.
+ self.assertEndsWith(
+ out_stream.getvalue(),
+ # error status
+ 'oE' +
+ # tuple: 'UnknownMethod', 'method name'
+ 's\x00\x00\x00\x20l13:UnknownMethod11:method namee'
+ # end of message
+ 'e')
+
+ def test_send_broken_body_stream(self):
+ encoder, out_stream = self.make_response_encoder()
+ encoder._headers = {}
+ def stream_that_fails():
+ yield 'aaa'
+ yield 'bbb'
+ raise Exception('Boom!')
+ response = _mod_request.SuccessfulSmartServerResponse(
+ ('args',), body_stream=stream_that_fails())
+ encoder.send_response(response)
+ expected_response = (
+ 'bzr message 3 (bzr 1.6)\n' # protocol marker
+ '\x00\x00\x00\x02de' # headers dict (empty)
+ + interrupted_body_stream)
+ self.assertEqual(expected_response, out_stream.getvalue())
+
+
+class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
+ """Tests for buffering of responses.
+
+ We want to avoid doing many small writes when one would do, to avoid
+ unnecessary network overhead.
+ """
+
+ def setUp(self):
+ tests.TestCase.setUp(self)
+ self.writes = []
+ self.responder = protocol.ProtocolThreeResponder(self.writes.append)
+
+ def assertWriteCount(self, expected_count):
+ # self.writes can be quite large; don't show the whole thing
+ self.assertEqual(
+ expected_count, len(self.writes),
+ "Too many writes: %d, expected %d" % (len(self.writes), expected_count))
+
+ def test_send_error_writes_just_once(self):
+ """An error response is written to the medium all at once."""
+ self.responder.send_error(Exception('An exception string.'))
+ self.assertWriteCount(1)
+
+ def test_send_response_writes_just_once(self):
+ """A normal response with no body is written to the medium all at once.
+ """
+ response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
+ self.responder.send_response(response)
+ self.assertWriteCount(1)
+
+ def test_send_response_with_body_writes_just_once(self):
+ """A normal response with a monolithic body is written to the medium
+ all at once.
+ """
+ response = _mod_request.SuccessfulSmartServerResponse(
+ ('arg', 'arg'), body='body bytes')
+ self.responder.send_response(response)
+ self.assertWriteCount(1)
+
+ def test_send_response_with_body_stream_buffers_writes(self):
+ """A normal response with a stream body writes to the medium once."""
+ # Construct a response with stream with 2 chunks in it.
+ response = _mod_request.SuccessfulSmartServerResponse(
+ ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
+ self.responder.send_response(response)
+ # Per the discussion in bug 590638 we flush once after the header and
+ # then once after each chunk
+ self.assertWriteCount(3)
+
+
+class TestSmartClientUnicode(tests.TestCase):
+ """_SmartClient tests for unicode arguments.
+
+ Unicode arguments to call_with_body_bytes are not correct (remote method
+ names, arguments, and bodies must all be expressed as byte strings), but
+ _SmartClient should gracefully reject them, rather than getting into a
+ broken state that prevents future correct calls from working. That is, it
+ should be possible to issue more requests on the medium afterwards, rather
+ than allowing one bad call to call_with_body_bytes to cause later calls to
+ mysteriously fail with TooManyConcurrentRequests.
+ """
+
+ def assertCallDoesNotBreakMedium(self, method, args, body):
+ """Call a medium with the given method, args and body, then assert that
+ the medium is left in a sane state, i.e. is capable of allowing further
+ requests.
+ """
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(
+ input, output, 'ignored base')
+ smart_client = client._SmartClient(client_medium)
+ self.assertRaises(TypeError,
+ smart_client.call_with_body_bytes, method, args, body)
+ self.assertEqual("", output.getvalue())
+ self.assertEqual(None, client_medium._current_request)
+
+ def test_call_with_body_bytes_unicode_method(self):
+ self.assertCallDoesNotBreakMedium(u'method', ('args',), 'body')
+
+ def test_call_with_body_bytes_unicode_args(self):
+ self.assertCallDoesNotBreakMedium('method', (u'args',), 'body')
+ self.assertCallDoesNotBreakMedium('method', ('arg1', u'arg2'), 'body')
+
+ def test_call_with_body_bytes_unicode_body(self):
+ self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
+
+
+class MockMedium(medium.SmartClientMedium):
+ """A mock medium that can be used to test _SmartClient.
+
+ It can be given a series of requests to expect (and responses it should
+ return for them). It can also be told when the client is expected to
+ disconnect a medium. Expectations must be satisfied in the order they are
+ given, or else an AssertionError will be raised.
+
+ Typical use looks like::
+
+ medium = MockMedium()
+ medium.expect_request(...)
+ medium.expect_request(...)
+ medium.expect_request(...)
+ """
+
+ def __init__(self):
+ super(MockMedium, self).__init__('dummy base')
+ self._mock_request = _MockMediumRequest(self)
+ self._expected_events = []
+
+ def expect_request(self, request_bytes, response_bytes,
+ allow_partial_read=False):
+ """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
+
+ No assumption is made about how many times accept_bytes should be
+ called to send the request. Similarly, no assumption is made about how
+ many times read_bytes/read_line are called by protocol code to read a
+ response. e.g.::
+
+ request.accept_bytes('ab')
+ request.accept_bytes('cd')
+ request.finished_writing()
+
+ and::
+
+ request.accept_bytes('abcd')
+ request.finished_writing()
+
+ Will both satisfy ``medium.expect_request('abcd', ...)``. Thus tests
+ using this should not break due to irrelevant changes in protocol
+ implementations.
+
+ :param allow_partial_read: if True, no assertion is raised if a
+ response is not fully read. Setting this is useful when the client
+ is expected to disconnect without needing to read the complete
+ response. Default is False.
+ """
+ self._expected_events.append(('send request', request_bytes))
+ if allow_partial_read:
+ self._expected_events.append(
+ ('read response (partial)', response_bytes))
+ else:
+ self._expected_events.append(('read response', response_bytes))
+
+ def expect_disconnect(self):
+ """Expect the client to call ``medium.disconnect()``."""
+ self._expected_events.append('disconnect')
+
+ def _assertEvent(self, observed_event):
+ """Raise AssertionError unless observed_event matches the next expected
+ event.
+
+ :seealso: expect_request
+ :seealso: expect_disconnect
+ """
+ try:
+ expected_event = self._expected_events.pop(0)
+ except IndexError:
+ raise AssertionError(
+ 'Mock medium observed event %r, but no more events expected'
+ % (observed_event,))
+ if expected_event[0] == 'read response (partial)':
+ if observed_event[0] != 'read response':
+ raise AssertionError(
+ 'Mock medium observed event %r, but expected event %r'
+ % (observed_event, expected_event))
+ elif observed_event != expected_event:
+ raise AssertionError(
+ 'Mock medium observed event %r, but expected event %r'
+ % (observed_event, expected_event))
+ if self._expected_events:
+ next_event = self._expected_events[0]
+ if next_event[0].startswith('read response'):
+ self._mock_request._response = next_event[1]
+
+ def get_request(self):
+ return self._mock_request
+
+ def disconnect(self):
+ if self._mock_request._read_bytes:
+ self._assertEvent(('read response', self._mock_request._read_bytes))
+ self._mock_request._read_bytes = ''
+ self._assertEvent('disconnect')
+
+
+class _MockMediumRequest(object):
+ """A mock ClientMediumRequest used by MockMedium."""
+
+ def __init__(self, mock_medium):
+ self._medium = mock_medium
+ self._written_bytes = ''
+ self._read_bytes = ''
+ self._response = None
+
+ def accept_bytes(self, bytes):
+ self._written_bytes += bytes
+
+ def finished_writing(self):
+ self._medium._assertEvent(('send request', self._written_bytes))
+ self._written_bytes = ''
+
+ def finished_reading(self):
+ self._medium._assertEvent(('read response', self._read_bytes))
+ self._read_bytes = ''
+
+ def read_bytes(self, size):
+ resp = self._response
+ bytes, resp = resp[:size], resp[size:]
+ self._response = resp
+ self._read_bytes += bytes
+ return bytes
+
+ def read_line(self):
+ resp = self._response
+ try:
+ line, resp = resp.split('\n', 1)
+ line += '\n'
+ except ValueError:
+ line, resp = resp, ''
+ self._response = resp
+ self._read_bytes += line
+ return line
+
+
+class Test_SmartClientVersionDetection(tests.TestCase):
+ """Tests for _SmartClient's automatic protocol version detection.
+
+ On the first remote call, _SmartClient will keep retrying the request with
+ different protocol versions until it finds one that works.
+ """
+
+ def test_version_three_server(self):
+ """With a protocol 3 server, only one request is needed."""
+ medium = MockMedium()
+ smart_client = client._SmartClient(medium, headers={})
+ message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
+ medium.expect_request(
+ message_start +
+ 's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
+ message_start + 's\0\0\0\x13l14:response valueee')
+ result = smart_client.call('method-name', 'arg 1', 'arg 2')
+ # The call succeeded without raising any exceptions from the mock
+ # medium, and the smart_client returns the response from the server.
+ self.assertEqual(('response value',), result)
+ self.assertEqual([], medium._expected_events)
+ # Also, the v3 works then the server should be assumed to support RPCs
+ # introduced in 1.6.
+ self.assertFalse(medium._is_remote_before((1, 6)))
+
+ def test_version_two_server(self):
+ """If the server only speaks protocol 2, the client will first try
+ version 3, then fallback to protocol 2.
+
+ Further, _SmartClient caches the detection, so future requests will all
+ use protocol 2 immediately.
+ """
+ medium = MockMedium()
+ smart_client = client._SmartClient(medium, headers={})
+ # First the client should send a v3 request, but the server will reply
+ # with a v2 error.
+ medium.expect_request(
+ 'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
+ 's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
+ 'bzr response 2\nfailed\n\n')
+ # So then the client should disconnect to reset the connection, because
+ # the client needs to assume the server cannot read any further
+ # requests off the original connection.
+ medium.expect_disconnect()
+ # The client should then retry the original request in v2
+ medium.expect_request(
+ 'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
+ 'bzr response 2\nsuccess\nresponse value\n')
+ result = smart_client.call('method-name', 'arg 1', 'arg 2')
+ # The smart_client object will return the result of the successful
+ # query.
+ self.assertEqual(('response value',), result)
+
+ # Now try another request, and this time the client will just use
+ # protocol 2. (i.e. the autodetection won't be repeated)
+ medium.expect_request(
+ 'bzr request 2\nanother-method\n',
+ 'bzr response 2\nsuccess\nanother response\n')
+ result = smart_client.call('another-method')
+ self.assertEqual(('another response',), result)
+ self.assertEqual([], medium._expected_events)
+
+ # Also, because v3 is not supported, the client medium should assume
+ # that RPCs introduced in 1.6 aren't supported either.
+ self.assertTrue(medium._is_remote_before((1, 6)))
+
+ def test_unknown_version(self):
+ """If the server does not use any known (or at least supported)
+ protocol version, a SmartProtocolError is raised.
+ """
+ medium = MockMedium()
+ smart_client = client._SmartClient(medium, headers={})
+ unknown_protocol_bytes = 'Unknown protocol!'
+ # The client will try v3 and v2 before eventually giving up.
+ medium.expect_request(
+ 'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
+ 's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
+ unknown_protocol_bytes)
+ medium.expect_disconnect()
+ medium.expect_request(
+ 'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
+ unknown_protocol_bytes)
+ medium.expect_disconnect()
+ self.assertRaises(
+ errors.SmartProtocolError,
+ smart_client.call, 'method-name', 'arg 1', 'arg 2')
+ self.assertEqual([], medium._expected_events)
+
+ def test_first_response_is_error(self):
+ """If the server replies with an error, then the version detection
+ should be complete.
+
+ This test is very similar to test_version_two_server, but catches a bug
+ we had in the case where the first reply was an error response.
+ """
+ medium = MockMedium()
+ smart_client = client._SmartClient(medium, headers={})
+ message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
+ # Issue a request that gets an error reply in a non-default protocol
+ # version.
+ medium.expect_request(
+ message_start +
+ 's\x00\x00\x00\x10l11:method-nameee',
+ 'bzr response 2\nfailed\n\n')
+ medium.expect_disconnect()
+ medium.expect_request(
+ 'bzr request 2\nmethod-name\n',
+ 'bzr response 2\nfailed\nFooBarError\n')
+ err = self.assertRaises(
+ errors.ErrorFromSmartServer,
+ smart_client.call, 'method-name')
+ self.assertEqual(('FooBarError',), err.error_tuple)
+ # Now the medium should have remembered the protocol version, so
+ # subsequent requests will use the remembered version immediately.
+ medium.expect_request(
+ 'bzr request 2\nmethod-name\n',
+ 'bzr response 2\nsuccess\nresponse value\n')
+ result = smart_client.call('method-name')
+ self.assertEqual(('response value',), result)
+ self.assertEqual([], medium._expected_events)
+
+
+class Test_SmartClient(tests.TestCase):
+
+ def test_call_default_headers(self):
+ """ProtocolThreeRequester.call by default sends a 'Software
+ version' header.
+ """
+ smart_client = client._SmartClient('dummy medium')
+ self.assertEqual(
+ bzrlib.__version__, smart_client._headers['Software version'])
+ # XXX: need a test that smart_client._headers is passed to the request
+ # encoder.
+
+
+class Test_SmartClientRequest(tests.TestCase):
+
+ def make_client_with_failing_medium(self, fail_at_write=True, response=''):
+ response_io = StringIO(response)
+ output = StringIO()
+ vendor = FirstRejectedStringIOSSHVendor(response_io, output,
+ fail_at_write=fail_at_write)
+ ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
+ client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
+ smart_client = client._SmartClient(client_medium, headers={})
+ return output, vendor, smart_client
+
+ def make_response(self, args, body=None, body_stream=None):
+ response_io = StringIO()
+ response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
+ body_stream=body_stream)
+ responder = protocol.ProtocolThreeResponder(response_io.write)
+ responder.send_response(response)
+ return response_io.getvalue()
+
+ def test__call_doesnt_retry_append(self):
+ response = self.make_response(('appended', '8'))
+ output, vendor, smart_client = self.make_client_with_failing_medium(
+ fail_at_write=False, response=response)
+ smart_request = client._SmartClientRequest(smart_client, 'append',
+ ('foo', ''), body='content\n')
+ self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
+
+ def test__call_retries_get_bytes(self):
+ response = self.make_response(('ok',), 'content\n')
+ output, vendor, smart_client = self.make_client_with_failing_medium(
+ fail_at_write=False, response=response)
+ smart_request = client._SmartClientRequest(smart_client, 'get',
+ ('foo',))
+ response, response_handler = smart_request._call(3)
+ self.assertEqual(('ok',), response)
+ self.assertEqual('content\n', response_handler.read_body_bytes())
+
+ def test__call_noretry_get_bytes(self):
+ debug.debug_flags.add('noretry')
+ response = self.make_response(('ok',), 'content\n')
+ output, vendor, smart_client = self.make_client_with_failing_medium(
+ fail_at_write=False, response=response)
+ smart_request = client._SmartClientRequest(smart_client, 'get',
+ ('foo',))
+ self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
+
+ def test__send_no_retry_pipes(self):
+ client_read, server_write = create_file_pipes()
+ server_read, client_write = create_file_pipes()
+ client_medium = medium.SmartSimplePipesClientMedium(client_read,
+ client_write, base='/')
+ smart_client = client._SmartClient(client_medium)
+ smart_request = client._SmartClientRequest(smart_client,
+ 'hello', ())
+ # Close the server side
+ server_read.close()
+ encoder, response_handler = smart_request._construct_protocol(3)
+ self.assertRaises(errors.ConnectionReset,
+ smart_request._send_no_retry, encoder)
+
+ def test__send_read_response_sockets(self):
+ listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(1)
+ host, port = listen_sock.getsockname()
+ client_medium = medium.SmartTCPClientMedium(host, port, '/')
+ client_medium._ensure_connection()
+ smart_client = client._SmartClient(client_medium)
+ smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+ # Accept the connection, but don't actually talk to the client.
+ server_sock, _ = listen_sock.accept()
+ server_sock.close()
+ # Sockets buffer and don't really notice that the server has closed the
+ # connection until we try to read again.
+ handler = smart_request._send(3)
+ self.assertRaises(errors.ConnectionReset,
+ handler.read_response_tuple, expect_body=False)
+
+ def test__send_retries_on_write(self):
+ output, vendor, smart_client = self.make_client_with_failing_medium()
+ smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+ handler = smart_request._send(3)
+ self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+ '\x00\x00\x00\x02de' # empty headers
+ 's\x00\x00\x00\tl5:helloee',
+ output.getvalue())
+ self.assertEqual(
+ [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close',),
+ ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ],
+ vendor.calls)
+
+ def test__send_doesnt_retry_read_failure(self):
+ output, vendor, smart_client = self.make_client_with_failing_medium(
+ fail_at_write=False)
+ smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+ handler = smart_request._send(3)
+ self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+ '\x00\x00\x00\x02de' # empty headers
+ 's\x00\x00\x00\tl5:helloee',
+ output.getvalue())
+ self.assertEqual(
+ [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ],
+ vendor.calls)
+ self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
+
+ def test__send_request_retries_body_stream_if_not_started(self):
+ output, vendor, smart_client = self.make_client_with_failing_medium()
+ smart_request = client._SmartClientRequest(smart_client, 'hello', (),
+ body_stream=['a', 'b'])
+ response_handler = smart_request._send(3)
+ # We connect, get disconnected, and notice before consuming the stream,
+ # so we try again one time and succeed.
+ self.assertEqual(
+ [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close',),
+ ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ],
+ vendor.calls)
+ self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+ '\x00\x00\x00\x02de' # empty headers
+ 's\x00\x00\x00\tl5:helloe'
+ 'b\x00\x00\x00\x01a'
+ 'b\x00\x00\x00\x01b'
+ 'e',
+ output.getvalue())
+
+ def test__send_request_stops_if_body_started(self):
+ # We intentionally use the python StringIO so that we can subclass it.
+ from StringIO import StringIO
+ response = StringIO()
+
+ class FailAfterFirstWrite(StringIO):
+ """Allow one 'write' call to pass, fail the rest"""
+ def __init__(self):
+ StringIO.__init__(self)
+ self._first = True
+
+ def write(self, s):
+ if self._first:
+ self._first = False
+ return StringIO.write(self, s)
+ raise IOError(errno.EINVAL, 'invalid file handle')
+ output = FailAfterFirstWrite()
+
+ vendor = FirstRejectedStringIOSSHVendor(response, output,
+ fail_at_write=False)
+ ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
+ client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
+ smart_client = client._SmartClient(client_medium, headers={})
+ smart_request = client._SmartClientRequest(smart_client, 'hello', (),
+ body_stream=['a', 'b'])
+ self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
+ # We connect, and manage to get to the point that we start consuming
+ # the body stream. The next write fails, so we just stop.
+ self.assertEqual(
+ [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close',),
+ ],
+ vendor.calls)
+ self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+ '\x00\x00\x00\x02de' # empty headers
+ 's\x00\x00\x00\tl5:helloe',
+ output.getvalue())
+
+ def test__send_disabled_retry(self):
+ debug.debug_flags.add('noretry')
+ output, vendor, smart_client = self.make_client_with_failing_medium()
+ smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+ self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
+ self.assertEqual(
+ [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+ ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+ ('close',),
+ ],
+ vendor.calls)
+
+
+class LengthPrefixedBodyDecoder(tests.TestCase):
+
+ # XXX: TODO: make accept_reading_trailer invoke translate_response or
+ # something similar to the ProtocolBase method.
+
+ def test_construct(self):
+ decoder = protocol.LengthPrefixedBodyDecoder()
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(6, decoder.next_read_size())
+ self.assertEqual('', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+
+ def test_accept_bytes(self):
+ decoder = protocol.LengthPrefixedBodyDecoder()
+ decoder.accept_bytes('')
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(6, decoder.next_read_size())
+ self.assertEqual('', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+ decoder.accept_bytes('7')
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(6, decoder.next_read_size())
+ self.assertEqual('', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+ decoder.accept_bytes('\na')
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(11, decoder.next_read_size())
+ self.assertEqual('a', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+ decoder.accept_bytes('bcdefgd')
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(4, decoder.next_read_size())
+ self.assertEqual('bcdefg', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+ decoder.accept_bytes('one')
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(1, decoder.next_read_size())
+ self.assertEqual('', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+ decoder.accept_bytes('\nblarg')
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(1, decoder.next_read_size())
+ self.assertEqual('', decoder.read_pending_data())
+ self.assertEqual('blarg', decoder.unused_data)
+
+ def test_accept_bytes_all_at_once_with_excess(self):
+ decoder = protocol.LengthPrefixedBodyDecoder()
+ decoder.accept_bytes('1\nadone\nunused')
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(1, decoder.next_read_size())
+ self.assertEqual('a', decoder.read_pending_data())
+ self.assertEqual('unused', decoder.unused_data)
+
+ def test_accept_bytes_exact_end_of_body(self):
+ decoder = protocol.LengthPrefixedBodyDecoder()
+ decoder.accept_bytes('1\na')
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(5, decoder.next_read_size())
+ self.assertEqual('a', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+ decoder.accept_bytes('done\n')
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(1, decoder.next_read_size())
+ self.assertEqual('', decoder.read_pending_data())
+ self.assertEqual('', decoder.unused_data)
+
+
+class TestChunkedBodyDecoder(tests.TestCase):
+ """Tests for ChunkedBodyDecoder.
+
+ This is the body decoder used for protocol version two.
+ """
+
+ def test_construct(self):
+ decoder = protocol.ChunkedBodyDecoder()
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(8, decoder.next_read_size())
+ self.assertEqual(None, decoder.read_next_chunk())
+ self.assertEqual('', decoder.unused_data)
+
+ def test_empty_content(self):
+ """'chunked\nEND\n' is the complete encoding of a zero-length body.
+ """
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ decoder.accept_bytes('END\n')
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(None, decoder.read_next_chunk())
+ self.assertEqual('', decoder.unused_data)
+
+ def test_one_chunk(self):
+ """A body in a single chunk is decoded correctly."""
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunk_length = 'f\n'
+ chunk_content = '123456789abcdef'
+ finish = 'END\n'
+ decoder.accept_bytes(chunk_length + chunk_content + finish)
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(chunk_content, decoder.read_next_chunk())
+ self.assertEqual('', decoder.unused_data)
+
+ def test_incomplete_chunk(self):
+ """When there are less bytes in the chunk than declared by the length,
+ then we haven't finished reading yet.
+ """
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunk_length = '8\n'
+ three_bytes = '123'
+ decoder.accept_bytes(chunk_length + three_bytes)
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(
+ 5 + 4, decoder.next_read_size(),
+ "The next_read_size hint should be the number of missing bytes in "
+ "this chunk plus 4 (the length of the end-of-body marker: "
+ "'END\\n')")
+ self.assertEqual(None, decoder.read_next_chunk())
+
+ def test_incomplete_length(self):
+ """A chunk length hasn't been read until a newline byte has been read.
+ """
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ decoder.accept_bytes('9')
+ self.assertEqual(
+ 1, decoder.next_read_size(),
+ "The next_read_size hint should be 1, because we don't know the "
+ "length yet.")
+ decoder.accept_bytes('\n')
+ self.assertEqual(
+ 9 + 4, decoder.next_read_size(),
+ "The next_read_size hint should be the length of the chunk plus 4 "
+ "(the length of the end-of-body marker: 'END\\n')")
+ self.assertFalse(decoder.finished_reading)
+ self.assertEqual(None, decoder.read_next_chunk())
+
+ def test_two_chunks(self):
+ """Content from multiple chunks is concatenated."""
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunk_one = '3\naaa'
+ chunk_two = '5\nbbbbb'
+ finish = 'END\n'
+ decoder.accept_bytes(chunk_one + chunk_two + finish)
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual('aaa', decoder.read_next_chunk())
+ self.assertEqual('bbbbb', decoder.read_next_chunk())
+ self.assertEqual(None, decoder.read_next_chunk())
+ self.assertEqual('', decoder.unused_data)
+
+ def test_excess_bytes(self):
+ """Bytes after the chunked body are reported as unused bytes."""
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunked_body = "5\naaaaaEND\n"
+ excess_bytes = "excess bytes"
+ decoder.accept_bytes(chunked_body + excess_bytes)
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual('aaaaa', decoder.read_next_chunk())
+ self.assertEqual(excess_bytes, decoder.unused_data)
+ self.assertEqual(
+ 1, decoder.next_read_size(),
+ "next_read_size hint should be 1 when finished_reading.")
+
+ def test_multidigit_length(self):
+ """Lengths in the chunk prefixes can have multiple digits."""
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ length = 0x123
+ chunk_prefix = hex(length) + '\n'
+ chunk_bytes = 'z' * length
+ finish = 'END\n'
+ decoder.accept_bytes(chunk_prefix + chunk_bytes + finish)
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(chunk_bytes, decoder.read_next_chunk())
+
+ def test_byte_at_a_time(self):
+ """A complete body fed to the decoder one byte at a time should not
+ confuse the decoder. That is, it should give the same result as if the
+ bytes had been received in one batch.
+
+ This test is the same as test_one_chunk apart from the way accept_bytes
+ is called.
+ """
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunk_length = 'f\n'
+ chunk_content = '123456789abcdef'
+ finish = 'END\n'
+ for byte in (chunk_length + chunk_content + finish):
+ decoder.accept_bytes(byte)
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual(chunk_content, decoder.read_next_chunk())
+ self.assertEqual('', decoder.unused_data)
+
+ def test_read_pending_data_resets(self):
+ """read_pending_data does not return the same bytes twice."""
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunk_one = '3\naaa'
+ chunk_two = '3\nbbb'
+ finish = 'END\n'
+ decoder.accept_bytes(chunk_one)
+ self.assertEqual('aaa', decoder.read_next_chunk())
+ decoder.accept_bytes(chunk_two)
+ self.assertEqual('bbb', decoder.read_next_chunk())
+ self.assertEqual(None, decoder.read_next_chunk())
+
+ def test_decode_error(self):
+ decoder = protocol.ChunkedBodyDecoder()
+ decoder.accept_bytes('chunked\n')
+ chunk_one = 'b\nfirst chunk'
+ error_signal = 'ERR\n'
+ error_chunks = '5\npart1' + '5\npart2'
+ finish = 'END\n'
+ decoder.accept_bytes(chunk_one + error_signal + error_chunks + finish)
+ self.assertTrue(decoder.finished_reading)
+ self.assertEqual('first chunk', decoder.read_next_chunk())
+ expected_failure = _mod_request.FailedSmartServerResponse(
+ ('part1', 'part2'))
+ self.assertEqual(expected_failure, decoder.read_next_chunk())
+
+ def test_bad_header(self):
+ """accept_bytes raises a SmartProtocolError if a chunked body does not
+ start with the right header.
+ """
+ decoder = protocol.ChunkedBodyDecoder()
+ self.assertRaises(
+ errors.SmartProtocolError, decoder.accept_bytes, 'bad header\n')
+
+
+class TestSuccessfulSmartServerResponse(tests.TestCase):
+
+ def test_construct_no_body(self):
+ response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'))
+ self.assertEqual(('foo', 'bar'), response.args)
+ self.assertEqual(None, response.body)
+
+ def test_construct_with_body(self):
+ response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'),
+ 'bytes')
+ self.assertEqual(('foo', 'bar'), response.args)
+ self.assertEqual('bytes', response.body)
+ # repr(response) doesn't trigger exceptions.
+ repr(response)
+
+ def test_construct_with_body_stream(self):
+ bytes_iterable = ['abc']
+ response = _mod_request.SuccessfulSmartServerResponse(
+ ('foo', 'bar'), body_stream=bytes_iterable)
+ self.assertEqual(('foo', 'bar'), response.args)
+ self.assertEqual(bytes_iterable, response.body_stream)
+
+ def test_construct_rejects_body_and_body_stream(self):
+ """'body' and 'body_stream' are mutually exclusive."""
+ self.assertRaises(
+ errors.BzrError,
+ _mod_request.SuccessfulSmartServerResponse, (), 'body', ['stream'])
+
+ def test_is_successful(self):
+ """is_successful should return True for SuccessfulSmartServerResponse."""
+ response = _mod_request.SuccessfulSmartServerResponse(('error',))
+ self.assertEqual(True, response.is_successful())
+
+
+class TestFailedSmartServerResponse(tests.TestCase):
+
+ def test_construct(self):
+ response = _mod_request.FailedSmartServerResponse(('foo', 'bar'))
+ self.assertEqual(('foo', 'bar'), response.args)
+ self.assertEqual(None, response.body)
+ response = _mod_request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
+ self.assertEqual(('foo', 'bar'), response.args)
+ self.assertEqual('bytes', response.body)
+ # repr(response) doesn't trigger exceptions.
+ repr(response)
+
+ def test_is_successful(self):
+ """is_successful should return False for FailedSmartServerResponse."""
+ response = _mod_request.FailedSmartServerResponse(('error',))
+ self.assertEqual(False, response.is_successful())
+
+
+class FakeHTTPMedium(object):
+ def __init__(self):
+ self.written_request = None
+ self._current_request = None
+ def send_http_smart_request(self, bytes):
+ self.written_request = bytes
+ return None
+
+
+class HTTPTunnellingSmokeTest(tests.TestCase):
+
+ def setUp(self):
+ super(HTTPTunnellingSmokeTest, self).setUp()
+ # We use the VFS layer as part of HTTP tunnelling tests.
+ self.overrideEnv('BZR_NO_SMART_VFS', None)
+
+ def test_smart_http_medium_request_accept_bytes(self):
+ medium = FakeHTTPMedium()
+ request = http.SmartClientHTTPMediumRequest(medium)
+ request.accept_bytes('abc')
+ request.accept_bytes('def')
+ self.assertEqual(None, medium.written_request)
+ request.finished_writing()
+ self.assertEqual('abcdef', medium.written_request)
+
+
+class RemoteHTTPTransportTestCase(tests.TestCase):
+
+ def test_remote_path_after_clone_child(self):
+ # If a user enters "bzr+http://host/foo", we want to sent all smart
+ # requests for child URLs of that to the original URL. i.e., we want to
+ # POST to "bzr+http://host/foo/.bzr/smart" and never something like
+ # "bzr+http://host/foo/.bzr/branch/.bzr/smart". So, a cloned
+ # RemoteHTTPTransport remembers the initial URL, and adjusts the
+ # relpaths it sends in smart requests accordingly.
+ base_transport = remote.RemoteHTTPTransport('bzr+http://host/path')
+ new_transport = base_transport.clone('child_dir')
+ self.assertEqual(base_transport._http_transport,
+ new_transport._http_transport)
+ self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
+ self.assertEqual(
+ 'child_dir/',
+ new_transport._client.remote_path_from_transport(new_transport))
+
+ def test_remote_path_unnormal_base(self):
+ # If the transport's base isn't normalised, the _remote_path should
+ # still be calculated correctly.
+ base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
+ self.assertEqual('c', base_transport._remote_path('c'))
+
+ def test_clone_unnormal_base(self):
+ # If the transport's base isn't normalised, cloned transports should
+ # still work correctly.
+ base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
+ new_transport = base_transport.clone('c')
+ self.assertEqual(base_transport.base + 'c/', new_transport.base)
+ self.assertEqual(
+ 'c/',
+ new_transport._client.remote_path_from_transport(new_transport))
+
+ def test__redirect_to(self):
+ t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
+ r = t._redirected_to('http://www.example.com/foo',
+ 'http://www.example.com/bar')
+ self.assertEquals(type(r), type(t))
+
+ def test__redirect_sibling_protocol(self):
+ t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
+ r = t._redirected_to('http://www.example.com/foo',
+ 'https://www.example.com/bar')
+ self.assertEquals(type(r), type(t))
+ self.assertStartsWith(r.base, 'bzr+https')
+
+ def test__redirect_to_with_user(self):
+ t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
+ r = t._redirected_to('http://www.example.com/foo',
+ 'http://www.example.com/bar')
+ self.assertEquals(type(r), type(t))
+ self.assertEquals('joe', t._parsed_url.user)
+ self.assertEquals(t._parsed_url.user, r._parsed_url.user)
+
+ def test_redirected_to_same_host_different_protocol(self):
+ t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
+ r = t._redirected_to('http://www.example.com/foo',
+ 'ftp://www.example.com/foo')
+ self.assertNotEquals(type(r), type(t))
+
+