From 079c963ddd4ebfd13a905829bc341dce85d94fbd Mon Sep 17 00:00:00 2001 From: Daniel Holth Date: Sun, 17 Nov 2019 22:45:52 -0500 Subject: use _ffi.from_buffer() to support bytearray (#852) * use _ffi.from_buffer(buf) in send, to support bytearray * add bytearray test * update CHANGELOG.rst * move from_buffer before 'buffer too long' check * context-managed from_buffer + black * don't shadow buf in send() * test return count for sendall * test sending an array * fix test * also use from_buffer in bio_write * de-format _util.py * formatting * add simple bio_write tests * wrap line --- .gitignore | 1 + CHANGELOG.rst | 3 ++- src/OpenSSL/SSL.py | 72 +++++++++++++++++++++++++--------------------------- src/OpenSSL/_util.py | 14 ++++++++++ tests/test_ssl.py | 42 ++++++++++++++++++++++++++++-- 5 files changed, 92 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index 2040b80..e3057ad 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ doc/_build/ examples/simple/*.cert examples/simple/*.pkey .cache +.mypy_cache \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2bf74f5..e0c034d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,7 +28,8 @@ Deprecations: Changes: ^^^^^^^^ -*none* +- Support ``bytearray`` in ``SSL.Connection.send()`` by using cffi's from_buffer. + `#852 `_ ---- diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 5521151..adcfd8f 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -15,6 +15,7 @@ from OpenSSL._util import ( UNSPECIFIED as _UNSPECIFIED, exception_from_error_queue as _exception_from_error_queue, ffi as _ffi, + from_buffer as _from_buffer, lib as _lib, make_assert as _make_assert, native as _native, @@ -1730,18 +1731,18 @@ class Connection(object): # Backward compatibility buf = _text_to_bytes_and_warn("buf", buf) - if isinstance(buf, memoryview): - buf = buf.tobytes() - if isinstance(buf, _buffer): - buf = str(buf) - if not isinstance(buf, bytes): - raise TypeError("data must be a memoryview, buffer or byte string") - if len(buf) > 2147483647: - raise ValueError("Cannot send more than 2**31-1 bytes at once.") + with _from_buffer(buf) as data: + # check len(buf) instead of len(data) for testability + if len(buf) > 2147483647: + raise ValueError( + "Cannot send more than 2**31-1 bytes at once." + ) + + result = _lib.SSL_write(self._ssl, data, len(data)) + self._raise_ssl_error(self._ssl, result) + + return result - result = _lib.SSL_write(self._ssl, buf, len(buf)) - self._raise_ssl_error(self._ssl, result) - return result write = send def sendall(self, buf, flags=0): @@ -1757,28 +1758,24 @@ class Connection(object): """ buf = _text_to_bytes_and_warn("buf", buf) - if isinstance(buf, memoryview): - buf = buf.tobytes() - if isinstance(buf, _buffer): - buf = str(buf) - if not isinstance(buf, bytes): - raise TypeError("buf must be a memoryview, buffer or byte string") - - left_to_send = len(buf) - total_sent = 0 - data = _ffi.new("char[]", buf) - - while left_to_send: - # SSL_write's num arg is an int, - # so we cannot send more than 2**31-1 bytes at once. - result = _lib.SSL_write( - self._ssl, - data + total_sent, - min(left_to_send, 2147483647) - ) - self._raise_ssl_error(self._ssl, result) - total_sent += result - left_to_send -= result + with _from_buffer(buf) as data: + + left_to_send = len(buf) + total_sent = 0 + + while left_to_send: + # SSL_write's num arg is an int, + # so we cannot send more than 2**31-1 bytes at once. + result = _lib.SSL_write( + self._ssl, + data + total_sent, + min(left_to_send, 2147483647) + ) + self._raise_ssl_error(self._ssl, result) + total_sent += result + left_to_send -= result + + return total_sent def recv(self, bufsiz, flags=None): """ @@ -1892,10 +1889,11 @@ class Connection(object): if self._into_ssl is None: raise TypeError("Connection sock was not None") - result = _lib.BIO_write(self._into_ssl, buf, len(buf)) - if result <= 0: - self._handle_bio_errors(self._into_ssl, result) - return result + with _from_buffer(buf) as data: + result = _lib.BIO_write(self._into_ssl, data, len(data)) + if result <= 0: + self._handle_bio_errors(self._into_ssl, result) + return result def renegotiate(self): """ diff --git a/src/OpenSSL/_util.py b/src/OpenSSL/_util.py index cdcacc8..d8e3f66 100644 --- a/src/OpenSSL/_util.py +++ b/src/OpenSSL/_util.py @@ -145,3 +145,17 @@ def text_to_bytes_and_warn(label, obj): ) return obj.encode('utf-8') return obj + + +try: + # newer versions of cffi free the buffer deterministically + with ffi.from_buffer(b""): + pass + from_buffer = ffi.from_buffer +except AttributeError: + # cffi < 0.12 frees the buffer with refcounting gc + from contextlib import contextmanager + + @contextmanager + def from_buffer(*args): + yield ffi.from_buffer(*args) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 6b9422c..16767e9 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -2087,6 +2087,29 @@ class TestConnection(object): with pytest.raises(TypeError): Connection(bad_context) + @pytest.mark.parametrize('bad_bio', [object(), None, 1, [1, 2, 3]]) + def test_bio_write_wrong_args(self, bad_bio): + """ + `Connection.bio_write` raises `TypeError` if called with a non-bytes + (or text) argument. + """ + context = Context(TLSv1_METHOD) + connection = Connection(context, None) + with pytest.raises(TypeError): + connection.bio_write(bad_bio) + + def test_bio_write(self): + """ + `Connection.bio_write` does not raise if called with bytes or + bytearray, warns if called with text. + """ + context = Context(TLSv1_METHOD) + connection = Connection(context, None) + connection.bio_write(b'xy') + connection.bio_write(bytearray(b'za')) + with pytest.warns(DeprecationWarning): + connection.bio_write(u'deprecated') + def test_get_context(self): """ `Connection.get_context` returns the `Context` instance used to @@ -2807,6 +2830,8 @@ class TestConnectionSend(object): connection = Connection(Context(TLSv1_METHOD), None) with pytest.raises(TypeError): connection.send(object()) + with pytest.raises(TypeError): + connection.send([1, 2, 3]) def test_short_bytes(self): """ @@ -2845,6 +2870,16 @@ class TestConnectionSend(object): assert count == 2 assert client.recv(2) == b'xy' + def test_short_bytearray(self): + """ + When passed a short bytearray, `Connection.send` transmits all of + it and returns the number of bytes sent. + """ + server, client = loopback() + count = server.send(bytearray(b'xy')) + assert count == 2 + assert client.recv(2) == b'xy' + @skip_if_py3 def test_short_buffer(self): """ @@ -3015,6 +3050,8 @@ class TestConnectionSendall(object): connection = Connection(Context(TLSv1_METHOD), None) with pytest.raises(TypeError): connection.sendall(object()) + with pytest.raises(TypeError): + connection.sendall([1, 2, 3]) def test_short(self): """ @@ -3056,8 +3093,9 @@ class TestConnectionSendall(object): `Connection.sendall` transmits all of them. """ server, client = loopback() - server.sendall(buffer(b'x')) - assert client.recv(1) == b'x' + count = server.sendall(buffer(b'xy')) + assert count == 2 + assert client.recv(2) == b'xy' def test_long(self): """ -- cgit v1.2.1