diff options
-rw-r--r-- | test/test_consumer_group.py | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py new file mode 100644 index 0000000..0fd65b8 --- /dev/null +++ b/test/test_consumer_group.py @@ -0,0 +1,170 @@ +import collections +import logging +import threading +import os +import time + +import pytest +import six + +from kafka import KafkaClient, SimpleProducer +from kafka.common import TopicPartition +from kafka.conn import BrokerConnection, ConnectionStates +from kafka.consumer.group import KafkaConsumer + +from test.fixtures import KafkaFixture, ZookeeperFixture +from test.testutil import random_string + + +@pytest.fixture(scope="module") +def version(): + if 'KAFKA_VERSION' not in os.environ: + return () + return tuple(map(int, os.environ['KAFKA_VERSION'].split('.'))) + + +@pytest.fixture(scope="module") +def zookeeper(version, request): + assert version + zk = ZookeeperFixture.instance() + def fin(): + zk.close() + request.addfinalizer(fin) + return zk + + +@pytest.fixture(scope="module") +def kafka_broker(version, zookeeper, request): + assert version + k = KafkaFixture.instance(0, zookeeper.host, zookeeper.port, + partitions=4) + def fin(): + k.close() + request.addfinalizer(fin) + return k + + +@pytest.fixture +def simple_client(kafka_broker): + connect_str = 'localhost:' + str(kafka_broker.port) + return KafkaClient(connect_str) + + +@pytest.fixture +def topic(simple_client): + topic = random_string(5) + simple_client.ensure_topic_exists(topic) + return topic + + +@pytest.fixture +def topic_with_messages(simple_client, topic): + producer = SimpleProducer(simple_client) + for i in six.moves.xrange(100): + producer.send_messages(topic, 'msg_%d' % i) + return topic + + +@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set") +def test_consumer(kafka_broker, version): + + # 0.8.2 brokers need a topic to function well + if version >= (0, 8, 2) and version < (0, 9): + topic(simple_client(kafka_broker)) + + connect_str = 'localhost:' + str(kafka_broker.port) + consumer = KafkaConsumer(bootstrap_servers=connect_str) + consumer.poll(500) + assert len(consumer._client._conns) > 0 + node_id = list(consumer._client._conns.keys())[0] + assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED + + +@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version') +@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set") +def test_group(kafka_broker, topic): + num_partitions = 4 + connect_str = 'localhost:' + str(kafka_broker.port) + consumers = {} + stop = {} + messages = collections.defaultdict(list) + def consumer_thread(i): + assert i not in consumers + assert i not in stop + stop[i] = threading.Event() + consumers[i] = KafkaConsumer(topic, + bootstrap_servers=connect_str, + request_timeout_ms=1000) + while not stop[i].is_set(): + for tp, records in six.itervalues(consumers[i].poll()): + messages[i][tp].extend(records) + consumers[i].close() + del consumers[i] + del stop[i] + + num_consumers = 4 + for i in range(num_consumers): + threading.Thread(target=consumer_thread, args=(i,)).start() + + try: + timeout = time.time() + 35 + while True: + for c in range(num_consumers): + if c not in consumers: + break + elif not consumers[c].assignment(): + break + else: + for c in range(num_consumers): + logging.info("%s: %s", c, consumers[c].assignment()) + break + assert time.time() < timeout, "timeout waiting for assignments" + + group_assignment = set() + for c in range(num_consumers): + assert len(consumers[c].assignment()) != 0 + assert set.isdisjoint(consumers[c].assignment(), group_assignment) + group_assignment.update(consumers[c].assignment()) + + assert group_assignment == set([ + TopicPartition(topic, partition) + for partition in range(num_partitions)]) + + finally: + for c in range(num_consumers): + stop[c].set() + + +@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set") +def test_correlation_id_rollover(kafka_broker): + logging.getLogger('kafka.conn').setLevel(logging.ERROR) + from kafka.protocol.metadata import MetadataRequest + conn = BrokerConnection('localhost', kafka_broker.port, + receive_buffer_bytes=131072, + max_in_flight_requests_per_connection=100) + req = MetadataRequest([]) + while not conn.connected(): + conn.connect() + futures = collections.deque() + start = time.time() + done = 0 + for i in six.moves.xrange(2**13): + if not conn.can_send_more(): + conn.recv(timeout=None) + futures.append(conn.send(req)) + conn.recv() + while futures and futures[0].is_done: + f = futures.popleft() + if not f.succeeded(): + raise f.exception + done += 1 + if time.time() > start + 10: + print ("%d done" % done) + start = time.time() + + while futures: + conn.recv() + if futures[0].is_done: + f = futures.popleft() + if not f.succeeded(): + raise f.exception |