summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAntonio Ojea <6450081+aojea@users.noreply.github.com>2019-01-25 19:06:42 +0100
committerAsif Saif Uddin <auvipy@gmail.com>2019-01-26 00:06:42 +0600
commite45ea3ece36ee5e58eaa3421d49b77ad86fdb5e5 (patch)
treebaee29191d5871887eed0376ceefa8594e513d6d
parent734305d58cdf025bb939540f2b7bbd2a569a37f5 (diff)
downloadpy-amqp-e45ea3ece36ee5e58eaa3421d49b77ad86fdb5e5.tar.gz
read_frame python3 compatible for large payloads (#248)
read_frame is using str.join to concatenate the payload if the received frame is bigger than SIGNED_INT_MAX. That's fine with python2, however in python3 documentation is stated str.joinReturn a string which is the concatenation of the strings in the iterable iterable. A TypeError will be raised if there are any non-string values in iterable, including bytes objects. The separator between elements is the string providing this method. So we have to use the byte object join() method Signed-off-by: aojeagarcia <aojeagarcia@suse.com>
-rw-r--r--amqp/transport.py2
-rw-r--r--t/unit/test_transport.py11
2 files changed, 12 insertions, 1 deletions
diff --git a/amqp/transport.py b/amqp/transport.py
index 4363fb7..80b058e 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -255,7 +255,7 @@ class _AbstractTransport(object):
if size > SIGNED_INT_MAX:
part1 = read(SIGNED_INT_MAX)
part2 = read(size - SIGNED_INT_MAX)
- payload = ''.join([part1, part2])
+ payload = b''.join([part1, part2])
else:
payload = read(size)
read_frame_buffer += payload
diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py
index 6a345b4..85bd301 100644
--- a/t/unit/test_transport.py
+++ b/t/unit/test_transport.py
@@ -11,6 +11,8 @@ from amqp.exceptions import UnexpectedFrame
from amqp.platform import pack
from amqp.transport import _AbstractTransport
+SIGNED_INT_MAX = 0x7FFFFFFF
+
class DummyException(Exception):
pass
@@ -319,6 +321,15 @@ class test_AbstractTransport:
with pytest.raises(UnexpectedFrame):
self.t.read_frame()
+ def test_read_frame__long(self):
+ self.t._read = Mock()
+ self.t._read.side_effect = [pack('>BHI', 1, 1, SIGNED_INT_MAX + 16),
+ b'read1', b'read2', b'\xce']
+ frame_type, channel, payload = self.t.read_frame()
+ assert frame_type == 1
+ assert channel == 1
+ assert payload == b'read1read2'
+
def transport_read_EOF(self):
for host, ssl in (('localhost:5672', False),
('localhost:5671', True),):