diff options
Diffstat (limited to 'bzrlib/smart/client.py')
-rw-r--r-- | bzrlib/smart/client.py | 352 |
1 files changed, 352 insertions, 0 deletions
diff --git a/bzrlib/smart/client.py b/bzrlib/smart/client.py new file mode 100644 index 0000000..8c1e21b --- /dev/null +++ b/bzrlib/smart/client.py @@ -0,0 +1,352 @@ +# Copyright (C) 2006-2010 Canonical Ltd +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +from __future__ import absolute_import + +from bzrlib import lazy_import +lazy_import.lazy_import(globals(), """ +from bzrlib.smart import request as _mod_request +""") + +import bzrlib +from bzrlib.smart import message, protocol +from bzrlib import ( + debug, + errors, + hooks, + trace, + ) + + +class _SmartClient(object): + + def __init__(self, medium, headers=None): + """Constructor. + + :param medium: a SmartClientMedium + """ + self._medium = medium + if headers is None: + self._headers = {'Software version': bzrlib.__version__} + else: + self._headers = dict(headers) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self._medium) + + def _call_and_read_response(self, method, args, body=None, readv_body=None, + body_stream=None, expect_response_body=True): + request = _SmartClientRequest(self, method, args, body=body, + readv_body=readv_body, body_stream=body_stream, + expect_response_body=expect_response_body) + return request.call_and_read_response() + + def call(self, method, *args): + """Call a method on the remote server.""" + result, protocol = self.call_expecting_body(method, *args) + protocol.cancel_read_body() + return result + + def call_expecting_body(self, method, *args): + """Call a method and return the result and the protocol object. + + The body can be read like so:: + + result, smart_protocol = smart_client.call_expecting_body(...) + body = smart_protocol.read_body_bytes() + """ + return self._call_and_read_response( + method, args, expect_response_body=True) + + def call_with_body_bytes(self, method, args, body): + """Call a method on the remote server with body bytes.""" + if type(method) is not str: + raise TypeError('method must be a byte string, not %r' % (method,)) + for arg in args: + if type(arg) is not str: + raise TypeError('args must be byte strings, not %r' % (args,)) + if type(body) is not str: + raise TypeError('body must be byte string, not %r' % (body,)) + response, response_handler = self._call_and_read_response( + method, args, body=body, expect_response_body=False) + return response + + def call_with_body_bytes_expecting_body(self, method, args, body): + """Call a method on the remote server with body bytes.""" + if type(method) is not str: + raise TypeError('method must be a byte string, not %r' % (method,)) + for arg in args: + if type(arg) is not str: + raise TypeError('args must be byte strings, not %r' % (args,)) + if type(body) is not str: + raise TypeError('body must be byte string, not %r' % (body,)) + response, response_handler = self._call_and_read_response( + method, args, body=body, expect_response_body=True) + return (response, response_handler) + + def call_with_body_readv_array(self, args, body): + response, response_handler = self._call_and_read_response( + args[0], args[1:], readv_body=body, expect_response_body=True) + return (response, response_handler) + + def call_with_body_stream(self, args, stream): + response, response_handler = self._call_and_read_response( + args[0], args[1:], body_stream=stream, + expect_response_body=False) + return (response, response_handler) + + def remote_path_from_transport(self, transport): + """Convert transport into a path suitable for using in a request. + + Note that the resulting remote path doesn't encode the host name or + anything but path, so it is only safe to use it in requests sent over + the medium from the matching transport. + """ + return self._medium.remote_path_from_transport(transport) + + +class _SmartClientRequest(object): + """Encapsulate the logic for a single request. + + This class handles things like reconnecting and sending the request a + second time when the connection is reset in the middle. It also handles the + multiple requests that get made if we don't know what protocol the server + supports yet. + + Generally, you build up one of these objects, passing in the arguments that + you want to send to the server, and then use 'call_and_read_response' to + get the response from the server. + """ + + def __init__(self, client, method, args, body=None, readv_body=None, + body_stream=None, expect_response_body=True): + self.client = client + self.method = method + self.args = args + self.body = body + self.readv_body = readv_body + self.body_stream = body_stream + self.expect_response_body = expect_response_body + + def call_and_read_response(self): + """Send the request to the server, and read the initial response. + + This doesn't read all of the body content of the response, instead it + returns (response_tuple, response_handler). response_tuple is the 'ok', + or 'error' information, and 'response_handler' can be used to get the + content stream out. + """ + self._run_call_hooks() + protocol_version = self.client._medium._protocol_version + if protocol_version is None: + return self._call_determining_protocol_version() + else: + return self._call(protocol_version) + + def _is_safe_to_send_twice(self): + """Check if the current method is re-entrant safe.""" + if self.body_stream is not None or 'noretry' in debug.debug_flags: + # We can't restart a body stream that has already been consumed. + return False + request_type = _mod_request.request_handlers.get_info(self.method) + if request_type in ('read', 'idem', 'semi'): + return True + # If we have gotten this far, 'stream' cannot be retried, because we + # already consumed the local stream. + if request_type in ('semivfs', 'mutate', 'stream'): + return False + trace.mutter('Unknown request type: %s for method %s' + % (request_type, self.method)) + return False + + def _run_call_hooks(self): + if not _SmartClient.hooks['call']: + return + params = CallHookParams(self.method, self.args, self.body, + self.readv_body, self.client._medium) + for hook in _SmartClient.hooks['call']: + hook(params) + + def _call(self, protocol_version): + """We know the protocol version. + + So this just sends the request, and then reads the response. This is + where the code will be to retry requests if the connection is closed. + """ + response_handler = self._send(protocol_version) + try: + response_tuple = response_handler.read_response_tuple( + expect_body=self.expect_response_body) + except errors.ConnectionReset, e: + self.client._medium.reset() + if not self._is_safe_to_send_twice(): + raise + trace.warning('ConnectionReset reading response for %r, retrying' + % (self.method,)) + trace.log_exception_quietly() + encoder, response_handler = self._construct_protocol( + protocol_version) + self._send_no_retry(encoder) + response_tuple = response_handler.read_response_tuple( + expect_body=self.expect_response_body) + return (response_tuple, response_handler) + + def _call_determining_protocol_version(self): + """Determine what protocol the remote server supports. + + We do this by placing a request in the most recent protocol, and + handling the UnexpectedProtocolVersionMarker from the server. + """ + for protocol_version in [3, 2]: + if protocol_version == 2: + # If v3 doesn't work, the remote side is older than 1.6. + self.client._medium._remember_remote_is_before((1, 6)) + try: + response_tuple, response_handler = self._call(protocol_version) + except errors.UnexpectedProtocolVersionMarker, err: + # TODO: We could recover from this without disconnecting if + # we recognise the protocol version. + trace.warning( + 'Server does not understand Bazaar network protocol %d,' + ' reconnecting. (Upgrade the server to avoid this.)' + % (protocol_version,)) + self.client._medium.disconnect() + continue + except errors.ErrorFromSmartServer: + # If we received an error reply from the server, then it + # must be ok with this protocol version. + self.client._medium._protocol_version = protocol_version + raise + else: + self.client._medium._protocol_version = protocol_version + return response_tuple, response_handler + raise errors.SmartProtocolError( + 'Server is not a Bazaar server: ' + str(err)) + + def _construct_protocol(self, version): + """Build the encoding stack for a given protocol version.""" + request = self.client._medium.get_request() + if version == 3: + request_encoder = protocol.ProtocolThreeRequester(request) + response_handler = message.ConventionalResponseHandler() + response_proto = protocol.ProtocolThreeDecoder( + response_handler, expect_version_marker=True) + response_handler.setProtoAndMediumRequest(response_proto, request) + elif version == 2: + request_encoder = protocol.SmartClientRequestProtocolTwo(request) + response_handler = request_encoder + else: + request_encoder = protocol.SmartClientRequestProtocolOne(request) + response_handler = request_encoder + return request_encoder, response_handler + + def _send(self, protocol_version): + """Encode the request, and send it to the server. + + This will retry a request if we get a ConnectionReset while sending the + request to the server. (Unless we have a body_stream that we have + already started consuming, since we can't restart body_streams) + + :return: response_handler as defined by _construct_protocol + """ + encoder, response_handler = self._construct_protocol(protocol_version) + try: + self._send_no_retry(encoder) + except errors.ConnectionReset, e: + # If we fail during the _send_no_retry phase, then we can + # be confident that the server did not get our request, because we + # haven't started waiting for the reply yet. So try the request + # again. We only issue a single retry, because if the connection + # really is down, there is no reason to loop endlessly. + + # Connection is dead, so close our end of it. + self.client._medium.reset() + if (('noretry' in debug.debug_flags) + or (self.body_stream is not None + and encoder.body_stream_started)): + # We can't restart a body_stream that has been partially + # consumed, so we don't retry. + # Note: We don't have to worry about + # SmartClientRequestProtocolOne or Two, because they don't + # support client-side body streams. + raise + trace.warning('ConnectionReset calling %r, retrying' + % (self.method,)) + trace.log_exception_quietly() + encoder, response_handler = self._construct_protocol( + protocol_version) + self._send_no_retry(encoder) + return response_handler + + def _send_no_retry(self, encoder): + """Just encode the request and try to send it.""" + encoder.set_headers(self.client._headers) + if self.body is not None: + if self.readv_body is not None: + raise AssertionError( + "body and readv_body are mutually exclusive.") + if self.body_stream is not None: + raise AssertionError( + "body and body_stream are mutually exclusive.") + encoder.call_with_body_bytes((self.method, ) + self.args, self.body) + elif self.readv_body is not None: + if self.body_stream is not None: + raise AssertionError( + "readv_body and body_stream are mutually exclusive.") + encoder.call_with_body_readv_array((self.method, ) + self.args, + self.readv_body) + elif self.body_stream is not None: + encoder.call_with_body_stream((self.method, ) + self.args, + self.body_stream) + else: + encoder.call(self.method, *self.args) + + +class SmartClientHooks(hooks.Hooks): + + def __init__(self): + hooks.Hooks.__init__(self, "bzrlib.smart.client", "_SmartClient.hooks") + self.add_hook('call', + "Called when the smart client is submitting a request to the " + "smart server. Called with a bzrlib.smart.client.CallHookParams " + "object. Streaming request bodies, and responses, are not " + "accessible.", None) + + +_SmartClient.hooks = SmartClientHooks() + + +class CallHookParams(object): + + def __init__(self, method, args, body, readv_body, medium): + self.method = method + self.args = args + self.body = body + self.readv_body = readv_body + self.medium = medium + + def __repr__(self): + attrs = dict((k, v) for (k, v) in self.__dict__.iteritems() + if v is not None) + return '<%s %r>' % (self.__class__.__name__, attrs) + + def __eq__(self, other): + if type(other) is not type(self): + return NotImplemented + return self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self == other |