summaryrefslogtreecommitdiff
path: root/requests_cache
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2023-02-28 19:05:55 -0600
committerJordan Cook <jordan.cook.git@proton.me>2023-03-01 15:22:18 -0600
commitc83e4e523008337a6f8f638e65ca217177d8ed5c (patch)
tree40a8b5a6c0773afe4e08bec93743a9650f555dd5 /requests_cache
parent0f330b54e15966ff74582cfa7d794f6b844d324c (diff)
downloadrequests-cache-c83e4e523008337a6f8f638e65ca217177d8ed5c.tar.gz
Share SQLite connection objects among threads and use lock for write operations instead of using thread-local connections
Diffstat (limited to 'requests_cache')
-rw-r--r--requests_cache/backends/sqlite.py46
1 files changed, 26 insertions, 20 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 0a892ec..2b03d4a 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -100,7 +100,7 @@ class SQLiteCache(BaseCache):
def _delete_expired(self):
"""A more efficient implementation deleting expired responses in SQL"""
- with self.responses._lock, self.responses.connection(commit=True) as con:
+ with self.responses.connection(commit=True) as con:
con.execute(
f'DELETE FROM {self.responses.table_name} WHERE expires <= ?', (round(time()),)
)
@@ -181,9 +181,10 @@ class SQLiteDict(BaseStorage):
):
super().__init__(serializer=serializer, **kwargs)
self._can_commit = True
- self._local_context = threading.local()
+ self._connection: Optional[sqlite3.Connection] = None
self._lock = threading.RLock()
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
+ self.connection_kwargs.setdefault('check_same_thread', False)
if use_memory:
self.connection_kwargs['uri'] = True
self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory)
@@ -195,7 +196,7 @@ class SQLiteDict(BaseStorage):
def init_db(self):
"""Initialize the database, if it hasn't already been"""
self.close()
- with self._lock, self.connection() as con:
+ with self.connection(commit=True) as con:
# Add new column to tables created before 1.0
try:
con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER')
@@ -214,26 +215,31 @@ class SQLiteDict(BaseStorage):
@contextmanager
def connection(self, commit=False) -> Iterator[sqlite3.Connection]:
"""Get a thread-local database connection"""
- if not getattr(self._local_context, 'con', None):
+ if not self._connection:
logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
- self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs)
+ self._connection = sqlite3.connect(self.db_path, **self.connection_kwargs)
if self.fast_save:
- self._local_context.con.execute('PRAGMA synchronous = 0;')
+ self._connection.execute('PRAGMA synchronous = 0;')
if self.wal:
- self._local_context.con.execute('PRAGMA journal_mode = wal')
- yield self._local_context.con
+ self._connection.execute('PRAGMA journal_mode = wal')
+
+ # Any write operations need to be run in serial
+ if commit and self._can_commit:
+ self._lock.acquire()
+ yield self._connection
if commit and self._can_commit:
- self._local_context.con.commit()
+ self._connection.commit()
+ self._lock.release()
def close(self):
"""Close any active connections"""
- if getattr(self._local_context, 'con', None):
- self._local_context.con.close()
- self._local_context.con = None
+ if self._connection:
+ self._connection.close()
+ self._connection = None
@contextmanager
def bulk_commit(self):
- """Context manager used to speed up insertion of a large number of records
+ """Insert a large number of records within a single transaction
Example:
@@ -243,13 +249,13 @@ class SQLiteDict(BaseStorage):
... d1[i] = i * 2
"""
- self._can_commit = False
- try:
- yield
- if hasattr(self._local_context, 'con'):
- self._local_context.con.commit()
- finally:
- self._can_commit = True
+ with self._lock:
+ self._can_commit = False
+ try:
+ yield
+ self._connection.commit()
+ finally:
+ self._can_commit = True
def __del__(self):
self.close()