summaryrefslogtreecommitdiff
path: root/kombu
diff options
context:
space:
mode:
authorJakub Pieńkowski <8525083+Jakski@users.noreply.github.com>2022-04-18 14:10:38 +0200
committerGitHub <noreply@github.com>2022-04-18 18:10:38 +0600
commit661d92222eb485fd9131543c8a0a224824f5e0f0 (patch)
tree67c0469003d4b7e734dd798de13bda3bb27f1992 /kombu
parent0f9f554b7cb9a307b07bec74688095053034fd57 (diff)
downloadkombu-661d92222eb485fd9131543c8a0a224824f5e0f0.tar.gz
Support pymongo 4.x (#1536)
* Support pymongo 4.x * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix problems detected by CI Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Diffstat (limited to 'kombu')
-rw-r--r--kombu/transport/mongodb.py74
1 files changed, 44 insertions, 30 deletions
diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py
index 8285b70f..b923f5f4 100644
--- a/kombu/transport/mongodb.py
+++ b/kombu/transport/mongodb.py
@@ -65,11 +65,10 @@ class BroadcastCursor:
def __init__(self, cursor):
self._cursor = cursor
-
self.purge(rewind=False)
def get_size(self):
- return self._cursor.count() - self._offset
+ return self._cursor.collection.count_documents({}) - self._offset
def close(self):
self._cursor.close()
@@ -79,7 +78,7 @@ class BroadcastCursor:
self._cursor.rewind()
# Fast forward the cursor past old events
- self._offset = self._cursor.count()
+ self._offset = self._cursor.collection.count_documents({})
self._cursor = self._cursor.skip(self._offset)
def __iter__(self):
@@ -151,11 +150,17 @@ class Channel(virtual.Channel):
def _new_queue(self, queue, **kwargs):
if self.ttl:
- self.queues.update(
+ self.queues.update_one(
{'_id': queue},
- {'_id': queue,
- 'options': kwargs,
- 'expire_at': self._get_expire(kwargs, 'x-expires')},
+ {
+ '$set': {
+ '_id': queue,
+ 'options': kwargs,
+ 'expire_at': self._get_queue_expire(
+ kwargs, 'x-expires'
+ ),
+ },
+ },
upsert=True)
def _get(self, queue):
@@ -165,10 +170,9 @@ class Channel(virtual.Channel):
except StopIteration:
msg = None
else:
- msg = self.messages.find_and_modify(
- query={'queue': queue},
+ msg = self.messages.find_one_and_delete(
+ {'queue': queue},
sort=[('priority', pymongo.ASCENDING)],
- remove=True,
)
if self.ttl:
@@ -188,7 +192,7 @@ class Channel(virtual.Channel):
if queue in self._fanout_queues:
return self._get_broadcast_cursor(queue).get_size()
- return self.messages.find({'queue': queue}).count()
+ return self.messages.count_documents({'queue': queue})
def _put(self, queue, message, **kwargs):
data = {
@@ -198,13 +202,18 @@ class Channel(virtual.Channel):
}
if self.ttl:
- data['expire_at'] = self._get_expire(queue, 'x-message-ttl')
+ data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl')
+ msg_expire = self._get_message_expire(message)
+ if msg_expire is not None and (
+ data['expire_at'] is None or msg_expire < data['expire_at']
+ ):
+ data['expire_at'] = msg_expire
- self.messages.insert(data)
+ self.messages.insert_one(data)
def _put_fanout(self, exchange, message, routing_key, **kwargs):
- self.broadcast.insert({'payload': dumps(message),
- 'queue': exchange})
+ self.broadcast.insert_one({'payload': dumps(message),
+ 'queue': exchange})
def _purge(self, queue):
size = self._size(queue)
@@ -243,9 +252,9 @@ class Channel(virtual.Channel):
data = lookup.copy()
if self.ttl:
- data['expire_at'] = self._get_expire(queue, 'x-expires')
+ data['expire_at'] = self._get_queue_expire(queue, 'x-expires')
- self.routing.update(lookup, data, upsert=True)
+ self.routing.update_one(lookup, {'$set': data}, upsert=True)
def queue_delete(self, queue, **kwargs):
self.routing.remove({'queue': queue})
@@ -348,7 +357,7 @@ class Channel(virtual.Channel):
def _create_broadcast(self, database):
"""Create capped collection for broadcast messages."""
- if self.broadcast_collection in database.collection_names():
+ if self.broadcast_collection in database.list_collection_names():
return
database.create_collection(self.broadcast_collection,
@@ -358,20 +367,20 @@ class Channel(virtual.Channel):
def _ensure_indexes(self, database):
"""Ensure indexes on collections."""
messages = database[self.messages_collection]
- messages.ensure_index(
+ messages.create_index(
[('queue', 1), ('priority', 1), ('_id', 1)], background=True,
)
- database[self.broadcast_collection].ensure_index([('queue', 1)])
+ database[self.broadcast_collection].create_index([('queue', 1)])
routing = database[self.routing_collection]
- routing.ensure_index([('queue', 1), ('exchange', 1)])
+ routing.create_index([('queue', 1), ('exchange', 1)])
if self.ttl:
- messages.ensure_index([('expire_at', 1)], expireAfterSeconds=0)
- routing.ensure_index([('expire_at', 1)], expireAfterSeconds=0)
+ messages.create_index([('expire_at', 1)], expireAfterSeconds=0)
+ routing.create_index([('expire_at', 1)], expireAfterSeconds=0)
- database[self.queues_collection].ensure_index(
+ database[self.queues_collection].create_index(
[('expire_at', 1)], expireAfterSeconds=0)
def _create_client(self):
@@ -429,7 +438,12 @@ class Channel(virtual.Channel):
ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
return ret
- def _get_expire(self, queue, argument):
+ def _get_message_expire(self, message):
+ value = message.get('properties', {}).get('expiration')
+ if value is not None:
+ return self.get_now() + datetime.timedelta(milliseconds=int(value))
+
+ def _get_queue_expire(self, queue, argument):
"""Get expiration header named `argument` of queue definition.
Note:
@@ -454,15 +468,15 @@ class Channel(virtual.Channel):
def _update_queues_expire(self, queue):
"""Update expiration field on queues documents."""
- expire_at = self._get_expire(queue, 'x-expires')
+ expire_at = self._get_queue_expire(queue, 'x-expires')
if not expire_at:
return
- self.routing.update(
- {'queue': queue}, {'$set': {'expire_at': expire_at}}, multi=True)
- self.queues.update(
- {'_id': queue}, {'$set': {'expire_at': expire_at}}, multi=True)
+ self.routing.update_many(
+ {'queue': queue}, {'$set': {'expire_at': expire_at}})
+ self.queues.update_many(
+ {'_id': queue}, {'$set': {'expire_at': expire_at}})
def get_now(self):
"""Return current time in UTC."""