summaryrefslogtreecommitdiff
path: root/bzrlib/smart/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bzrlib/smart/client.py')
-rw-r--r--bzrlib/smart/client.py352
1 files changed, 352 insertions, 0 deletions
diff --git a/bzrlib/smart/client.py b/bzrlib/smart/client.py
new file mode 100644
index 0000000..8c1e21b
--- /dev/null
+++ b/bzrlib/smart/client.py
@@ -0,0 +1,352 @@
+# Copyright (C) 2006-2010 Canonical Ltd
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+from __future__ import absolute_import
+
+from bzrlib import lazy_import
+lazy_import.lazy_import(globals(), """
+from bzrlib.smart import request as _mod_request
+""")
+
+import bzrlib
+from bzrlib.smart import message, protocol
+from bzrlib import (
+ debug,
+ errors,
+ hooks,
+ trace,
+ )
+
+
+class _SmartClient(object):
+
+ def __init__(self, medium, headers=None):
+ """Constructor.
+
+ :param medium: a SmartClientMedium
+ """
+ self._medium = medium
+ if headers is None:
+ self._headers = {'Software version': bzrlib.__version__}
+ else:
+ self._headers = dict(headers)
+
+ def __repr__(self):
+ return '%s(%r)' % (self.__class__.__name__, self._medium)
+
+ def _call_and_read_response(self, method, args, body=None, readv_body=None,
+ body_stream=None, expect_response_body=True):
+ request = _SmartClientRequest(self, method, args, body=body,
+ readv_body=readv_body, body_stream=body_stream,
+ expect_response_body=expect_response_body)
+ return request.call_and_read_response()
+
+ def call(self, method, *args):
+ """Call a method on the remote server."""
+ result, protocol = self.call_expecting_body(method, *args)
+ protocol.cancel_read_body()
+ return result
+
+ def call_expecting_body(self, method, *args):
+ """Call a method and return the result and the protocol object.
+
+ The body can be read like so::
+
+ result, smart_protocol = smart_client.call_expecting_body(...)
+ body = smart_protocol.read_body_bytes()
+ """
+ return self._call_and_read_response(
+ method, args, expect_response_body=True)
+
+ def call_with_body_bytes(self, method, args, body):
+ """Call a method on the remote server with body bytes."""
+ if type(method) is not str:
+ raise TypeError('method must be a byte string, not %r' % (method,))
+ for arg in args:
+ if type(arg) is not str:
+ raise TypeError('args must be byte strings, not %r' % (args,))
+ if type(body) is not str:
+ raise TypeError('body must be byte string, not %r' % (body,))
+ response, response_handler = self._call_and_read_response(
+ method, args, body=body, expect_response_body=False)
+ return response
+
+ def call_with_body_bytes_expecting_body(self, method, args, body):
+ """Call a method on the remote server with body bytes."""
+ if type(method) is not str:
+ raise TypeError('method must be a byte string, not %r' % (method,))
+ for arg in args:
+ if type(arg) is not str:
+ raise TypeError('args must be byte strings, not %r' % (args,))
+ if type(body) is not str:
+ raise TypeError('body must be byte string, not %r' % (body,))
+ response, response_handler = self._call_and_read_response(
+ method, args, body=body, expect_response_body=True)
+ return (response, response_handler)
+
+ def call_with_body_readv_array(self, args, body):
+ response, response_handler = self._call_and_read_response(
+ args[0], args[1:], readv_body=body, expect_response_body=True)
+ return (response, response_handler)
+
+ def call_with_body_stream(self, args, stream):
+ response, response_handler = self._call_and_read_response(
+ args[0], args[1:], body_stream=stream,
+ expect_response_body=False)
+ return (response, response_handler)
+
+ def remote_path_from_transport(self, transport):
+ """Convert transport into a path suitable for using in a request.
+
+ Note that the resulting remote path doesn't encode the host name or
+ anything but path, so it is only safe to use it in requests sent over
+ the medium from the matching transport.
+ """
+ return self._medium.remote_path_from_transport(transport)
+
+
+class _SmartClientRequest(object):
+ """Encapsulate the logic for a single request.
+
+ This class handles things like reconnecting and sending the request a
+ second time when the connection is reset in the middle. It also handles the
+ multiple requests that get made if we don't know what protocol the server
+ supports yet.
+
+ Generally, you build up one of these objects, passing in the arguments that
+ you want to send to the server, and then use 'call_and_read_response' to
+ get the response from the server.
+ """
+
+ def __init__(self, client, method, args, body=None, readv_body=None,
+ body_stream=None, expect_response_body=True):
+ self.client = client
+ self.method = method
+ self.args = args
+ self.body = body
+ self.readv_body = readv_body
+ self.body_stream = body_stream
+ self.expect_response_body = expect_response_body
+
+ def call_and_read_response(self):
+ """Send the request to the server, and read the initial response.
+
+ This doesn't read all of the body content of the response, instead it
+ returns (response_tuple, response_handler). response_tuple is the 'ok',
+ or 'error' information, and 'response_handler' can be used to get the
+ content stream out.
+ """
+ self._run_call_hooks()
+ protocol_version = self.client._medium._protocol_version
+ if protocol_version is None:
+ return self._call_determining_protocol_version()
+ else:
+ return self._call(protocol_version)
+
+ def _is_safe_to_send_twice(self):
+ """Check if the current method is re-entrant safe."""
+ if self.body_stream is not None or 'noretry' in debug.debug_flags:
+ # We can't restart a body stream that has already been consumed.
+ return False
+ request_type = _mod_request.request_handlers.get_info(self.method)
+ if request_type in ('read', 'idem', 'semi'):
+ return True
+ # If we have gotten this far, 'stream' cannot be retried, because we
+ # already consumed the local stream.
+ if request_type in ('semivfs', 'mutate', 'stream'):
+ return False
+ trace.mutter('Unknown request type: %s for method %s'
+ % (request_type, self.method))
+ return False
+
+ def _run_call_hooks(self):
+ if not _SmartClient.hooks['call']:
+ return
+ params = CallHookParams(self.method, self.args, self.body,
+ self.readv_body, self.client._medium)
+ for hook in _SmartClient.hooks['call']:
+ hook(params)
+
+ def _call(self, protocol_version):
+ """We know the protocol version.
+
+ So this just sends the request, and then reads the response. This is
+ where the code will be to retry requests if the connection is closed.
+ """
+ response_handler = self._send(protocol_version)
+ try:
+ response_tuple = response_handler.read_response_tuple(
+ expect_body=self.expect_response_body)
+ except errors.ConnectionReset, e:
+ self.client._medium.reset()
+ if not self._is_safe_to_send_twice():
+ raise
+ trace.warning('ConnectionReset reading response for %r, retrying'
+ % (self.method,))
+ trace.log_exception_quietly()
+ encoder, response_handler = self._construct_protocol(
+ protocol_version)
+ self._send_no_retry(encoder)
+ response_tuple = response_handler.read_response_tuple(
+ expect_body=self.expect_response_body)
+ return (response_tuple, response_handler)
+
+ def _call_determining_protocol_version(self):
+ """Determine what protocol the remote server supports.
+
+ We do this by placing a request in the most recent protocol, and
+ handling the UnexpectedProtocolVersionMarker from the server.
+ """
+ for protocol_version in [3, 2]:
+ if protocol_version == 2:
+ # If v3 doesn't work, the remote side is older than 1.6.
+ self.client._medium._remember_remote_is_before((1, 6))
+ try:
+ response_tuple, response_handler = self._call(protocol_version)
+ except errors.UnexpectedProtocolVersionMarker, err:
+ # TODO: We could recover from this without disconnecting if
+ # we recognise the protocol version.
+ trace.warning(
+ 'Server does not understand Bazaar network protocol %d,'
+ ' reconnecting. (Upgrade the server to avoid this.)'
+ % (protocol_version,))
+ self.client._medium.disconnect()
+ continue
+ except errors.ErrorFromSmartServer:
+ # If we received an error reply from the server, then it
+ # must be ok with this protocol version.
+ self.client._medium._protocol_version = protocol_version
+ raise
+ else:
+ self.client._medium._protocol_version = protocol_version
+ return response_tuple, response_handler
+ raise errors.SmartProtocolError(
+ 'Server is not a Bazaar server: ' + str(err))
+
+ def _construct_protocol(self, version):
+ """Build the encoding stack for a given protocol version."""
+ request = self.client._medium.get_request()
+ if version == 3:
+ request_encoder = protocol.ProtocolThreeRequester(request)
+ response_handler = message.ConventionalResponseHandler()
+ response_proto = protocol.ProtocolThreeDecoder(
+ response_handler, expect_version_marker=True)
+ response_handler.setProtoAndMediumRequest(response_proto, request)
+ elif version == 2:
+ request_encoder = protocol.SmartClientRequestProtocolTwo(request)
+ response_handler = request_encoder
+ else:
+ request_encoder = protocol.SmartClientRequestProtocolOne(request)
+ response_handler = request_encoder
+ return request_encoder, response_handler
+
+ def _send(self, protocol_version):
+ """Encode the request, and send it to the server.
+
+ This will retry a request if we get a ConnectionReset while sending the
+ request to the server. (Unless we have a body_stream that we have
+ already started consuming, since we can't restart body_streams)
+
+ :return: response_handler as defined by _construct_protocol
+ """
+ encoder, response_handler = self._construct_protocol(protocol_version)
+ try:
+ self._send_no_retry(encoder)
+ except errors.ConnectionReset, e:
+ # If we fail during the _send_no_retry phase, then we can
+ # be confident that the server did not get our request, because we
+ # haven't started waiting for the reply yet. So try the request
+ # again. We only issue a single retry, because if the connection
+ # really is down, there is no reason to loop endlessly.
+
+ # Connection is dead, so close our end of it.
+ self.client._medium.reset()
+ if (('noretry' in debug.debug_flags)
+ or (self.body_stream is not None
+ and encoder.body_stream_started)):
+ # We can't restart a body_stream that has been partially
+ # consumed, so we don't retry.
+ # Note: We don't have to worry about
+ # SmartClientRequestProtocolOne or Two, because they don't
+ # support client-side body streams.
+ raise
+ trace.warning('ConnectionReset calling %r, retrying'
+ % (self.method,))
+ trace.log_exception_quietly()
+ encoder, response_handler = self._construct_protocol(
+ protocol_version)
+ self._send_no_retry(encoder)
+ return response_handler
+
+ def _send_no_retry(self, encoder):
+ """Just encode the request and try to send it."""
+ encoder.set_headers(self.client._headers)
+ if self.body is not None:
+ if self.readv_body is not None:
+ raise AssertionError(
+ "body and readv_body are mutually exclusive.")
+ if self.body_stream is not None:
+ raise AssertionError(
+ "body and body_stream are mutually exclusive.")
+ encoder.call_with_body_bytes((self.method, ) + self.args, self.body)
+ elif self.readv_body is not None:
+ if self.body_stream is not None:
+ raise AssertionError(
+ "readv_body and body_stream are mutually exclusive.")
+ encoder.call_with_body_readv_array((self.method, ) + self.args,
+ self.readv_body)
+ elif self.body_stream is not None:
+ encoder.call_with_body_stream((self.method, ) + self.args,
+ self.body_stream)
+ else:
+ encoder.call(self.method, *self.args)
+
+
+class SmartClientHooks(hooks.Hooks):
+
+ def __init__(self):
+ hooks.Hooks.__init__(self, "bzrlib.smart.client", "_SmartClient.hooks")
+ self.add_hook('call',
+ "Called when the smart client is submitting a request to the "
+ "smart server. Called with a bzrlib.smart.client.CallHookParams "
+ "object. Streaming request bodies, and responses, are not "
+ "accessible.", None)
+
+
+_SmartClient.hooks = SmartClientHooks()
+
+
+class CallHookParams(object):
+
+ def __init__(self, method, args, body, readv_body, medium):
+ self.method = method
+ self.args = args
+ self.body = body
+ self.readv_body = readv_body
+ self.medium = medium
+
+ def __repr__(self):
+ attrs = dict((k, v) for (k, v) in self.__dict__.iteritems()
+ if v is not None)
+ return '<%s %r>' % (self.__class__.__name__, attrs)
+
+ def __eq__(self, other):
+ if type(other) is not type(self):
+ return NotImplemented
+ return self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self == other