From c83e4e523008337a6f8f638e65ca217177d8ed5c Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Tue, 28 Feb 2023 19:05:55 -0600 Subject: Share SQLite connection objects among threads and use lock for write operations instead of using thread-local connections --- requests_cache/backends/sqlite.py | 46 ++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 20 deletions(-) (limited to 'requests_cache') diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 0a892ec..2b03d4a 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -100,7 +100,7 @@ class SQLiteCache(BaseCache): def _delete_expired(self): """A more efficient implementation deleting expired responses in SQL""" - with self.responses._lock, self.responses.connection(commit=True) as con: + with self.responses.connection(commit=True) as con: con.execute( f'DELETE FROM {self.responses.table_name} WHERE expires <= ?', (round(time()),) ) @@ -181,9 +181,10 @@ class SQLiteDict(BaseStorage): ): super().__init__(serializer=serializer, **kwargs) self._can_commit = True - self._local_context = threading.local() + self._connection: Optional[sqlite3.Connection] = None self._lock = threading.RLock() self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) + self.connection_kwargs.setdefault('check_same_thread', False) if use_memory: self.connection_kwargs['uri'] = True self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory) @@ -195,7 +196,7 @@ class SQLiteDict(BaseStorage): def init_db(self): """Initialize the database, if it hasn't already been""" self.close() - with self._lock, self.connection() as con: + with self.connection(commit=True) as con: # Add new column to tables created before 1.0 try: con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER') @@ -214,26 +215,31 @@ class SQLiteDict(BaseStorage): @contextmanager def connection(self, commit=False) -> Iterator[sqlite3.Connection]: """Get a thread-local database connection""" - if not getattr(self._local_context, 'con', None): + if not self._connection: logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') - self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs) + self._connection = sqlite3.connect(self.db_path, **self.connection_kwargs) if self.fast_save: - self._local_context.con.execute('PRAGMA synchronous = 0;') + self._connection.execute('PRAGMA synchronous = 0;') if self.wal: - self._local_context.con.execute('PRAGMA journal_mode = wal') - yield self._local_context.con + self._connection.execute('PRAGMA journal_mode = wal') + + # Any write operations need to be run in serial + if commit and self._can_commit: + self._lock.acquire() + yield self._connection if commit and self._can_commit: - self._local_context.con.commit() + self._connection.commit() + self._lock.release() def close(self): """Close any active connections""" - if getattr(self._local_context, 'con', None): - self._local_context.con.close() - self._local_context.con = None + if self._connection: + self._connection.close() + self._connection = None @contextmanager def bulk_commit(self): - """Context manager used to speed up insertion of a large number of records + """Insert a large number of records within a single transaction Example: @@ -243,13 +249,13 @@ class SQLiteDict(BaseStorage): ... d1[i] = i * 2 """ - self._can_commit = False - try: - yield - if hasattr(self._local_context, 'con'): - self._local_context.con.commit() - finally: - self._can_commit = True + with self._lock: + self._can_commit = False + try: + yield + self._connection.commit() + finally: + self._can_commit = True def __del__(self): self.close() -- cgit v1.2.1