diff options
author | Jordan Cook <JWCook@users.noreply.github.com> | 2021-04-21 10:28:42 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-21 10:28:42 -0500 |
commit | 683eb8f6d45b259be60d5470e84c256283d037cc (patch) | |
tree | 3024011456fc1d5064562f19917dd981f10c30dc | |
parent | c6d9351ffe605df24182b4c204f6c0000f5dfbb7 (diff) | |
parent | a08c54d524e166d913c7e395e6a36cca76243df4 (diff) | |
download | requests-cache-683eb8f6d45b259be60d5470e84c256283d037cc.tar.gz |
Merge pull request #240 from jsemric/use-thread-local-connections-for-sqlite
Use thread local connections for sqlite
-rw-r--r-- | requests_cache/backends/sqlite.py | 62 | ||||
-rw-r--r-- | tests/integration/test_sqlite.py | 15 |
2 files changed, 36 insertions, 41 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index e80497c..8cc301b 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -64,42 +64,25 @@ class DbDict(BaseStorage): self.fast_save = fast_save self.table_name = table_name - self._bulk_commit = False self._can_commit = True - self._pending_connection = None - self._lock = threading.RLock() + self._local_context = threading.local() with self.connection() as con: con.execute("create table if not exists `%s` (key PRIMARY KEY, value)" % self.table_name) @contextmanager def connection(self, commit_on_success=False): - logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') - with self._lock: - if self._bulk_commit: - if self._pending_connection is None: - self._pending_connection = sqlite3.connect(self.db_path, **self.connection_kwargs) - con = self._pending_connection - else: - con = sqlite3.connect(self.db_path, **self.connection_kwargs) - try: - if self.fast_save: - con.execute("PRAGMA synchronous = 0;") - yield con - if commit_on_success and self._can_commit: - con.commit() - finally: - if not self._bulk_commit: - con.close() - - def commit(self, force=False): - """ - Commits pending transaction if :attr:`can_commit` or `force` is `True` - - :param force: force commit, ignore :attr:`can_commit` - """ - if force or self._can_commit: - if self._pending_connection is not None: - self._pending_connection.commit() + if not hasattr(self._local_context, "con"): + logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') + self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs) + if self.fast_save: + self._local_context.con.execute("PRAGMA synchronous = 0;") + yield self._local_context.con + if commit_on_success and self._can_commit: + self._local_context.con.commit() + + def __del__(self): + if hasattr(self._local_context, "con"): + self._local_context.con.close() @contextmanager def bulk_commit(self): @@ -113,24 +96,21 @@ class DbDict(BaseStorage): ... d1[i] = i * 2 """ - self._bulk_commit = True self._can_commit = False try: yield - self.commit(True) + if hasattr(self._local_context, "con"): + self._local_context.con.commit() finally: - self._bulk_commit = False self._can_commit = True - if self._pending_connection is not None: - self._pending_connection.close() - self._pending_connection = None def __getitem__(self, key): with self.connection() as con: row = con.execute("select value from `%s` where key=?" % self.table_name, (key,)).fetchone() - if not row: - raise KeyError - return row[0] + # raise error after the with block, otherwise the connection will be locked + if not row: + raise KeyError + return row[0] def __setitem__(self, key, item): with self.connection(True) as con: @@ -142,8 +122,8 @@ class DbDict(BaseStorage): def __delitem__(self, key): with self.connection(True) as con: cur = con.execute("delete from `%s` where key=?" % self.table_name, (key,)) - if not cur.rowcount: - raise KeyError + if not cur.rowcount: + raise KeyError def __iter__(self): with self.connection() as con: diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index 18ddefd..a6fd50d 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -77,6 +77,21 @@ class SQLiteTestCase(BaseStorageTestCase): do_test_for(self.storage_class(self.NAMESPACE, fast_save=True)) do_test_for(self.storage_class(self.NAMESPACE, self.TABLES[1], fast_save=True)) + def test_noop(self): + def do_noop_bulk(d): + with d.bulk_commit(): + pass + del d + + d = self.storage_class(self.NAMESPACE) + t = Thread(target=do_noop_bulk, args=(d,)) + t.start() + t.join() + + # make sure connection is not closed by the thread + d[0] = 0 + assert str(d) == "{0: 0}" + class DbDictTestCase(SQLiteTestCase, unittest.TestCase): def __init__(self, *args, **kwargs): |