summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/sqlite/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/sqlite/base.py')
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py161
1 files changed, 97 insertions, 64 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index fdcd1340b..22f003e38 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -867,6 +867,7 @@ from ... import util
from ...engine import default
from ...engine import processors
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import ColumnElement
from ...sql import compiler
@@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect):
return [db[1] for db in dl if db[1] != "temp"]
- @reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
+ def _format_schema(self, schema, table_name):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ name = f"{qschema}.{table_name}"
else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
- return [row[0] for row in rs]
+ name = table_name
+ return name
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
- s = (
- "SELECT name FROM sqlite_temp_master "
- "WHERE type='table' ORDER BY name "
- )
- rs = connection.exec_driver_sql(s)
+ def get_table_names(self, connection, schema=None, **kw):
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- return [row[0] for row in rs]
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ main = "sqlite_temp_master"
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_temp_view_names(self, connection, **kw):
@@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect):
"SELECT name FROM sqlite_temp_master "
"WHERE type='view' ORDER BY name "
)
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
info = self._get_table_pragma(
@@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect):
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
- if schema is not None:
- qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
- else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ master = f"{qschema}.sqlite_master"
s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % (
master,
)
@@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect):
result = rs.fetchall()
if result:
return result[0].sql
+ else:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
@@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect):
tablesql,
)
)
- return columns
+ if columns:
+ return columns
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.columns()
def _get_column_info(
self,
@@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect):
"type": coltype,
"nullable": nullable,
"default": default,
- "autoincrement": "auto",
"primary_key": primary_key,
}
if generated:
@@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect):
constraint_name = result.group(1) if result else None
cols = self.get_columns(connection, table_name, schema, **kw)
+ # consider only pk columns. This also avoids sorting the cached
+ # value returned by get_columns
+ cols = [col for col in cols if col.get("primary_key", 0) > 0]
cols.sort(key=lambda col: col.get("primary_key"))
- pkeys = []
- for col in cols:
- if col["primary_key"]:
- pkeys.append(col["name"])
+ pkeys = [col["name"] for col in cols]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
@@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect):
# original DDL. The referred columns of the foreign key
# constraint are therefore the primary key of the referred
# table.
- referred_pk = self.get_pk_constraint(
- connection, rtbl, schema=schema, **kw
- )
- # note that if table doesn't exist, we still get back a record,
- # just it has no columns in it
- referred_columns = referred_pk["constrained_columns"]
+ try:
+ referred_pk = self.get_pk_constraint(
+ connection, rtbl, schema=schema, **kw
+ )
+ referred_columns = referred_pk["constrained_columns"]
+ except exc.NoSuchTableError:
+ # ignore not existing parents
+ referred_columns = []
else:
# note we use this list only if this is the first column
# in the constraint. for subsequent columns we ignore the
@@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect):
)
table_data = self._get_table_sql(connection, table_name, schema=schema)
- if table_data is None:
- # system tables, etc.
- return []
def parse_fks():
+ if table_data is None:
+ # system tables, etc.
+ return
FK_PATTERN = (
r"(?:CONSTRAINT (\w+) +)?"
r"FOREIGN KEY *\( *(.+?) *\) +"
@@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect):
# use them as is as it's extremely difficult to parse inline
# constraints
fkeys.extend(keys_by_signature.values())
- return fkeys
+ if fkeys:
+ return fkeys
+ else:
+ return ReflectionDefaults.foreign_keys()
def _find_cols_in_sig(self, sig):
for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I):
@@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
-
unique_constraints = []
def parse_uqs():
+ if table_data is None:
+ return
UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
INLINE_UNIQUE_PATTERN = (
r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) '
@@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect):
unique_constraints.append(parsed_constraint)
# NOTE: auto_index_by_sig might not be empty here,
# the PRIMARY KEY may have an entry.
- return unique_constraints
+ if unique_constraints:
+ return unique_constraints
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *"
check_constraints = []
@@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect):
# necessarily makes assumptions as to how the CREATE TABLE
# was emitted.
- for match in re.finditer(CHECK_PATTERN, table_data, re.I):
+ for match in re.finditer(CHECK_PATTERN, table_data or "", re.I):
name = match.group(1)
if name:
@@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect):
check_constraints.append({"sqltext": match.group(2), "name": name})
- return check_constraints
+ if check_constraints:
+ return check_constraints
+ else:
+ return ReflectionDefaults.check_constraints()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
@@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect):
# loop thru unique indexes to get the column names.
for idx in list(indexes):
pragma_index = self._get_table_pragma(
- connection, "index_info", idx["name"]
+ connection, "index_info", idx["name"], schema=schema
)
for row in pragma_index:
@@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect):
break
else:
idx["column_names"].append(row[2])
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if indexes:
+ return indexes
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.indexes()
+
+ def _is_sys_table(self, table_name):
+ return table_name in {
+ "sqlite_schema",
+ "sqlite_master",
+ "sqlite_temp_schema",
+ "sqlite_temp_master",
+ }
@reflection.cache
def _get_table_sql(self, connection, table_name, schema=None, **kw):
@@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect):
" (SELECT * FROM %(schema)ssqlite_master UNION ALL "
" SELECT * FROM %(schema)ssqlite_temp_master) "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
except exc.DBAPIError:
s = (
"SELECT sql FROM %(schema)ssqlite_master "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
- return rs.scalar()
+ value = rs.scalar()
+ if value is None and not self._is_sys_table(table_name):
+ raise exc.NoSuchTableError(f"{schema_expr}{table_name}")
+ return value
def _get_table_pragma(self, connection, pragma, table_name, schema=None):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
- statements = ["PRAGMA %s." % quote(schema)]
+ statements = [f"PRAGMA {quote(schema)}."]
else:
# because PRAGMA looks in all attached databases if no schema
# given, need to specify "main" schema, however since we want
@@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect):
qtable = quote(table_name)
for statement in statements:
- statement = "%s%s(%s)" % (statement, pragma, qtable)
+ statement = f"{statement}{pragma}({qtable})"
cursor = connection.exec_driver_sql(statement)
if not cursor._soft_closed:
# work around SQLite issue whereby cursor.description