summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Roberts <wizzat@gmail.com>2014-11-26 10:30:02 -0800
committerMark Roberts <wizzat@gmail.com>2014-11-26 10:30:02 -0800
commit3689529a6127d55fd0b580ddd621ed3ee5abcb6a (patch)
tree02d7a7fc69748552d2e478b344e6fb8d4cc22a42
parent52ec0782b751d03a6cf293d97922fc0f5bd4aeb1 (diff)
parenta9e77bdaa2490d4c8c343d18f32d7b256c50ddd7 (diff)
downloadkafka-python-3689529a6127d55fd0b580ddd621ed3ee5abcb6a.tar.gz
Merge pull request #268 from se7entyse7en/keyed_message
Pass key to message sent by `KeyedProducer`
-rw-r--r--kafka/producer/base.py18
-rw-r--r--kafka/producer/keyed.py2
-rw-r--r--kafka/protocol.py8
-rw-r--r--test/test_producer_integration.py42
-rw-r--r--test/testutil.py4
5 files changed, 47 insertions, 27 deletions
diff --git a/kafka/producer/base.py b/kafka/producer/base.py
index 6c91364..6e19b92 100644
--- a/kafka/producer/base.py
+++ b/kafka/producer/base.py
@@ -49,7 +49,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
# timeout is reached
while count > 0 and timeout >= 0:
try:
- topic_partition, msg = queue.get(timeout=timeout)
+ topic_partition, msg, key = queue.get(timeout=timeout)
except Empty:
break
@@ -67,7 +67,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
# Send collected requests upstream
reqs = []
for topic_partition, msg in msgset.items():
- messages = create_message_set(msg, codec)
+ messages = create_message_set(msg, codec, key)
req = ProduceRequest(topic_partition.topic,
topic_partition.partition,
messages)
@@ -169,6 +169,10 @@ class Producer(object):
All messages produced via this method will set the message 'key' to Null
"""
+ return self._send_messages(topic, partition, *msg)
+
+ def _send_messages(self, topic, partition, *msg, **kwargs):
+ key = kwargs.pop('key', None)
# Guarantee that msg is actually a list or tuple (should always be true)
if not isinstance(msg, (list, tuple)):
@@ -178,12 +182,16 @@ class Producer(object):
if any(not isinstance(m, six.binary_type) for m in msg):
raise TypeError("all produce message payloads must be type bytes")
+ # Raise TypeError if the key is not encoded as bytes
+ if key is not None and not isinstance(key, six.binary_type):
+ raise TypeError("the key must be type bytes")
+
if self.async:
for m in msg:
- self.queue.put((TopicAndPartition(topic, partition), m))
+ self.queue.put((TopicAndPartition(topic, partition), m, key))
resp = []
else:
- messages = create_message_set(msg, self.codec)
+ messages = create_message_set(msg, self.codec, key)
req = ProduceRequest(topic, partition, messages)
try:
resp = self.client.send_produce_request([req], acks=self.req_acks,
@@ -199,7 +207,7 @@ class Producer(object):
forcefully cleaning up.
"""
if self.async:
- self.queue.put((STOP_ASYNC_PRODUCER, None))
+ self.queue.put((STOP_ASYNC_PRODUCER, None, None))
self.proc.join(timeout)
if self.proc.is_alive():
diff --git a/kafka/producer/keyed.py b/kafka/producer/keyed.py
index d2311c6..473f70a 100644
--- a/kafka/producer/keyed.py
+++ b/kafka/producer/keyed.py
@@ -56,7 +56,7 @@ class KeyedProducer(Producer):
def send(self, topic, key, msg):
partition = self._next_partition(topic, key)
- return self.send_messages(topic, partition, msg)
+ return self._send_messages(topic, partition, msg, key=key)
def __repr__(self):
return '<KeyedProducer batch=%s>' % self.async
diff --git a/kafka/protocol.py b/kafka/protocol.py
index 266e963..13b973e 100644
--- a/kafka/protocol.py
+++ b/kafka/protocol.py
@@ -597,17 +597,17 @@ def create_snappy_message(payloads, key=None):
return Message(0, 0x00 | codec, key, snapped)
-def create_message_set(messages, codec=CODEC_NONE):
+def create_message_set(messages, codec=CODEC_NONE, key=None):
"""Create a message set using the given codec.
If codec is CODEC_NONE, return a list of raw Kafka messages. Otherwise,
return a list containing a single codec-encoded message.
"""
if codec == CODEC_NONE:
- return [create_message(m) for m in messages]
+ return [create_message(m, key) for m in messages]
elif codec == CODEC_GZIP:
- return [create_gzip_message(messages)]
+ return [create_gzip_message(messages, key)]
elif codec == CODEC_SNAPPY:
- return [create_snappy_message(messages)]
+ return [create_snappy_message(messages, key)]
else:
raise UnsupportedCodecError("Codec 0x%02x unsupported" % codec)
diff --git a/test/test_producer_integration.py b/test/test_producer_integration.py
index 1ebbc86..d68af72 100644
--- a/test/test_producer_integration.py
+++ b/test/test_producer_integration.py
@@ -199,10 +199,10 @@ class TestKafkaProducerIntegration(KafkaIntegrationTestCase):
start_offset1 = self.current_offset(self.topic, 1)
producer = KeyedProducer(self.client, partitioner=RoundRobinPartitioner)
- resp1 = producer.send(self.topic, "key1", self.msg("one"))
- resp2 = producer.send(self.topic, "key2", self.msg("two"))
- resp3 = producer.send(self.topic, "key3", self.msg("three"))
- resp4 = producer.send(self.topic, "key4", self.msg("four"))
+ resp1 = producer.send(self.topic, self.key("key1"), self.msg("one"))
+ resp2 = producer.send(self.topic, self.key("key2"), self.msg("two"))
+ resp3 = producer.send(self.topic, self.key("key3"), self.msg("three"))
+ resp4 = producer.send(self.topic, self.key("key4"), self.msg("four"))
self.assert_produce_response(resp1, start_offset0+0)
self.assert_produce_response(resp2, start_offset1+0)
@@ -220,20 +220,28 @@ class TestKafkaProducerIntegration(KafkaIntegrationTestCase):
start_offset1 = self.current_offset(self.topic, 1)
producer = KeyedProducer(self.client, partitioner=HashedPartitioner)
- resp1 = producer.send(self.topic, 1, self.msg("one"))
- resp2 = producer.send(self.topic, 2, self.msg("two"))
- resp3 = producer.send(self.topic, 3, self.msg("three"))
- resp4 = producer.send(self.topic, 3, self.msg("four"))
- resp5 = producer.send(self.topic, 4, self.msg("five"))
+ resp1 = producer.send(self.topic, self.key("1"), self.msg("one"))
+ resp2 = producer.send(self.topic, self.key("2"), self.msg("two"))
+ resp3 = producer.send(self.topic, self.key("3"), self.msg("three"))
+ resp4 = producer.send(self.topic, self.key("3"), self.msg("four"))
+ resp5 = producer.send(self.topic, self.key("4"), self.msg("five"))
- self.assert_produce_response(resp1, start_offset1+0)
- self.assert_produce_response(resp2, start_offset0+0)
- self.assert_produce_response(resp3, start_offset1+1)
- self.assert_produce_response(resp4, start_offset1+2)
- self.assert_produce_response(resp5, start_offset0+1)
+ offsets = {0: start_offset0, 1: start_offset1}
+ messages = {0: [], 1: []}
- self.assert_fetch_offset(0, start_offset0, [ self.msg("two"), self.msg("five") ])
- self.assert_fetch_offset(1, start_offset1, [ self.msg("one"), self.msg("three"), self.msg("four") ])
+ keys = [self.key(k) for k in ["1", "2", "3", "3", "4"]]
+ resps = [resp1, resp2, resp3, resp4, resp5]
+ msgs = [self.msg(m) for m in ["one", "two", "three", "four", "five"]]
+
+ for key, resp, msg in zip(keys, resps, msgs):
+ k = hash(key) % 2
+ offset = offsets[k]
+ self.assert_produce_response(resp, offset)
+ offsets[k] += 1
+ messages[k].append(msg)
+
+ self.assert_fetch_offset(0, start_offset0, messages[0])
+ self.assert_fetch_offset(1, start_offset1, messages[1])
producer.stop()
@@ -393,7 +401,7 @@ class TestKafkaProducerIntegration(KafkaIntegrationTestCase):
producer = KeyedProducer(self.client, partitioner = RoundRobinPartitioner, async=True)
- resp = producer.send(self.topic, "key1", self.msg("one"))
+ resp = producer.send(self.topic, self.key("key1"), self.msg("one"))
self.assertEquals(len(resp), 0)
self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
diff --git a/test/testutil.py b/test/testutil.py
index fba3869..7661cbc 100644
--- a/test/testutil.py
+++ b/test/testutil.py
@@ -89,6 +89,10 @@ class KafkaIntegrationTestCase(unittest.TestCase):
return self._messages[s].encode('utf-8')
+ def key(self, k):
+ return k.encode('utf-8')
+
+
class Timer(object):
def __enter__(self):
self.start = time.time()