From 079c963ddd4ebfd13a905829bc341dce85d94fbd Mon Sep 17 00:00:00 2001 From: Daniel Holth Date: Sun, 17 Nov 2019 22:45:52 -0500 Subject: use _ffi.from_buffer() to support bytearray (#852) * use _ffi.from_buffer(buf) in send, to support bytearray * add bytearray test * update CHANGELOG.rst * move from_buffer before 'buffer too long' check * context-managed from_buffer + black * don't shadow buf in send() * test return count for sendall * test sending an array * fix test * also use from_buffer in bio_write * de-format _util.py * formatting * add simple bio_write tests * wrap line --- src/OpenSSL/SSL.py | 72 ++++++++++++++++++++++++++---------------------------- 1 file changed, 35 insertions(+), 37 deletions(-) (limited to 'src/OpenSSL/SSL.py') diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 5521151..adcfd8f 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -15,6 +15,7 @@ from OpenSSL._util import ( UNSPECIFIED as _UNSPECIFIED, exception_from_error_queue as _exception_from_error_queue, ffi as _ffi, + from_buffer as _from_buffer, lib as _lib, make_assert as _make_assert, native as _native, @@ -1730,18 +1731,18 @@ class Connection(object): # Backward compatibility buf = _text_to_bytes_and_warn("buf", buf) - if isinstance(buf, memoryview): - buf = buf.tobytes() - if isinstance(buf, _buffer): - buf = str(buf) - if not isinstance(buf, bytes): - raise TypeError("data must be a memoryview, buffer or byte string") - if len(buf) > 2147483647: - raise ValueError("Cannot send more than 2**31-1 bytes at once.") + with _from_buffer(buf) as data: + # check len(buf) instead of len(data) for testability + if len(buf) > 2147483647: + raise ValueError( + "Cannot send more than 2**31-1 bytes at once." + ) + + result = _lib.SSL_write(self._ssl, data, len(data)) + self._raise_ssl_error(self._ssl, result) + + return result - result = _lib.SSL_write(self._ssl, buf, len(buf)) - self._raise_ssl_error(self._ssl, result) - return result write = send def sendall(self, buf, flags=0): @@ -1757,28 +1758,24 @@ class Connection(object): """ buf = _text_to_bytes_and_warn("buf", buf) - if isinstance(buf, memoryview): - buf = buf.tobytes() - if isinstance(buf, _buffer): - buf = str(buf) - if not isinstance(buf, bytes): - raise TypeError("buf must be a memoryview, buffer or byte string") - - left_to_send = len(buf) - total_sent = 0 - data = _ffi.new("char[]", buf) - - while left_to_send: - # SSL_write's num arg is an int, - # so we cannot send more than 2**31-1 bytes at once. - result = _lib.SSL_write( - self._ssl, - data + total_sent, - min(left_to_send, 2147483647) - ) - self._raise_ssl_error(self._ssl, result) - total_sent += result - left_to_send -= result + with _from_buffer(buf) as data: + + left_to_send = len(buf) + total_sent = 0 + + while left_to_send: + # SSL_write's num arg is an int, + # so we cannot send more than 2**31-1 bytes at once. + result = _lib.SSL_write( + self._ssl, + data + total_sent, + min(left_to_send, 2147483647) + ) + self._raise_ssl_error(self._ssl, result) + total_sent += result + left_to_send -= result + + return total_sent def recv(self, bufsiz, flags=None): """ @@ -1892,10 +1889,11 @@ class Connection(object): if self._into_ssl is None: raise TypeError("Connection sock was not None") - result = _lib.BIO_write(self._into_ssl, buf, len(buf)) - if result <= 0: - self._handle_bio_errors(self._into_ssl, result) - return result + with _from_buffer(buf) as data: + result = _lib.BIO_write(self._into_ssl, data, len(data)) + if result <= 0: + self._handle_bio_errors(self._into_ssl, result) + return result def renegotiate(self): """ -- cgit v1.2.1