diff options
author | Jakub Pieńkowski <8525083+Jakski@users.noreply.github.com> | 2022-04-18 14:10:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-18 18:10:38 +0600 |
commit | 661d92222eb485fd9131543c8a0a224824f5e0f0 (patch) | |
tree | 67c0469003d4b7e734dd798de13bda3bb27f1992 /kombu | |
parent | 0f9f554b7cb9a307b07bec74688095053034fd57 (diff) | |
download | kombu-661d92222eb485fd9131543c8a0a224824f5e0f0.tar.gz |
Support pymongo 4.x (#1536)
* Support pymongo 4.x
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix problems detected by CI
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Diffstat (limited to 'kombu')
-rw-r--r-- | kombu/transport/mongodb.py | 74 |
1 files changed, 44 insertions, 30 deletions
diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py index 8285b70f..b923f5f4 100644 --- a/kombu/transport/mongodb.py +++ b/kombu/transport/mongodb.py @@ -65,11 +65,10 @@ class BroadcastCursor: def __init__(self, cursor): self._cursor = cursor - self.purge(rewind=False) def get_size(self): - return self._cursor.count() - self._offset + return self._cursor.collection.count_documents({}) - self._offset def close(self): self._cursor.close() @@ -79,7 +78,7 @@ class BroadcastCursor: self._cursor.rewind() # Fast forward the cursor past old events - self._offset = self._cursor.count() + self._offset = self._cursor.collection.count_documents({}) self._cursor = self._cursor.skip(self._offset) def __iter__(self): @@ -151,11 +150,17 @@ class Channel(virtual.Channel): def _new_queue(self, queue, **kwargs): if self.ttl: - self.queues.update( + self.queues.update_one( {'_id': queue}, - {'_id': queue, - 'options': kwargs, - 'expire_at': self._get_expire(kwargs, 'x-expires')}, + { + '$set': { + '_id': queue, + 'options': kwargs, + 'expire_at': self._get_queue_expire( + kwargs, 'x-expires' + ), + }, + }, upsert=True) def _get(self, queue): @@ -165,10 +170,9 @@ class Channel(virtual.Channel): except StopIteration: msg = None else: - msg = self.messages.find_and_modify( - query={'queue': queue}, + msg = self.messages.find_one_and_delete( + {'queue': queue}, sort=[('priority', pymongo.ASCENDING)], - remove=True, ) if self.ttl: @@ -188,7 +192,7 @@ class Channel(virtual.Channel): if queue in self._fanout_queues: return self._get_broadcast_cursor(queue).get_size() - return self.messages.find({'queue': queue}).count() + return self.messages.count_documents({'queue': queue}) def _put(self, queue, message, **kwargs): data = { @@ -198,13 +202,18 @@ class Channel(virtual.Channel): } if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-message-ttl') + data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl') + msg_expire = self._get_message_expire(message) + if msg_expire is not None and ( + data['expire_at'] is None or msg_expire < data['expire_at'] + ): + data['expire_at'] = msg_expire - self.messages.insert(data) + self.messages.insert_one(data) def _put_fanout(self, exchange, message, routing_key, **kwargs): - self.broadcast.insert({'payload': dumps(message), - 'queue': exchange}) + self.broadcast.insert_one({'payload': dumps(message), + 'queue': exchange}) def _purge(self, queue): size = self._size(queue) @@ -243,9 +252,9 @@ class Channel(virtual.Channel): data = lookup.copy() if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-expires') + data['expire_at'] = self._get_queue_expire(queue, 'x-expires') - self.routing.update(lookup, data, upsert=True) + self.routing.update_one(lookup, {'$set': data}, upsert=True) def queue_delete(self, queue, **kwargs): self.routing.remove({'queue': queue}) @@ -348,7 +357,7 @@ class Channel(virtual.Channel): def _create_broadcast(self, database): """Create capped collection for broadcast messages.""" - if self.broadcast_collection in database.collection_names(): + if self.broadcast_collection in database.list_collection_names(): return database.create_collection(self.broadcast_collection, @@ -358,20 +367,20 @@ class Channel(virtual.Channel): def _ensure_indexes(self, database): """Ensure indexes on collections.""" messages = database[self.messages_collection] - messages.ensure_index( + messages.create_index( [('queue', 1), ('priority', 1), ('_id', 1)], background=True, ) - database[self.broadcast_collection].ensure_index([('queue', 1)]) + database[self.broadcast_collection].create_index([('queue', 1)]) routing = database[self.routing_collection] - routing.ensure_index([('queue', 1), ('exchange', 1)]) + routing.create_index([('queue', 1), ('exchange', 1)]) if self.ttl: - messages.ensure_index([('expire_at', 1)], expireAfterSeconds=0) - routing.ensure_index([('expire_at', 1)], expireAfterSeconds=0) + messages.create_index([('expire_at', 1)], expireAfterSeconds=0) + routing.create_index([('expire_at', 1)], expireAfterSeconds=0) - database[self.queues_collection].ensure_index( + database[self.queues_collection].create_index( [('expire_at', 1)], expireAfterSeconds=0) def _create_client(self): @@ -429,7 +438,12 @@ class Channel(virtual.Channel): ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor) return ret - def _get_expire(self, queue, argument): + def _get_message_expire(self, message): + value = message.get('properties', {}).get('expiration') + if value is not None: + return self.get_now() + datetime.timedelta(milliseconds=int(value)) + + def _get_queue_expire(self, queue, argument): """Get expiration header named `argument` of queue definition. Note: @@ -454,15 +468,15 @@ class Channel(virtual.Channel): def _update_queues_expire(self, queue): """Update expiration field on queues documents.""" - expire_at = self._get_expire(queue, 'x-expires') + expire_at = self._get_queue_expire(queue, 'x-expires') if not expire_at: return - self.routing.update( - {'queue': queue}, {'$set': {'expire_at': expire_at}}, multi=True) - self.queues.update( - {'_id': queue}, {'$set': {'expire_at': expire_at}}, multi=True) + self.routing.update_many( + {'queue': queue}, {'$set': {'expire_at': expire_at}}) + self.queues.update_many( + {'_id': queue}, {'$set': {'expire_at': expire_at}}) def get_now(self): """Return current time in UTC.""" |