summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2023-02-28 19:34:02 -0600
committerJordan Cook <jordan.cook.git@proton.me>2023-03-01 15:22:18 -0600
commit4ceda12dba80865b1efd2318daf040fe24a41ab3 (patch)
tree40a8b5a6c0773afe4e08bec93743a9650f555dd5
parentdd56ee17d77523426daac086ae02554368e3d8f5 (diff)
parentc83e4e523008337a6f8f638e65ca217177d8ed5c (diff)
downloadrequests-cache-4ceda12dba80865b1efd2318daf040fe24a41ab3.tar.gz
Merge pull request #796 from requests-cache/sqlite-locks
Share SQLite connection objects among threads, and lock write operations
-rw-r--r--HISTORY.md2
-rw-r--r--requests_cache/backends/base.py9
-rw-r--r--requests_cache/backends/sqlite.py46
-rw-r--r--tests/conftest.py2
-rw-r--r--tests/integration/base_cache_test.py30
-rw-r--r--tests/integration/base_storage_test.py14
-rw-r--r--tests/integration/test_sqlite.py2
7 files changed, 73 insertions, 32 deletions
diff --git a/HISTORY.md b/HISTORY.md
index 6ffdfbe..4c294d2 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -55,6 +55,7 @@
* Add `ttl_offset` argument to add a delay between cache expiration and deletion
* **SQLite**:
* Improve performance for removing expired responses with `delete()`
+ * Improve performance (slightly) with a large number of threads and high request rate
* Add `count()` method to count responses, with option to exclude expired responses (performs a fast indexed count instead of slower in-memory filtering)
* Add `size()` method to get estimated size of the database (including in-memory databases)
* Add `sorted()` method with sorting and other query options
@@ -110,6 +111,7 @@
🪲 **Bugfixes:**
* Fix usage of memory backend with `install_cache()`
+* Fix an uncommon `OperationalError: database is locked` in SQLite backend
* Fix issue on Windows with occasional missing `CachedResponse.created_at` timestamp
* Add `CachedRequest.path_url` property for compatibility with `RequestEncodingMixin`
* Fix potential `AttributeError` due to undetected imports when requests-cache is bundled in a PyInstaller package
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 6193483..d64abc1 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -173,7 +173,14 @@ class BaseCache:
delete_keys.append(response.cache_key)
logger.debug(f'Deleting up to {len(delete_keys)} responses')
- self.responses.bulk_delete(delete_keys)
+ # For some backends, we don't want to use bulk_delete if there's only one key
+ if len(delete_keys) == 1:
+ try:
+ del self.responses[delete_keys[0]]
+ except KeyError:
+ pass
+ else:
+ self.responses.bulk_delete(delete_keys)
self._prune_redirects()
def _prune_redirects(self):
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 0a892ec..2b03d4a 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -100,7 +100,7 @@ class SQLiteCache(BaseCache):
def _delete_expired(self):
"""A more efficient implementation deleting expired responses in SQL"""
- with self.responses._lock, self.responses.connection(commit=True) as con:
+ with self.responses.connection(commit=True) as con:
con.execute(
f'DELETE FROM {self.responses.table_name} WHERE expires <= ?', (round(time()),)
)
@@ -181,9 +181,10 @@ class SQLiteDict(BaseStorage):
):
super().__init__(serializer=serializer, **kwargs)
self._can_commit = True
- self._local_context = threading.local()
+ self._connection: Optional[sqlite3.Connection] = None
self._lock = threading.RLock()
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
+ self.connection_kwargs.setdefault('check_same_thread', False)
if use_memory:
self.connection_kwargs['uri'] = True
self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory)
@@ -195,7 +196,7 @@ class SQLiteDict(BaseStorage):
def init_db(self):
"""Initialize the database, if it hasn't already been"""
self.close()
- with self._lock, self.connection() as con:
+ with self.connection(commit=True) as con:
# Add new column to tables created before 1.0
try:
con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER')
@@ -214,26 +215,31 @@ class SQLiteDict(BaseStorage):
@contextmanager
def connection(self, commit=False) -> Iterator[sqlite3.Connection]:
"""Get a thread-local database connection"""
- if not getattr(self._local_context, 'con', None):
+ if not self._connection:
logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
- self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs)
+ self._connection = sqlite3.connect(self.db_path, **self.connection_kwargs)
if self.fast_save:
- self._local_context.con.execute('PRAGMA synchronous = 0;')
+ self._connection.execute('PRAGMA synchronous = 0;')
if self.wal:
- self._local_context.con.execute('PRAGMA journal_mode = wal')
- yield self._local_context.con
+ self._connection.execute('PRAGMA journal_mode = wal')
+
+ # Any write operations need to be run in serial
+ if commit and self._can_commit:
+ self._lock.acquire()
+ yield self._connection
if commit and self._can_commit:
- self._local_context.con.commit()
+ self._connection.commit()
+ self._lock.release()
def close(self):
"""Close any active connections"""
- if getattr(self._local_context, 'con', None):
- self._local_context.con.close()
- self._local_context.con = None
+ if self._connection:
+ self._connection.close()
+ self._connection = None
@contextmanager
def bulk_commit(self):
- """Context manager used to speed up insertion of a large number of records
+ """Insert a large number of records within a single transaction
Example:
@@ -243,13 +249,13 @@ class SQLiteDict(BaseStorage):
... d1[i] = i * 2
"""
- self._can_commit = False
- try:
- yield
- if hasattr(self._local_context, 'con'):
- self._local_context.con.commit()
- finally:
- self._can_commit = True
+ with self._lock:
+ self._can_commit = False
+ try:
+ yield
+ self._connection.commit()
+ finally:
+ self._can_commit = True
def __del__(self):
self.close()
diff --git a/tests/conftest.py b/tests/conftest.py
index c91d6db..c53d866 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -43,7 +43,7 @@ logger = getLogger(__name__)
# Allow running longer stress tests with an environment variable
STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1'))
-N_WORKERS = 2 * STRESS_TEST_MULTIPLIER
+N_WORKERS = 5 * STRESS_TEST_MULTIPLIER
N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER
N_REQUESTS_PER_ITERATION = 10 + 10 * STRESS_TEST_MULTIPLIER
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index cdc3a98..d37bfa7 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse
import pytest
import requests
-from requests import PreparedRequest, Session
+from requests import ConnectionError, PreparedRequest, Session
from requests_cache import ALL_METHODS, CachedResponse, CachedSession
from requests_cache.backends import BaseCache
@@ -36,6 +36,7 @@ from tests.conftest import (
logger = getLogger(__name__)
+
VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]
@@ -355,10 +356,14 @@ class BaseCacheTest:
"""
start = time()
url = httpbin('anything')
- self.init_session(clear=True)
- session_factory = partial(self.init_session, clear=False)
- request_func = partial(_send_request, session_factory, url)
+ # For multithreading, we can share a session object, but we can't for multiprocessing
+ session = self.init_session(clear=True, expire_after=1)
+ if executor_class is ProcessPoolExecutor:
+ session = None
+ session_factory = partial(self.init_session, clear=False, expire_after=1)
+
+ request_func = partial(_send_request, session, session_factory, url)
with executor_class(max_workers=N_WORKERS) as executor:
_ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))
@@ -373,7 +378,7 @@ class BaseCacheTest:
)
-def _send_request(session_factory, url, _=None):
+def _send_request(session, session_factory, url, _=None):
"""Concurrent request function for stress tests. Defined in module scope so it can be serialized
to multiple processes.
"""
@@ -381,6 +386,15 @@ def _send_request(session_factory, url, _=None):
n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4)
i = randint(1, n_unique_responses)
- session = session_factory()
- sleep(0.1)
- return session.get(url, params={f'key_{i}': f'value_{i}'})
+ # Threads can share a session object, but processes will create their own session because it
+ # can't be serialized
+ if session is None:
+ session = session_factory()
+
+ sleep(0.01)
+ try:
+ return session.get(url, params={f'key_{i}': f'value_{i}'})
+ # Sometimes the local http server is the bottleneck here; just retry once
+ except ConnectionError:
+ sleep(0.1)
+ return session.get(url, params={f'key_{i}': f'value_{i}'})
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index ce881e8..4870c02 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -1,4 +1,5 @@
"""Common tests to run for all backends (BaseStorage subclasses)"""
+from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Dict, Type
@@ -7,7 +8,7 @@ from attrs import define, field
from requests_cache.backends import BaseStorage
from requests_cache.models import CachedResponse
-from tests.conftest import CACHE_NAME
+from tests.conftest import CACHE_NAME, N_ITERATIONS, N_REQUESTS_PER_ITERATION, N_WORKERS
class BaseStorageTest:
@@ -158,6 +159,17 @@ class BaseStorageTest:
for i in range(10):
assert f'key_{i}' in str(cache)
+ def test_concurrency(self):
+ """Test a large number of concurrent write operations for each backend"""
+ cache = self.init_cache()
+
+ def write(i):
+ cache[f'key_{i}'] = f'value_{i}'
+
+ n_iterations = N_ITERATIONS * N_REQUESTS_PER_ITERATION * 10
+ with ThreadPoolExecutor(max_workers=N_WORKERS * 2) as executor:
+ _ = list(executor.map(write, range(n_iterations)))
+
@define
class BasicDataclass:
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index af94610..04e4b61 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -32,7 +32,7 @@ class TestSQLiteDict(BaseStorageTest):
def test_connection_kwargs(self, mock_sqlite):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???')
- mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5)
+ mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5, check_same_thread=False)
def test_use_cache_dir(self):
relative_path = self.storage_class(CACHE_NAME).db_path