# Copyright (C) 2006-2010 Canonical Ltd # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA """Wire-level encoding and decoding of requests and responses for the smart client and server. """ from __future__ import absolute_import import collections from cStringIO import StringIO import struct import sys import thread import time import bzrlib from bzrlib import ( debug, errors, osutils, ) from bzrlib.smart import message, request from bzrlib.trace import log_exception_quietly, mutter from bzrlib.bencode import bdecode_as_tuple, bencode # Protocol version strings. These are sent as prefixes of bzr requests and # responses to identify the protocol version being used. (There are no version # one strings because that version doesn't send any). REQUEST_VERSION_TWO = 'bzr request 2\n' RESPONSE_VERSION_TWO = 'bzr response 2\n' MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n' RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE def _recv_tuple(from_file): req_line = from_file.readline() return _decode_tuple(req_line) def _decode_tuple(req_line): if req_line is None or req_line == '': return None if req_line[-1] != '\n': raise errors.SmartProtocolError("request %r not terminated" % req_line) return tuple(req_line[:-1].split('\x01')) def _encode_tuple(args): """Encode the tuple args to a bytestream.""" joined = '\x01'.join(args) + '\n' if type(joined) is unicode: # XXX: We should fix things so this never happens! -AJB, 20100304 mutter('response args contain unicode, should be only bytes: %r', joined) joined = joined.encode('ascii') return joined class Requester(object): """Abstract base class for an object that can issue requests on a smart medium. """ def call(self, *args): """Make a remote call. :param args: the arguments of this call. """ raise NotImplementedError(self.call) def call_with_body_bytes(self, args, body): """Make a remote call with a body. :param args: the arguments of this call. :type body: str :param body: the body to send with the request. """ raise NotImplementedError(self.call_with_body_bytes) def call_with_body_readv_array(self, args, body): """Make a remote call with a readv array. :param args: the arguments of this call. :type body: iterable of (start, length) tuples. :param body: the readv ranges to send with this request. """ raise NotImplementedError(self.call_with_body_readv_array) def set_headers(self, headers): raise NotImplementedError(self.set_headers) class SmartProtocolBase(object): """Methods common to client and server""" # TODO: this only actually accomodates a single block; possibly should # support multiple chunks? def _encode_bulk_data(self, body): """Encode body as a bulk data chunk.""" return ''.join(('%d\n' % len(body), body, 'done\n')) def _serialise_offsets(self, offsets): """Serialise a readv offset list.""" txt = [] for start, length in offsets: txt.append('%d,%d' % (start, length)) return '\n'.join(txt) class SmartServerRequestProtocolOne(SmartProtocolBase): """Server-side encoding and decoding logic for smart version 1.""" def __init__(self, backing_transport, write_func, root_client_path='/', jail_root=None): self._backing_transport = backing_transport self._root_client_path = root_client_path self._jail_root = jail_root self.unused_data = '' self._finished = False self.in_buffer = '' self._has_dispatched = False self.request = None self._body_decoder = None self._write_func = write_func def accept_bytes(self, bytes): """Take bytes, and advance the internal state machine appropriately. :param bytes: must be a byte string """ if not isinstance(bytes, str): raise ValueError(bytes) self.in_buffer += bytes if not self._has_dispatched: if '\n' not in self.in_buffer: # no command line yet return self._has_dispatched = True try: first_line, self.in_buffer = self.in_buffer.split('\n', 1) first_line += '\n' req_args = _decode_tuple(first_line) self.request = request.SmartServerRequestHandler( self._backing_transport, commands=request.request_handlers, root_client_path=self._root_client_path, jail_root=self._jail_root) self.request.args_received(req_args) if self.request.finished_reading: # trivial request self.unused_data = self.in_buffer self.in_buffer = '' self._send_response(self.request.response) except KeyboardInterrupt: raise except errors.UnknownSmartMethod, err: protocol_error = errors.SmartProtocolError( "bad request %r" % (err.verb,)) failure = request.FailedSmartServerResponse( ('error', str(protocol_error))) self._send_response(failure) return except Exception, exception: # everything else: pass to client, flush, and quit log_exception_quietly() self._send_response(request.FailedSmartServerResponse( ('error', str(exception)))) return if self._has_dispatched: if self._finished: # nothing to do.XXX: this routine should be a single state # machine too. self.unused_data += self.in_buffer self.in_buffer = '' return if self._body_decoder is None: self._body_decoder = LengthPrefixedBodyDecoder() self._body_decoder.accept_bytes(self.in_buffer) self.in_buffer = self._body_decoder.unused_data body_data = self._body_decoder.read_pending_data() self.request.accept_body(body_data) if self._body_decoder.finished_reading: self.request.end_of_body() if not self.request.finished_reading: raise AssertionError("no more body, request not finished") if self.request.response is not None: self._send_response(self.request.response) self.unused_data = self.in_buffer self.in_buffer = '' else: if self.request.finished_reading: raise AssertionError( "no response and we have finished reading.") def _send_response(self, response): """Send a smart server response down the output stream.""" if self._finished: raise AssertionError('response already sent') args = response.args body = response.body self._finished = True self._write_protocol_version() self._write_success_or_failure_prefix(response) self._write_func(_encode_tuple(args)) if body is not None: if not isinstance(body, str): raise ValueError(body) bytes = self._encode_bulk_data(body) self._write_func(bytes) def _write_protocol_version(self): """Write any prefixes this protocol requires. Version one doesn't send protocol versions. """ def _write_success_or_failure_prefix(self, response): """Write the protocol specific success/failure prefix. For SmartServerRequestProtocolOne this is omitted but we call is_successful to ensure that the response is valid. """ response.is_successful() def next_read_size(self): if self._finished: return 0 if self._body_decoder is None: return 1 else: return self._body_decoder.next_read_size() class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne): r"""Version two of the server side of the smart protocol. This prefixes responses with the value of RESPONSE_VERSION_TWO. """ response_marker = RESPONSE_VERSION_TWO request_marker = REQUEST_VERSION_TWO def _write_success_or_failure_prefix(self, response): """Write the protocol specific success/failure prefix.""" if response.is_successful(): self._write_func('success\n') else: self._write_func('failed\n') def _write_protocol_version(self): r"""Write any prefixes this protocol requires. Version two sends the value of RESPONSE_VERSION_TWO. """ self._write_func(self.response_marker) def _send_response(self, response): """Send a smart server response down the output stream.""" if (self._finished): raise AssertionError('response already sent') self._finished = True self._write_protocol_version() self._write_success_or_failure_prefix(response) self._write_func(_encode_tuple(response.args)) if response.body is not None: if not isinstance(response.body, str): raise AssertionError('body must be a str') if not (response.body_stream is None): raise AssertionError( 'body_stream and body cannot both be set') bytes = self._encode_bulk_data(response.body) self._write_func(bytes) elif response.body_stream is not None: _send_stream(response.body_stream, self._write_func) def _send_stream(stream, write_func): write_func('chunked\n') _send_chunks(stream, write_func) write_func('END\n') def _send_chunks(stream, write_func): for chunk in stream: if isinstance(chunk, str): bytes = "%x\n%s" % (len(chunk), chunk) write_func(bytes) elif isinstance(chunk, request.FailedSmartServerResponse): write_func('ERR\n') _send_chunks(chunk.args, write_func) return else: raise errors.BzrError( 'Chunks must be str or FailedSmartServerResponse, got %r' % chunk) class _NeedMoreBytes(Exception): """Raise this inside a _StatefulDecoder to stop decoding until more bytes have been received. """ def __init__(self, count=None): """Constructor. :param count: the total number of bytes needed by the current state. May be None if the number of bytes needed is unknown. """ self.count = count class _StatefulDecoder(object): """Base class for writing state machines to decode byte streams. Subclasses should provide a self.state_accept attribute that accepts bytes and, if appropriate, updates self.state_accept to a different function. accept_bytes will call state_accept as often as necessary to make sure the state machine has progressed as far as possible before it returns. See ProtocolThreeDecoder for an example subclass. """ def __init__(self): self.finished_reading = False self._in_buffer_list = [] self._in_buffer_len = 0 self.unused_data = '' self.bytes_left = None self._number_needed_bytes = None def _get_in_buffer(self): if len(self._in_buffer_list) == 1: return self._in_buffer_list[0] in_buffer = ''.join(self._in_buffer_list) if len(in_buffer) != self._in_buffer_len: raise AssertionError( "Length of buffer did not match expected value: %s != %s" % self._in_buffer_len, len(in_buffer)) self._in_buffer_list = [in_buffer] return in_buffer def _get_in_bytes(self, count): """Grab X bytes from the input_buffer. Callers should have already checked that self._in_buffer_len is > count. Note, this does not consume the bytes from the buffer. The caller will still need to call _get_in_buffer() and then _set_in_buffer() if they actually need to consume the bytes. """ # check if we can yield the bytes from just the first entry in our list if len(self._in_buffer_list) == 0: raise AssertionError('Callers must be sure we have buffered bytes' ' before calling _get_in_bytes') if len(self._in_buffer_list[0]) > count: return self._in_buffer_list[0][:count] # We can't yield it from the first buffer, so collapse all buffers, and # yield it from that in_buf = self._get_in_buffer() return in_buf[:count] def _set_in_buffer(self, new_buf): if new_buf is not None: self._in_buffer_list = [new_buf] self._in_buffer_len = len(new_buf) else: self._in_buffer_list = [] self._in_buffer_len = 0 def accept_bytes(self, bytes): """Decode as much of bytes as possible. If 'bytes' contains too much data it will be appended to self.unused_data. finished_reading will be set when no more data is required. Further data will be appended to self.unused_data. """ # accept_bytes is allowed to change the state self._number_needed_bytes = None # lsprof puts a very large amount of time on this specific call for # large readv arrays self._in_buffer_list.append(bytes) self._in_buffer_len += len(bytes) try: # Run the function for the current state. current_state = self.state_accept self.state_accept() while current_state != self.state_accept: # The current state has changed. Run the function for the new # current state, so that it can: # - decode any unconsumed bytes left in a buffer, and # - signal how many more bytes are expected (via raising # _NeedMoreBytes). current_state = self.state_accept self.state_accept() except _NeedMoreBytes, e: self._number_needed_bytes = e.count class ChunkedBodyDecoder(_StatefulDecoder): """Decoder for chunked body data. This is very similar the HTTP's chunked encoding. See the description of streamed body data in `doc/developers/network-protocol.txt` for details. """ def __init__(self): _StatefulDecoder.__init__(self) self.state_accept = self._state_accept_expecting_header self.chunk_in_progress = None self.chunks = collections.deque() self.error = False self.error_in_progress = None def next_read_size(self): # Note: the shortest possible chunk is 2 bytes: '0\n', and the # end-of-body marker is 4 bytes: 'END\n'. if self.state_accept == self._state_accept_reading_chunk: # We're expecting more chunk content. So we're expecting at least # the rest of this chunk plus an END chunk. return self.bytes_left + 4 elif self.state_accept == self._state_accept_expecting_length: if self._in_buffer_len == 0: # We're expecting a chunk length. There's at least two bytes # left: a digit plus '\n'. return 2 else: # We're in the middle of reading a chunk length. So there's at # least one byte left, the '\n' that terminates the length. return 1 elif self.state_accept == self._state_accept_reading_unused: return 1 elif self.state_accept == self._state_accept_expecting_header: return max(0, len('chunked\n') - self._in_buffer_len) else: raise AssertionError("Impossible state: %r" % (self.state_accept,)) def read_next_chunk(self): try: return self.chunks.popleft() except IndexError: return None def _extract_line(self): in_buf = self._get_in_buffer() pos = in_buf.find('\n') if pos == -1: # We haven't read a complete line yet, so request more bytes before # we continue. raise _NeedMoreBytes(1) line = in_buf[:pos] # Trim the prefix (including '\n' delimiter) from the _in_buffer. self._set_in_buffer(in_buf[pos+1:]) return line def _finished(self): self.unused_data = self._get_in_buffer() self._in_buffer_list = [] self._in_buffer_len = 0 self.state_accept = self._state_accept_reading_unused if self.error: error_args = tuple(self.error_in_progress) self.chunks.append(request.FailedSmartServerResponse(error_args)) self.error_in_progress = None self.finished_reading = True def _state_accept_expecting_header(self): prefix = self._extract_line() if prefix == 'chunked': self.state_accept = self._state_accept_expecting_length else: raise errors.SmartProtocolError( 'Bad chunked body header: "%s"' % (prefix,)) def _state_accept_expecting_length(self): prefix = self._extract_line() if prefix == 'ERR': self.error = True self.error_in_progress = [] self._state_accept_expecting_length() return elif prefix == 'END': # We've read the end-of-body marker. # Any further bytes are unused data, including the bytes left in # the _in_buffer. self._finished() return else: self.bytes_left = int(prefix, 16) self.chunk_in_progress = '' self.state_accept = self._state_accept_reading_chunk def _state_accept_reading_chunk(self): in_buf = self._get_in_buffer() in_buffer_len = len(in_buf) self.chunk_in_progress += in_buf[:self.bytes_left] self._set_in_buffer(in_buf[self.bytes_left:]) self.bytes_left -= in_buffer_len if self.bytes_left <= 0: # Finished with chunk self.bytes_left = None if self.error: self.error_in_progress.append(self.chunk_in_progress) else: self.chunks.append(self.chunk_in_progress) self.chunk_in_progress = None self.state_accept = self._state_accept_expecting_length def _state_accept_reading_unused(self): self.unused_data += self._get_in_buffer() self._in_buffer_list = [] class LengthPrefixedBodyDecoder(_StatefulDecoder): """Decodes the length-prefixed bulk data.""" def __init__(self): _StatefulDecoder.__init__(self) self.state_accept = self._state_accept_expecting_length self.state_read = self._state_read_no_data self._body = '' self._trailer_buffer = '' def next_read_size(self): if self.bytes_left is not None: # Ideally we want to read all the remainder of the body and the # trailer in one go. return self.bytes_left + 5 elif self.state_accept == self._state_accept_reading_trailer: # Just the trailer left return 5 - len(self._trailer_buffer) elif self.state_accept == self._state_accept_expecting_length: # There's still at least 6 bytes left ('\n' to end the length, plus # 'done\n'). return 6 else: # Reading excess data. Either way, 1 byte at a time is fine. return 1 def read_pending_data(self): """Return any pending data that has been decoded.""" return self.state_read() def _state_accept_expecting_length(self): in_buf = self._get_in_buffer() pos = in_buf.find('\n') if pos == -1: return self.bytes_left = int(in_buf[:pos]) self._set_in_buffer(in_buf[pos+1:]) self.state_accept = self._state_accept_reading_body self.state_read = self._state_read_body_buffer def _state_accept_reading_body(self): in_buf = self._get_in_buffer() self._body += in_buf self.bytes_left -= len(in_buf) self._set_in_buffer(None) if self.bytes_left <= 0: # Finished with body if self.bytes_left != 0: self._trailer_buffer = self._body[self.bytes_left:] self._body = self._body[:self.bytes_left] self.bytes_left = None self.state_accept = self._state_accept_reading_trailer def _state_accept_reading_trailer(self): self._trailer_buffer += self._get_in_buffer() self._set_in_buffer(None) # TODO: what if the trailer does not match "done\n"? Should this raise # a ProtocolViolation exception? if self._trailer_buffer.startswith('done\n'): self.unused_data = self._trailer_buffer[len('done\n'):] self.state_accept = self._state_accept_reading_unused self.finished_reading = True def _state_accept_reading_unused(self): self.unused_data += self._get_in_buffer() self._set_in_buffer(None) def _state_read_no_data(self): return '' def _state_read_body_buffer(self): result = self._body self._body = '' return result class SmartClientRequestProtocolOne(SmartProtocolBase, Requester, message.ResponseHandler): """The client-side protocol for smart version 1.""" def __init__(self, request): """Construct a SmartClientRequestProtocolOne. :param request: A SmartClientMediumRequest to serialise onto and deserialise from. """ self._request = request self._body_buffer = None self._request_start_time = None self._last_verb = None self._headers = None def set_headers(self, headers): self._headers = dict(headers) def call(self, *args): if 'hpss' in debug.debug_flags: mutter('hpss call: %s', repr(args)[1:-1]) if getattr(self._request._medium, 'base', None) is not None: mutter(' (to %s)', self._request._medium.base) self._request_start_time = osutils.timer_func() self._write_args(args) self._request.finished_writing() self._last_verb = args[0] def call_with_body_bytes(self, args, body): """Make a remote call of args with body bytes 'body'. After calling this, call read_response_tuple to find the result out. """ if 'hpss' in debug.debug_flags: mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20]) if getattr(self._request._medium, '_path', None) is not None: mutter(' (to %s)', self._request._medium._path) mutter(' %d bytes', len(body)) self._request_start_time = osutils.timer_func() if 'hpssdetail' in debug.debug_flags: mutter('hpss body content: %s', body) self._write_args(args) bytes = self._encode_bulk_data(body) self._request.accept_bytes(bytes) self._request.finished_writing() self._last_verb = args[0] def call_with_body_readv_array(self, args, body): """Make a remote call with a readv array. The body is encoded with one line per readv offset pair. The numbers in each pair are separated by a comma, and no trailing \\n is emitted. """ if 'hpss' in debug.debug_flags: mutter('hpss call w/readv: %s', repr(args)[1:-1]) if getattr(self._request._medium, '_path', None) is not None: mutter(' (to %s)', self._request._medium._path) self._request_start_time = osutils.timer_func() self._write_args(args) readv_bytes = self._serialise_offsets(body) bytes = self._encode_bulk_data(readv_bytes) self._request.accept_bytes(bytes) self._request.finished_writing() if 'hpss' in debug.debug_flags: mutter(' %d bytes in readv request', len(readv_bytes)) self._last_verb = args[0] def call_with_body_stream(self, args, stream): # Protocols v1 and v2 don't support body streams. So it's safe to # assume that a v1/v2 server doesn't support whatever method we're # trying to call with a body stream. self._request.finished_writing() self._request.finished_reading() raise errors.UnknownSmartMethod(args[0]) def cancel_read_body(self): """After expecting a body, a response code may indicate one otherwise. This method lets the domain client inform the protocol that no body will be transmitted. This is a terminal method: after calling it the protocol is not able to be used further. """ self._request.finished_reading() def _read_response_tuple(self): result = self._recv_tuple() if 'hpss' in debug.debug_flags: if self._request_start_time is not None: mutter(' result: %6.3fs %s', osutils.timer_func() - self._request_start_time, repr(result)[1:-1]) self._request_start_time = None else: mutter(' result: %s', repr(result)[1:-1]) return result def read_response_tuple(self, expect_body=False): """Read a response tuple from the wire. This should only be called once. """ result = self._read_response_tuple() self._response_is_unknown_method(result) self._raise_args_if_error(result) if not expect_body: self._request.finished_reading() return result def _raise_args_if_error(self, result_tuple): # Later protocol versions have an explicit flag in the protocol to say # if an error response is "failed" or not. In version 1 we don't have # that luxury. So here is a complete list of errors that can be # returned in response to existing version 1 smart requests. Responses # starting with these codes are always "failed" responses. v1_error_codes = [ 'norepository', 'NoSuchFile', 'FileExists', 'DirectoryNotEmpty', 'ShortReadvError', 'UnicodeEncodeError', 'UnicodeDecodeError', 'ReadOnlyError', 'nobranch', 'NoSuchRevision', 'nosuchrevision', 'LockContention', 'UnlockableTransport', 'LockFailed', 'TokenMismatch', 'ReadError', 'PermissionDenied', ] if result_tuple[0] in v1_error_codes: self._request.finished_reading() raise errors.ErrorFromSmartServer(result_tuple) def _response_is_unknown_method(self, result_tuple): """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn method' response to the request. :param response: The response from a smart client call_expecting_body call. :param verb: The verb used in that call. :raises: UnexpectedSmartServerResponse """ if (result_tuple == ('error', "Generic bzr smart protocol error: " "bad request '%s'" % self._last_verb) or result_tuple == ('error', "Generic bzr smart protocol error: " "bad request u'%s'" % self._last_verb)): # The response will have no body, so we've finished reading. self._request.finished_reading() raise errors.UnknownSmartMethod(self._last_verb) def read_body_bytes(self, count=-1): """Read bytes from the body, decoding into a byte stream. We read all bytes at once to ensure we've checked the trailer for errors, and then feed the buffer back as read_body_bytes is called. """ if self._body_buffer is not None: return self._body_buffer.read(count) _body_decoder = LengthPrefixedBodyDecoder() while not _body_decoder.finished_reading: bytes = self._request.read_bytes(_body_decoder.next_read_size()) if bytes == '': # end of file encountered reading from server raise errors.ConnectionReset( "Connection lost while reading response body.") _body_decoder.accept_bytes(bytes) self._request.finished_reading() self._body_buffer = StringIO(_body_decoder.read_pending_data()) # XXX: TODO check the trailer result. if 'hpss' in debug.debug_flags: mutter(' %d body bytes read', len(self._body_buffer.getvalue())) return self._body_buffer.read(count) def _recv_tuple(self): """Receive a tuple from the medium request.""" return _decode_tuple(self._request.read_line()) def query_version(self): """Return protocol version number of the server.""" self.call('hello') resp = self.read_response_tuple() if resp == ('ok', '1'): return 1 elif resp == ('ok', '2'): return 2 else: raise errors.SmartProtocolError("bad response %r" % (resp,)) def _write_args(self, args): self._write_protocol_version() bytes = _encode_tuple(args) self._request.accept_bytes(bytes) def _write_protocol_version(self): """Write any prefixes this protocol requires. Version one doesn't send protocol versions. """ class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne): """Version two of the client side of the smart protocol. This prefixes the request with the value of REQUEST_VERSION_TWO. """ response_marker = RESPONSE_VERSION_TWO request_marker = REQUEST_VERSION_TWO def read_response_tuple(self, expect_body=False): """Read a response tuple from the wire. This should only be called once. """ version = self._request.read_line() if version != self.response_marker: self._request.finished_reading() raise errors.UnexpectedProtocolVersionMarker(version) response_status = self._request.read_line() result = SmartClientRequestProtocolOne._read_response_tuple(self) self._response_is_unknown_method(result) if response_status == 'success\n': self.response_status = True if not expect_body: self._request.finished_reading() return result elif response_status == 'failed\n': self.response_status = False self._request.finished_reading() raise errors.ErrorFromSmartServer(result) else: raise errors.SmartProtocolError( 'bad protocol status %r' % response_status) def _write_protocol_version(self): """Write any prefixes this protocol requires. Version two sends the value of REQUEST_VERSION_TWO. """ self._request.accept_bytes(self.request_marker) def read_streamed_body(self): """Read bytes from the body, decoding into a byte stream. """ # Read no more than 64k at a time so that we don't risk error 10055 (no # buffer space available) on Windows. _body_decoder = ChunkedBodyDecoder() while not _body_decoder.finished_reading: bytes = self._request.read_bytes(_body_decoder.next_read_size()) if bytes == '': # end of file encountered reading from server raise errors.ConnectionReset( "Connection lost while reading streamed body.") _body_decoder.accept_bytes(bytes) for body_bytes in iter(_body_decoder.read_next_chunk, None): if 'hpss' in debug.debug_flags and type(body_bytes) is str: mutter(' %d byte chunk read', len(body_bytes)) yield body_bytes self._request.finished_reading() def build_server_protocol_three(backing_transport, write_func, root_client_path, jail_root=None): request_handler = request.SmartServerRequestHandler( backing_transport, commands=request.request_handlers, root_client_path=root_client_path, jail_root=jail_root) responder = ProtocolThreeResponder(write_func) message_handler = message.ConventionalRequestHandler(request_handler, responder) return ProtocolThreeDecoder(message_handler) class ProtocolThreeDecoder(_StatefulDecoder): response_marker = RESPONSE_VERSION_THREE request_marker = REQUEST_VERSION_THREE def __init__(self, message_handler, expect_version_marker=False): _StatefulDecoder.__init__(self) self._has_dispatched = False # Initial state if expect_version_marker: self.state_accept = self._state_accept_expecting_protocol_version # We're expecting at least the protocol version marker + some # headers. self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4 else: self.state_accept = self._state_accept_expecting_headers self._number_needed_bytes = 4 self.decoding_failed = False self.request_handler = self.message_handler = message_handler def accept_bytes(self, bytes): self._number_needed_bytes = None try: _StatefulDecoder.accept_bytes(self, bytes) except KeyboardInterrupt: raise except errors.SmartMessageHandlerError, exception: # We do *not* set self.decoding_failed here. The message handler # has raised an error, but the decoder is still able to parse bytes # and determine when this message ends. if not isinstance(exception.exc_value, errors.UnknownSmartMethod): log_exception_quietly() self.message_handler.protocol_error(exception.exc_value) # The state machine is ready to continue decoding, but the # exception has interrupted the loop that runs the state machine. # So we call accept_bytes again to restart it. self.accept_bytes('') except Exception, exception: # The decoder itself has raised an exception. We cannot continue # decoding. self.decoding_failed = True if isinstance(exception, errors.UnexpectedProtocolVersionMarker): # This happens during normal operation when the client tries a # protocol version the server doesn't understand, so no need to # log a traceback every time. # Note that this can only happen when # expect_version_marker=True, which is only the case on the # client side. pass else: log_exception_quietly() self.message_handler.protocol_error(exception) def _extract_length_prefixed_bytes(self): if self._in_buffer_len < 4: # A length prefix by itself is 4 bytes, and we don't even have that # many yet. raise _NeedMoreBytes(4) (length,) = struct.unpack('!L', self._get_in_bytes(4)) end_of_bytes = 4 + length if self._in_buffer_len < end_of_bytes: # We haven't yet read as many bytes as the length-prefix says there # are. raise _NeedMoreBytes(end_of_bytes) # Extract the bytes from the buffer. in_buf = self._get_in_buffer() bytes = in_buf[4:end_of_bytes] self._set_in_buffer(in_buf[end_of_bytes:]) return bytes def _extract_prefixed_bencoded_data(self): prefixed_bytes = self._extract_length_prefixed_bytes() try: decoded = bdecode_as_tuple(prefixed_bytes) except ValueError: raise errors.SmartProtocolError( 'Bytes %r not bencoded' % (prefixed_bytes,)) return decoded def _extract_single_byte(self): if self._in_buffer_len == 0: # The buffer is empty raise _NeedMoreBytes(1) in_buf = self._get_in_buffer() one_byte = in_buf[0] self._set_in_buffer(in_buf[1:]) return one_byte def _state_accept_expecting_protocol_version(self): needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len in_buf = self._get_in_buffer() if needed_bytes > 0: # We don't have enough bytes to check if the protocol version # marker is right. But we can check if it is already wrong by # checking that the start of MESSAGE_VERSION_THREE matches what # we've read so far. # [In fact, if the remote end isn't bzr we might never receive # len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far # are wrong then we should just raise immediately rather than # stall.] if not MESSAGE_VERSION_THREE.startswith(in_buf): # We have enough bytes to know the protocol version is wrong raise errors.UnexpectedProtocolVersionMarker(in_buf) raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE)) if not in_buf.startswith(MESSAGE_VERSION_THREE): raise errors.UnexpectedProtocolVersionMarker(in_buf) self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):]) self.state_accept = self._state_accept_expecting_headers def _state_accept_expecting_headers(self): decoded = self._extract_prefixed_bencoded_data() if type(decoded) is not dict: raise errors.SmartProtocolError( 'Header object %r is not a dict' % (decoded,)) self.state_accept = self._state_accept_expecting_message_part try: self.message_handler.headers_received(decoded) except: raise errors.SmartMessageHandlerError(sys.exc_info()) def _state_accept_expecting_message_part(self): message_part_kind = self._extract_single_byte() if message_part_kind == 'o': self.state_accept = self._state_accept_expecting_one_byte elif message_part_kind == 's': self.state_accept = self._state_accept_expecting_structure elif message_part_kind == 'b': self.state_accept = self._state_accept_expecting_bytes elif message_part_kind == 'e': self.done() else: raise errors.SmartProtocolError( 'Bad message kind byte: %r' % (message_part_kind,)) def _state_accept_expecting_one_byte(self): byte = self._extract_single_byte() self.state_accept = self._state_accept_expecting_message_part try: self.message_handler.byte_part_received(byte) except: raise errors.SmartMessageHandlerError(sys.exc_info()) def _state_accept_expecting_bytes(self): # XXX: this should not buffer whole message part, but instead deliver # the bytes as they arrive. prefixed_bytes = self._extract_length_prefixed_bytes() self.state_accept = self._state_accept_expecting_message_part try: self.message_handler.bytes_part_received(prefixed_bytes) except: raise errors.SmartMessageHandlerError(sys.exc_info()) def _state_accept_expecting_structure(self): structure = self._extract_prefixed_bencoded_data() self.state_accept = self._state_accept_expecting_message_part try: self.message_handler.structure_part_received(structure) except: raise errors.SmartMessageHandlerError(sys.exc_info()) def done(self): self.unused_data = self._get_in_buffer() self._set_in_buffer(None) self.state_accept = self._state_accept_reading_unused try: self.message_handler.end_received() except: raise errors.SmartMessageHandlerError(sys.exc_info()) def _state_accept_reading_unused(self): self.unused_data += self._get_in_buffer() self._set_in_buffer(None) def next_read_size(self): if self.state_accept == self._state_accept_reading_unused: return 0 elif self.decoding_failed: # An exception occured while processing this message, probably from # self.message_handler. We're not sure that this state machine is # in a consistent state, so just signal that we're done (i.e. give # up). return 0 else: if self._number_needed_bytes is not None: return self._number_needed_bytes - self._in_buffer_len else: raise AssertionError("don't know how many bytes are expected!") class _ProtocolThreeEncoder(object): response_marker = request_marker = MESSAGE_VERSION_THREE BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing def __init__(self, write_func): self._buf = [] self._buf_len = 0 self._real_write_func = write_func def _write_func(self, bytes): # TODO: Another possibility would be to turn this into an async model. # Where we let another thread know that we have some bytes if # they want it, but we don't actually block for it # Note that osutils.send_all always sends 64kB chunks anyway, so # we might just push out smaller bits at a time? self._buf.append(bytes) self._buf_len += len(bytes) if self._buf_len > self.BUFFER_SIZE: self.flush() def flush(self): if self._buf: self._real_write_func(''.join(self._buf)) del self._buf[:] self._buf_len = 0 def _serialise_offsets(self, offsets): """Serialise a readv offset list.""" txt = [] for start, length in offsets: txt.append('%d,%d' % (start, length)) return '\n'.join(txt) def _write_protocol_version(self): self._write_func(MESSAGE_VERSION_THREE) def _write_prefixed_bencode(self, structure): bytes = bencode(structure) self._write_func(struct.pack('!L', len(bytes))) self._write_func(bytes) def _write_headers(self, headers): self._write_prefixed_bencode(headers) def _write_structure(self, args): self._write_func('s') utf8_args = [] for arg in args: if type(arg) is unicode: utf8_args.append(arg.encode('utf8')) else: utf8_args.append(arg) self._write_prefixed_bencode(utf8_args) def _write_end(self): self._write_func('e') self.flush() def _write_prefixed_body(self, bytes): self._write_func('b') self._write_func(struct.pack('!L', len(bytes))) self._write_func(bytes) def _write_chunked_body_start(self): self._write_func('oC') def _write_error_status(self): self._write_func('oE') def _write_success_status(self): self._write_func('oS') class ProtocolThreeResponder(_ProtocolThreeEncoder): def __init__(self, write_func): _ProtocolThreeEncoder.__init__(self, write_func) self.response_sent = False self._headers = {'Software version': bzrlib.__version__} if 'hpss' in debug.debug_flags: self._thread_id = thread.get_ident() self._response_start_time = None def _trace(self, action, message, extra_bytes=None, include_time=False): if self._response_start_time is None: self._response_start_time = osutils.timer_func() if include_time: t = '%5.3fs ' % (time.clock() - self._response_start_time) else: t = '' if extra_bytes is None: extra = '' else: extra = ' ' + repr(extra_bytes[:40]) if len(extra) > 33: extra = extra[:29] + extra[-1] + '...' mutter('%12s: [%s] %s%s%s' % (action, self._thread_id, t, message, extra)) def send_error(self, exception): if self.response_sent: raise AssertionError( "send_error(%s) called, but response already sent." % (exception,)) if isinstance(exception, errors.UnknownSmartMethod): failure = request.FailedSmartServerResponse( ('UnknownMethod', exception.verb)) self.send_response(failure) return if 'hpss' in debug.debug_flags: self._trace('error', str(exception)) self.response_sent = True self._write_protocol_version() self._write_headers(self._headers) self._write_error_status() self._write_structure(('error', str(exception))) self._write_end() def send_response(self, response): if self.response_sent: raise AssertionError( "send_response(%r) called, but response already sent." % (response,)) self.response_sent = True self._write_protocol_version() self._write_headers(self._headers) if response.is_successful(): self._write_success_status() else: self._write_error_status() if 'hpss' in debug.debug_flags: self._trace('response', repr(response.args)) self._write_structure(response.args) if response.body is not None: self._write_prefixed_body(response.body) if 'hpss' in debug.debug_flags: self._trace('body', '%d bytes' % (len(response.body),), response.body, include_time=True) elif response.body_stream is not None: count = num_bytes = 0 first_chunk = None for exc_info, chunk in _iter_with_errors(response.body_stream): count += 1 if exc_info is not None: self._write_error_status() error_struct = request._translate_error(exc_info[1]) self._write_structure(error_struct) break else: if isinstance(chunk, request.FailedSmartServerResponse): self._write_error_status() self._write_structure(chunk.args) break num_bytes += len(chunk) if first_chunk is None: first_chunk = chunk self._write_prefixed_body(chunk) self.flush() if 'hpssdetail' in debug.debug_flags: # Not worth timing separately, as _write_func is # actually buffered self._trace('body chunk', '%d bytes' % (len(chunk),), chunk, suppress_time=True) if 'hpss' in debug.debug_flags: self._trace('body stream', '%d bytes %d chunks' % (num_bytes, count), first_chunk) self._write_end() if 'hpss' in debug.debug_flags: self._trace('response end', '', include_time=True) def _iter_with_errors(iterable): """Handle errors from iterable.next(). Use like:: for exc_info, value in _iter_with_errors(iterable): ... This is a safer alternative to:: try: for value in iterable: ... except: ... Because the latter will catch errors from the for-loop body, not just iterable.next() If an error occurs, exc_info will be a exc_info tuple, and the generator will terminate. Otherwise exc_info will be None, and value will be the value from iterable.next(). Note that KeyboardInterrupt and SystemExit will not be itercepted. """ iterator = iter(iterable) while True: try: yield None, iterator.next() except StopIteration: return except (KeyboardInterrupt, SystemExit): raise except Exception: mutter('_iter_with_errors caught error') log_exception_quietly() yield sys.exc_info(), None return class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester): def __init__(self, medium_request): _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes) self._medium_request = medium_request self._headers = {} self.body_stream_started = None def set_headers(self, headers): self._headers = headers.copy() def call(self, *args): if 'hpss' in debug.debug_flags: mutter('hpss call: %s', repr(args)[1:-1]) base = getattr(self._medium_request._medium, 'base', None) if base is not None: mutter(' (to %s)', base) self._request_start_time = osutils.timer_func() self._write_protocol_version() self._write_headers(self._headers) self._write_structure(args) self._write_end() self._medium_request.finished_writing() def call_with_body_bytes(self, args, body): """Make a remote call of args with body bytes 'body'. After calling this, call read_response_tuple to find the result out. """ if 'hpss' in debug.debug_flags: mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20]) path = getattr(self._medium_request._medium, '_path', None) if path is not None: mutter(' (to %s)', path) mutter(' %d bytes', len(body)) self._request_start_time = osutils.timer_func() self._write_protocol_version() self._write_headers(self._headers) self._write_structure(args) self._write_prefixed_body(body) self._write_end() self._medium_request.finished_writing() def call_with_body_readv_array(self, args, body): """Make a remote call with a readv array. The body is encoded with one line per readv offset pair. The numbers in each pair are separated by a comma, and no trailing \\n is emitted. """ if 'hpss' in debug.debug_flags: mutter('hpss call w/readv: %s', repr(args)[1:-1]) path = getattr(self._medium_request._medium, '_path', None) if path is not None: mutter(' (to %s)', path) self._request_start_time = osutils.timer_func() self._write_protocol_version() self._write_headers(self._headers) self._write_structure(args) readv_bytes = self._serialise_offsets(body) if 'hpss' in debug.debug_flags: mutter(' %d bytes in readv request', len(readv_bytes)) self._write_prefixed_body(readv_bytes) self._write_end() self._medium_request.finished_writing() def call_with_body_stream(self, args, stream): if 'hpss' in debug.debug_flags: mutter('hpss call w/body stream: %r', args) path = getattr(self._medium_request._medium, '_path', None) if path is not None: mutter(' (to %s)', path) self._request_start_time = osutils.timer_func() self.body_stream_started = False self._write_protocol_version() self._write_headers(self._headers) self._write_structure(args) # TODO: notice if the server has sent an early error reply before we # have finished sending the stream. We would notice at the end # anyway, but if the medium can deliver it early then it's good # to short-circuit the whole request... # Provoke any ConnectionReset failures before we start the body stream. self.flush() self.body_stream_started = True for exc_info, part in _iter_with_errors(stream): if exc_info is not None: # Iterating the stream failed. Cleanly abort the request. self._write_error_status() # Currently the client unconditionally sends ('error',) as the # error args. self._write_structure(('error',)) self._write_end() self._medium_request.finished_writing() raise exc_info[0], exc_info[1], exc_info[2] else: self._write_prefixed_body(part) self.flush() self._write_end() self._medium_request.finished_writing()