diff options
Diffstat (limited to 'buildscripts/cost_model/database_instance.py')
-rw-r--r-- | buildscripts/cost_model/database_instance.py | 63 |
1 files changed, 31 insertions, 32 deletions
diff --git a/buildscripts/cost_model/database_instance.py b/buildscripts/cost_model/database_instance.py index 7fd271709d0..6d1323f8d8b 100644 --- a/buildscripts/cost_model/database_instance.py +++ b/buildscripts/cost_model/database_instance.py @@ -30,7 +30,7 @@ from __future__ import annotations from typing import Sequence, Mapping, NewType, Any import subprocess -from pymongo import MongoClient, InsertOne +from motor.motor_asyncio import AsyncIOMotorClient from config import DatabaseConfig, RestoreMode __all__ = ['DatabaseInstance', 'Pipeline'] @@ -44,13 +44,13 @@ class DatabaseInstance: def __init__(self, config: DatabaseConfig) -> None: """Initialize wrapper.""" self.config = config - self.client = MongoClient(config.connection_string) - self.database = self.client.get_database(config.database_name) + self.client = AsyncIOMotorClient(config.connection_string) + self.database = self.client[config.database_name] def __enter__(self): if self.config.restore_from_dump == RestoreMode.ALWAYS or ( self.config.restore_from_dump == RestoreMode.ONLY_NEW - and self.config.database_name not in self.client.database_names()): + and self.config.database_name not in self.client.list_database_names()): self.restore() return self @@ -59,9 +59,9 @@ class DatabaseInstance: self.enable_cascades(False) self.dump() - def drop(self): + async def drop(self): """Drop the database.""" - self.client.drop_database(self.config.database_name) + await self.client.drop_database(self.config.database_name) def restore(self): """Restore the database from the 'self.dump_directory'.""" @@ -73,69 +73,68 @@ class DatabaseInstance: subprocess.run(['mongodump', '--db', self.config.database_name], cwd=self.config.dump_path, check=True) - def enable_sbe(self, state: bool) -> None: + async def enable_sbe(self, state: bool) -> None: """Enable new query execution engine. Throw pymongo.errors.OperationFailure in case of failure.""" - self.client.admin.command({ + await self.client.admin.command({ 'setParameter': 1, 'internalQueryFrameworkControl': 'trySbeEngine' if state else 'forceClassicEngine' }) - def enable_cascades(self, state: bool) -> None: + async def enable_cascades(self, state: bool) -> None: """Enable new query optimizer. Requires featureFlagCommonQueryFramework set to True.""" - self.client.admin.command( + await self.client.admin.command( {'configureFailPoint': 'enableExplainInBonsai', 'mode': 'alwaysOn'}) - self.client.admin.command({ + await self.client.admin.command({ 'setParameter': 1, 'internalQueryFrameworkControl': 'tryBonsai' if state else 'trySbeEngine' }) - def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]: + async def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]: """Return explain for the given pipeline.""" - return self.database.command( + return await self.database.command( 'explain', {'aggregate': collection_name, 'pipeline': pipeline, 'cursor': {}}, verbosity='executionStats') - def hide_index(self, collection_name: str, index_name: str) -> None: + async def hide_index(self, collection_name: str, index_name: str) -> None: """Hide the given index from the query optimizer.""" - self.database.command( + await self.database.command( {'collMod': collection_name, 'index': {'name': index_name, 'hidden': True}}) - def unhide_index(self, collection_name: str, index_name: str) -> None: + async def unhide_index(self, collection_name: str, index_name: str) -> None: """Make the given index visible for the query optimizer.""" - self.database.command( + await self.database.command( {'collMod': collection_name, 'index': {'name': index_name, 'hidden': False}}) - def hide_all_indexes(self, collection_name: str) -> None: + async def hide_all_indexes(self, collection_name: str) -> None: """Hide all indexes of the given collection from the query optimizer.""" for index in self.database[collection_name].list_indexes(): if index['name'] != '_id_': - self.hide_index(collection_name, index['name']) + await self.hide_index(collection_name, index['name']) - def unhide_all_indexes(self, collection_name: str) -> None: + async def unhide_all_indexes(self, collection_name: str) -> None: """Make all indexes of the given collection visible fpr the query optimizer.""" for index in self.database[collection_name].list_indexes(): if index['name'] != '_id_': - self.unhide_index(collection_name, index['name']) + await self.unhide_index(collection_name, index['name']) - def drop_collection(self, collection_name: str) -> None: + async def drop_collection(self, collection_name: str) -> None: """Drop collection.""" - self.database[collection_name].drop() + await self.database[collection_name].drop() - def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None: + async def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None: """Insert documents into the collection with the given name.""" - requests = [InsertOne(doc) for doc in docs] - self.database[collection_name].bulk_write(requests, ordered=False) + await self.database[collection_name].insert_many(docs, ordered=False) - def get_all_documents(self, collection_name: str): + async def get_all_documents(self, collection_name: str): """Get all documents from the collection with the given name.""" - return self.database[collection_name].find({}) + return await self.database[collection_name].find({}).to_list(length=None) - def get_stats(self, collection_name: str): + async def get_stats(self, collection_name: str): """Get collection statistics.""" - return self.database.command('collstats', collection_name) + return await self.database.command('collstats', collection_name) - def get_average_document_size(self, collection_name: str) -> float: + async def get_average_document_size(self, collection_name: str) -> float: """Get average document size for the given collection.""" - stats = self.get_stats(collection_name) + stats = await self.get_stats(collection_name) avg_size = stats.get('avgObjSize') return avg_size if avg_size is not None else 0 |