summaryrefslogtreecommitdiff
path: root/buildscripts/cost_model/database_instance.py
diff options
context:
space:
mode:
Diffstat (limited to 'buildscripts/cost_model/database_instance.py')
-rw-r--r--buildscripts/cost_model/database_instance.py63
1 files changed, 31 insertions, 32 deletions
diff --git a/buildscripts/cost_model/database_instance.py b/buildscripts/cost_model/database_instance.py
index 7fd271709d0..6d1323f8d8b 100644
--- a/buildscripts/cost_model/database_instance.py
+++ b/buildscripts/cost_model/database_instance.py
@@ -30,7 +30,7 @@
from __future__ import annotations
from typing import Sequence, Mapping, NewType, Any
import subprocess
-from pymongo import MongoClient, InsertOne
+from motor.motor_asyncio import AsyncIOMotorClient
from config import DatabaseConfig, RestoreMode
__all__ = ['DatabaseInstance', 'Pipeline']
@@ -44,13 +44,13 @@ class DatabaseInstance:
def __init__(self, config: DatabaseConfig) -> None:
"""Initialize wrapper."""
self.config = config
- self.client = MongoClient(config.connection_string)
- self.database = self.client.get_database(config.database_name)
+ self.client = AsyncIOMotorClient(config.connection_string)
+ self.database = self.client[config.database_name]
def __enter__(self):
if self.config.restore_from_dump == RestoreMode.ALWAYS or (
self.config.restore_from_dump == RestoreMode.ONLY_NEW
- and self.config.database_name not in self.client.database_names()):
+ and self.config.database_name not in self.client.list_database_names()):
self.restore()
return self
@@ -59,9 +59,9 @@ class DatabaseInstance:
self.enable_cascades(False)
self.dump()
- def drop(self):
+ async def drop(self):
"""Drop the database."""
- self.client.drop_database(self.config.database_name)
+ await self.client.drop_database(self.config.database_name)
def restore(self):
"""Restore the database from the 'self.dump_directory'."""
@@ -73,69 +73,68 @@ class DatabaseInstance:
subprocess.run(['mongodump', '--db', self.config.database_name], cwd=self.config.dump_path,
check=True)
- def enable_sbe(self, state: bool) -> None:
+ async def enable_sbe(self, state: bool) -> None:
"""Enable new query execution engine. Throw pymongo.errors.OperationFailure in case of failure."""
- self.client.admin.command({
+ await self.client.admin.command({
'setParameter': 1,
'internalQueryFrameworkControl': 'trySbeEngine' if state else 'forceClassicEngine'
})
- def enable_cascades(self, state: bool) -> None:
+ async def enable_cascades(self, state: bool) -> None:
"""Enable new query optimizer. Requires featureFlagCommonQueryFramework set to True."""
- self.client.admin.command(
+ await self.client.admin.command(
{'configureFailPoint': 'enableExplainInBonsai', 'mode': 'alwaysOn'})
- self.client.admin.command({
+ await self.client.admin.command({
'setParameter': 1,
'internalQueryFrameworkControl': 'tryBonsai' if state else 'trySbeEngine'
})
- def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]:
+ async def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]:
"""Return explain for the given pipeline."""
- return self.database.command(
+ return await self.database.command(
'explain', {'aggregate': collection_name, 'pipeline': pipeline, 'cursor': {}},
verbosity='executionStats')
- def hide_index(self, collection_name: str, index_name: str) -> None:
+ async def hide_index(self, collection_name: str, index_name: str) -> None:
"""Hide the given index from the query optimizer."""
- self.database.command(
+ await self.database.command(
{'collMod': collection_name, 'index': {'name': index_name, 'hidden': True}})
- def unhide_index(self, collection_name: str, index_name: str) -> None:
+ async def unhide_index(self, collection_name: str, index_name: str) -> None:
"""Make the given index visible for the query optimizer."""
- self.database.command(
+ await self.database.command(
{'collMod': collection_name, 'index': {'name': index_name, 'hidden': False}})
- def hide_all_indexes(self, collection_name: str) -> None:
+ async def hide_all_indexes(self, collection_name: str) -> None:
"""Hide all indexes of the given collection from the query optimizer."""
for index in self.database[collection_name].list_indexes():
if index['name'] != '_id_':
- self.hide_index(collection_name, index['name'])
+ await self.hide_index(collection_name, index['name'])
- def unhide_all_indexes(self, collection_name: str) -> None:
+ async def unhide_all_indexes(self, collection_name: str) -> None:
"""Make all indexes of the given collection visible fpr the query optimizer."""
for index in self.database[collection_name].list_indexes():
if index['name'] != '_id_':
- self.unhide_index(collection_name, index['name'])
+ await self.unhide_index(collection_name, index['name'])
- def drop_collection(self, collection_name: str) -> None:
+ async def drop_collection(self, collection_name: str) -> None:
"""Drop collection."""
- self.database[collection_name].drop()
+ await self.database[collection_name].drop()
- def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None:
+ async def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None:
"""Insert documents into the collection with the given name."""
- requests = [InsertOne(doc) for doc in docs]
- self.database[collection_name].bulk_write(requests, ordered=False)
+ await self.database[collection_name].insert_many(docs, ordered=False)
- def get_all_documents(self, collection_name: str):
+ async def get_all_documents(self, collection_name: str):
"""Get all documents from the collection with the given name."""
- return self.database[collection_name].find({})
+ return await self.database[collection_name].find({}).to_list(length=None)
- def get_stats(self, collection_name: str):
+ async def get_stats(self, collection_name: str):
"""Get collection statistics."""
- return self.database.command('collstats', collection_name)
+ return await self.database.command('collstats', collection_name)
- def get_average_document_size(self, collection_name: str) -> float:
+ async def get_average_document_size(self, collection_name: str) -> float:
"""Get average document size for the given collection."""
- stats = self.get_stats(collection_name)
+ stats = await self.get_stats(collection_name)
avg_size = stats.get('avgObjSize')
return avg_size if avg_size is not None else 0