summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNed Batchelder <ned@nedbatchelder.com>2022-11-27 13:38:08 -0500
committerNed Batchelder <ned@nedbatchelder.com>2022-11-27 13:38:08 -0500
commit9fa587276840246f20622debcbf9b8a7cd0e7960 (patch)
tree2f90a708ceffd36addcfba35650f2ae98ec6e88e
parent771e299c153ee20181cbb286a30dfa1450ed9e99 (diff)
downloadpython-coveragepy-git-9fa587276840246f20622debcbf9b8a7cd0e7960.tar.gz
refactor: ensure all sqlite cursors are closed
-rw-r--r--coverage/sqldata.py224
1 files changed, 131 insertions, 93 deletions
diff --git a/coverage/sqldata.py b/coverage/sqldata.py
index f1b192d1..4caa13d2 100644
--- a/coverage/sqldata.py
+++ b/coverage/sqldata.py
@@ -4,6 +4,7 @@
"""SQLite coverage data."""
import collections
+import contextlib
import datetime
import functools
import glob
@@ -287,19 +288,21 @@ class CoverageData(AutoReprMixin):
)
)
- for row in db.execute("select value from meta where key = 'has_arcs'"):
- self._has_arcs = bool(int(row[0]))
- self._has_lines = not self._has_arcs
+ with db.execute("select value from meta where key = 'has_arcs'") as cur:
+ for row in cur:
+ self._has_arcs = bool(int(row[0]))
+ self._has_lines = not self._has_arcs
- for file_id, path in db.execute("select id, path from file"):
- self._file_map[path] = file_id
+ with db.execute("select id, path from file") as cur:
+ for file_id, path in cur:
+ self._file_map[path] = file_id
def _init_db(self, db):
"""Write the initial contents of the database."""
if self._debug.should("dataio"):
self._debug.write(f"Initing data file {self._filename!r}")
db.executescript(SCHEMA)
- db.execute("insert into coverage_schema (version) values (?)", (SCHEMA_VERSION,))
+ db.execute_void("insert into coverage_schema (version) values (?)", (SCHEMA_VERSION,))
# When writing metadata, avoid information that will needlessly change
# the hash of the data file, unless we're debugging processes.
@@ -311,7 +314,7 @@ class CoverageData(AutoReprMixin):
("sys_argv", str(getattr(sys, "argv", None))),
("when", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
])
- db.executemany("insert or ignore into meta (key, value) values (?, ?)", meta_data)
+ db.executemany_void("insert or ignore into meta (key, value) values (?, ?)", meta_data)
def _connect(self):
"""Get the SqliteDb object to use."""
@@ -324,8 +327,8 @@ class CoverageData(AutoReprMixin):
return False
try:
with self._connect() as con:
- rows = con.execute("select * from file limit 1")
- return bool(list(rows))
+ with con.execute("select * from file limit 1") as cur:
+ return bool(list(cur))
except CoverageException:
return False
@@ -475,11 +478,12 @@ class CoverageData(AutoReprMixin):
linemap = nums_to_numbits(linenos)
file_id = self._file_id(filename, add=True)
query = "select numbits from line_bits where file_id = ? and context_id = ?"
- existing = list(con.execute(query, (file_id, self._current_context_id)))
+ with con.execute(query, (file_id, self._current_context_id)) as cur:
+ existing = list(cur)
if existing:
linemap = numbits_union(linemap, existing[0][0])
- con.execute(
+ con.execute_void(
"insert or replace into line_bits " +
" (file_id, context_id, numbits) values (?, ?, ?)",
(file_id, self._current_context_id, linemap),
@@ -508,7 +512,7 @@ class CoverageData(AutoReprMixin):
for filename, arcs in arc_data.items():
file_id = self._file_id(filename, add=True)
data = [(file_id, self._current_context_id, fromno, tono) for fromno, tono in arcs]
- con.executemany(
+ con.executemany_void(
"insert or ignore into arc " +
"(file_id, context_id, fromno, tono) values (?, ?, ?, ?)",
data,
@@ -530,7 +534,7 @@ class CoverageData(AutoReprMixin):
self._has_lines = lines
self._has_arcs = arcs
with self._connect() as con:
- con.execute(
+ con.execute_void(
"insert or ignore into meta (key, value) values (?, ?)",
("has_arcs", str(int(arcs)))
)
@@ -564,7 +568,7 @@ class CoverageData(AutoReprMixin):
)
)
elif plugin_name:
- con.execute(
+ con.execute_void(
"insert into tracer (file_id, tracer) values (?, ?)",
(file_id, plugin_name)
)
@@ -622,48 +626,46 @@ class CoverageData(AutoReprMixin):
other_data.read()
with other_data._connect() as con:
# Get files data.
- cur = con.execute("select path from file")
- files = {path: aliases.map(path) for (path,) in cur}
- cur.close()
+ with con.execute("select path from file") as cur:
+ files = {path: aliases.map(path) for (path,) in cur}
# Get contexts data.
- cur = con.execute("select context from context")
- contexts = [context for (context,) in cur]
- cur.close()
+ with con.execute("select context from context") as cur:
+ contexts = [context for (context,) in cur]
# Get arc data.
- cur = con.execute(
+ with con.execute(
"select file.path, context.context, arc.fromno, arc.tono " +
"from arc " +
"inner join file on file.id = arc.file_id " +
"inner join context on context.id = arc.context_id"
- )
- arcs = [(files[path], context, fromno, tono) for (path, context, fromno, tono) in cur]
- cur.close()
+ ) as cur:
+ arcs = [
+ (files[path], context, fromno, tono)
+ for (path, context, fromno, tono) in cur
+ ]
# Get line data.
- cur = con.execute(
+ with con.execute(
"select file.path, context.context, line_bits.numbits " +
"from line_bits " +
"inner join file on file.id = line_bits.file_id " +
"inner join context on context.id = line_bits.context_id"
- )
- lines = {}
- for path, context, numbits in cur:
- key = (files[path], context)
- if key in lines:
- numbits = numbits_union(lines[key], numbits)
- lines[key] = numbits
- cur.close()
+ ) as cur:
+ lines = {}
+ for path, context, numbits in cur:
+ key = (files[path], context)
+ if key in lines:
+ numbits = numbits_union(lines[key], numbits)
+ lines[key] = numbits
# Get tracer data.
- cur = con.execute(
+ with con.execute(
"select file.path, tracer " +
"from tracer " +
"inner join file on file.id = tracer.file_id"
- )
- tracers = {files[path]: tracer for (path, tracer) in cur}
- cur.close()
+ ) as cur:
+ tracers = {files[path]: tracer for (path, tracer) in cur}
with self._connect() as con:
con.con.isolation_level = "IMMEDIATE"
@@ -672,33 +674,31 @@ class CoverageData(AutoReprMixin):
# to have an empty string tracer. Since Sqlite does not support
# full outer joins, we have to make two queries to fill the
# dictionary.
- this_tracers = {path: "" for path, in con.execute("select path from file")}
- this_tracers.update({
- aliases.map(path): tracer
- for path, tracer in con.execute(
- "select file.path, tracer from tracer " +
- "inner join file on file.id = tracer.file_id"
- )
- })
+ with con.execute("select path from file") as cur:
+ this_tracers = {path: "" for path, in cur}
+ with con.execute(
+ "select file.path, tracer from tracer " +
+ "inner join file on file.id = tracer.file_id"
+ ) as cur:
+ this_tracers.update({
+ aliases.map(path): tracer
+ for path, tracer in cur
+ })
# Create all file and context rows in the DB.
- con.executemany(
+ con.executemany_void(
"insert or ignore into file (path) values (?)",
((file,) for file in files.values())
)
- file_ids = {
- path: id
- for id, path in con.execute("select id, path from file")
- }
+ with con.execute("select id, path from file") as cur:
+ file_ids = {path: id for id, path in cur}
self._file_map.update(file_ids)
- con.executemany(
+ con.executemany_void(
"insert or ignore into context (context) values (?)",
((context,) for context in contexts)
)
- context_ids = {
- context: id
- for id, context in con.execute("select id, context from context")
- }
+ with con.execute("select id, context from context") as cur:
+ context_ids = {context: id for id, context in cur}
# Prepare tracers and fail, if a conflict is found.
# tracer_paths is used to ensure consistency over the tracer data
@@ -725,24 +725,23 @@ class CoverageData(AutoReprMixin):
)
# Get line data.
- cur = con.execute(
+ with con.execute(
"select file.path, context.context, line_bits.numbits " +
"from line_bits " +
"inner join file on file.id = line_bits.file_id " +
"inner join context on context.id = line_bits.context_id"
- )
- for path, context, numbits in cur:
- key = (aliases.map(path), context)
- if key in lines:
- numbits = numbits_union(lines[key], numbits)
- lines[key] = numbits
- cur.close()
+ ) as cur:
+ for path, context, numbits in cur:
+ key = (aliases.map(path), context)
+ if key in lines:
+ numbits = numbits_union(lines[key], numbits)
+ lines[key] = numbits
if arcs:
self._choose_lines_or_arcs(arcs=True)
# Write the combined data.
- con.executemany(
+ con.executemany_void(
"insert or ignore into arc " +
"(file_id, context_id, fromno, tono) values (?, ?, ?, ?)",
arc_rows
@@ -750,8 +749,8 @@ class CoverageData(AutoReprMixin):
if lines:
self._choose_lines_or_arcs(lines=True)
- con.execute("delete from line_bits")
- con.executemany(
+ con.execute_void("delete from line_bits")
+ con.executemany_void(
"insert into line_bits " +
"(file_id, context_id, numbits) values (?, ?, ?)",
[
@@ -759,7 +758,7 @@ class CoverageData(AutoReprMixin):
for (file, context), numbits in lines.items()
]
)
- con.executemany(
+ con.executemany_void(
"insert or ignore into tracer (file_id, tracer) values (?, ?)",
((file_ids[filename], tracer) for filename, tracer in tracer_map.items())
)
@@ -828,7 +827,8 @@ class CoverageData(AutoReprMixin):
"""
self._start_using()
with self._connect() as con:
- contexts = {row[0] for row in con.execute("select distinct(context) from context")}
+ with con.execute("select distinct(context) from context") as cur:
+ contexts = {row[0] for row in cur}
return contexts
def file_tracer(self, filename):
@@ -862,8 +862,8 @@ class CoverageData(AutoReprMixin):
"""
self._start_using()
with self._connect() as con:
- cur = con.execute("select id from context where context = ?", (context,))
- self._query_context_ids = [row[0] for row in cur.fetchall()]
+ with con.execute("select id from context where context = ?", (context,)) as cur:
+ self._query_context_ids = [row[0] for row in cur.fetchall()]
def set_query_contexts(self, contexts):
"""Set a number of contexts for subsequent querying.
@@ -881,8 +881,8 @@ class CoverageData(AutoReprMixin):
if contexts:
with self._connect() as con:
context_clause = " or ".join(["context regexp ?"] * len(contexts))
- cur = con.execute("select id from context where " + context_clause, contexts)
- self._query_context_ids = [row[0] for row in cur.fetchall()]
+ with con.execute("select id from context where " + context_clause, contexts) as cur:
+ self._query_context_ids = [row[0] for row in cur.fetchall()]
else:
self._query_context_ids = None
@@ -914,7 +914,8 @@ class CoverageData(AutoReprMixin):
ids_array = ", ".join("?" * len(self._query_context_ids))
query += " and context_id in (" + ids_array + ")"
data += self._query_context_ids
- bitmaps = list(con.execute(query, data))
+ with con.execute(query, data) as cur:
+ bitmaps = list(cur)
nums = set()
for row in bitmaps:
nums.update(numbits_to_nums(row[0]))
@@ -949,8 +950,8 @@ class CoverageData(AutoReprMixin):
ids_array = ", ".join("?" * len(self._query_context_ids))
query += " and context_id in (" + ids_array + ")"
data += self._query_context_ids
- arcs = con.execute(query, data)
- return list(arcs)
+ with con.execute(query, data) as cur:
+ return list(cur)
def contexts_by_lineno(self, filename):
"""Get the contexts for each line in a file.
@@ -979,11 +980,12 @@ class CoverageData(AutoReprMixin):
ids_array = ", ".join("?" * len(self._query_context_ids))
query += " and arc.context_id in (" + ids_array + ")"
data += self._query_context_ids
- for fromno, tono, context in con.execute(query, data):
- if fromno > 0:
- lineno_contexts_map[fromno].add(context)
- if tono > 0:
- lineno_contexts_map[tono].add(context)
+ with con.execute(query, data) as cur:
+ for fromno, tono, context in cur:
+ if fromno > 0:
+ lineno_contexts_map[fromno].add(context)
+ if tono > 0:
+ lineno_contexts_map[tono].add(context)
else:
query = (
"select l.numbits, c.context from line_bits l, context c " +
@@ -995,9 +997,10 @@ class CoverageData(AutoReprMixin):
ids_array = ", ".join("?" * len(self._query_context_ids))
query += " and l.context_id in (" + ids_array + ")"
data += self._query_context_ids
- for numbits, context in con.execute(query, data):
- for lineno in numbits_to_nums(numbits):
- lineno_contexts_map[lineno].add(context)
+ with con.execute(query, data) as cur:
+ for numbits, context in cur:
+ for lineno in numbits_to_nums(numbits):
+ lineno_contexts_map[lineno].add(context)
return {lineno: list(contexts) for lineno, contexts in lineno_contexts_map.items()}
@@ -1009,8 +1012,10 @@ class CoverageData(AutoReprMixin):
"""
with SqliteDb(":memory:", debug=NoDebugging()) as db:
- temp_store = [row[0] for row in db.execute("pragma temp_store")]
- copts = [row[0] for row in db.execute("pragma compile_options")]
+ with db.execute("pragma temp_store") as cur:
+ temp_store = [row[0] for row in cur]
+ with db.execute("pragma compile_options") as cur:
+ copts = [row[0] for row in cur]
copts = textwrap.wrap(", ".join(copts), width=75)
return [
@@ -1078,9 +1083,9 @@ class SqliteDb(AutoReprMixin):
# This pragma makes writing faster. It disables rollbacks, but we never need them.
# PyPy needs the .close() calls here, or sqlite gets twisted up:
# https://bitbucket.org/pypy/pypy/issues/2872/default-isolation-mode-is-different-on
- self.execute("pragma journal_mode=off").close()
+ self.execute_void("pragma journal_mode=off")
# This pragma makes writing faster.
- self.execute("pragma synchronous=off").close()
+ self.execute_void("pragma synchronous=off")
def close(self):
"""If needed, close the connection."""
@@ -1106,7 +1111,7 @@ class SqliteDb(AutoReprMixin):
self.debug.write(f"EXCEPTION from __exit__: {exc}")
raise DataError(f"Couldn't end data file {self.filename!r}: {exc}") from exc
- def execute(self, sql, parameters=()):
+ def _execute(self, sql, parameters):
"""Same as :meth:`python:sqlite3.Connection.execute`."""
if self.debug.should("sql"):
tail = f" with {parameters!r}" if parameters else ""
@@ -1137,10 +1142,26 @@ class SqliteDb(AutoReprMixin):
self.debug.write(f"EXCEPTION from execute: {msg}")
raise DataError(f"Couldn't use data file {self.filename!r}: {msg}") from exc
+ @contextlib.contextmanager
+ def execute(self, sql, parameters=()):
+ """Context managed :meth:`python:sqlite3.Connection.execute`.
+
+ Use with a ``with`` statement to auto-close the returned cursor.
+ """
+ cur = self._execute(sql, parameters)
+ try:
+ yield cur
+ finally:
+ cur.close()
+
+ def execute_void(self, sql, parameters=()):
+ """Same as :meth:`python:sqlite3.Connection.execute` when you don't need the cursor."""
+ self._execute(sql, parameters).close()
+
def execute_for_rowid(self, sql, parameters=()):
"""Like execute, but returns the lastrowid."""
- con = self.execute(sql, parameters)
- rowid = con.lastrowid
+ with self.execute(sql, parameters) as cur:
+ rowid = cur.lastrowid
if self.debug.should("sqldata"):
self.debug.write(f"Row id result: {rowid!r}")
return rowid
@@ -1154,7 +1175,8 @@ class SqliteDb(AutoReprMixin):
Returns a row, or None if there were no rows.
"""
- rows = list(self.execute(sql, parameters))
+ with self.execute(sql, parameters) as cur:
+ rows = list(cur)
if len(rows) == 0:
return None
elif len(rows) == 1:
@@ -1162,7 +1184,7 @@ class SqliteDb(AutoReprMixin):
else:
raise AssertionError(f"SQL {sql!r} shouldn't return {len(rows)} rows")
- def executemany(self, sql, data):
+ def _executemany(self, sql, data):
"""Same as :meth:`python:sqlite3.Connection.executemany`."""
if self.debug.should("sql"):
data = list(data)
@@ -1179,13 +1201,29 @@ class SqliteDb(AutoReprMixin):
# https://github.com/nedbat/coveragepy/issues/1010
return self.con.executemany(sql, data)
+ @contextlib.contextmanager
+ def executemany(self, sql, data):
+ """Context managed :meth:`python:sqlite3.Connection.executemany`.
+
+ Use with a ``with`` statement to auto-close the returned cursor.
+ """
+ cur = self._executemany(sql, data)
+ try:
+ yield cur
+ finally:
+ cur.close()
+
+ def executemany_void(self, sql, data):
+ """Same as :meth:`python:sqlite3.Connection.executemany` when you don't need the cursor."""
+ self._executemany(sql, data).close()
+
def executescript(self, script):
"""Same as :meth:`python:sqlite3.Connection.executescript`."""
if self.debug.should("sql"):
self.debug.write("Executing script with {} chars: {}".format(
len(script), clipped_repr(script, 100),
))
- self.con.executescript(script)
+ self.con.executescript(script).close()
def dump(self):
"""Return a multi-line string, the SQL dump of the database."""