summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <JWCook@users.noreply.github.com>2021-04-21 10:28:42 -0500
committerGitHub <noreply@github.com>2021-04-21 10:28:42 -0500
commit683eb8f6d45b259be60d5470e84c256283d037cc (patch)
tree3024011456fc1d5064562f19917dd981f10c30dc
parentc6d9351ffe605df24182b4c204f6c0000f5dfbb7 (diff)
parenta08c54d524e166d913c7e395e6a36cca76243df4 (diff)
downloadrequests-cache-683eb8f6d45b259be60d5470e84c256283d037cc.tar.gz
Merge pull request #240 from jsemric/use-thread-local-connections-for-sqlite
Use thread local connections for sqlite
-rw-r--r--requests_cache/backends/sqlite.py62
-rw-r--r--tests/integration/test_sqlite.py15
2 files changed, 36 insertions, 41 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index e80497c..8cc301b 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -64,42 +64,25 @@ class DbDict(BaseStorage):
self.fast_save = fast_save
self.table_name = table_name
- self._bulk_commit = False
self._can_commit = True
- self._pending_connection = None
- self._lock = threading.RLock()
+ self._local_context = threading.local()
with self.connection() as con:
con.execute("create table if not exists `%s` (key PRIMARY KEY, value)" % self.table_name)
@contextmanager
def connection(self, commit_on_success=False):
- logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
- with self._lock:
- if self._bulk_commit:
- if self._pending_connection is None:
- self._pending_connection = sqlite3.connect(self.db_path, **self.connection_kwargs)
- con = self._pending_connection
- else:
- con = sqlite3.connect(self.db_path, **self.connection_kwargs)
- try:
- if self.fast_save:
- con.execute("PRAGMA synchronous = 0;")
- yield con
- if commit_on_success and self._can_commit:
- con.commit()
- finally:
- if not self._bulk_commit:
- con.close()
-
- def commit(self, force=False):
- """
- Commits pending transaction if :attr:`can_commit` or `force` is `True`
-
- :param force: force commit, ignore :attr:`can_commit`
- """
- if force or self._can_commit:
- if self._pending_connection is not None:
- self._pending_connection.commit()
+ if not hasattr(self._local_context, "con"):
+ logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
+ self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs)
+ if self.fast_save:
+ self._local_context.con.execute("PRAGMA synchronous = 0;")
+ yield self._local_context.con
+ if commit_on_success and self._can_commit:
+ self._local_context.con.commit()
+
+ def __del__(self):
+ if hasattr(self._local_context, "con"):
+ self._local_context.con.close()
@contextmanager
def bulk_commit(self):
@@ -113,24 +96,21 @@ class DbDict(BaseStorage):
... d1[i] = i * 2
"""
- self._bulk_commit = True
self._can_commit = False
try:
yield
- self.commit(True)
+ if hasattr(self._local_context, "con"):
+ self._local_context.con.commit()
finally:
- self._bulk_commit = False
self._can_commit = True
- if self._pending_connection is not None:
- self._pending_connection.close()
- self._pending_connection = None
def __getitem__(self, key):
with self.connection() as con:
row = con.execute("select value from `%s` where key=?" % self.table_name, (key,)).fetchone()
- if not row:
- raise KeyError
- return row[0]
+ # raise error after the with block, otherwise the connection will be locked
+ if not row:
+ raise KeyError
+ return row[0]
def __setitem__(self, key, item):
with self.connection(True) as con:
@@ -142,8 +122,8 @@ class DbDict(BaseStorage):
def __delitem__(self, key):
with self.connection(True) as con:
cur = con.execute("delete from `%s` where key=?" % self.table_name, (key,))
- if not cur.rowcount:
- raise KeyError
+ if not cur.rowcount:
+ raise KeyError
def __iter__(self):
with self.connection() as con:
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 18ddefd..a6fd50d 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -77,6 +77,21 @@ class SQLiteTestCase(BaseStorageTestCase):
do_test_for(self.storage_class(self.NAMESPACE, fast_save=True))
do_test_for(self.storage_class(self.NAMESPACE, self.TABLES[1], fast_save=True))
+ def test_noop(self):
+ def do_noop_bulk(d):
+ with d.bulk_commit():
+ pass
+ del d
+
+ d = self.storage_class(self.NAMESPACE)
+ t = Thread(target=do_noop_bulk, args=(d,))
+ t.start()
+ t.join()
+
+ # make sure connection is not closed by the thread
+ d[0] = 0
+ assert str(d) == "{0: 0}"
+
class DbDictTestCase(SQLiteTestCase, unittest.TestCase):
def __init__(self, *args, **kwargs):