diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/sqlite/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 161 |
1 files changed, 97 insertions, 64 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index fdcd1340b..22f003e38 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -867,6 +867,7 @@ from ... import util from ...engine import default from ...engine import processors from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import ColumnElement from ...sql import compiler @@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect): return [db[1] for db in dl if db[1] != "temp"] - @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _format_schema(self, schema, table_name): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + name = f"{qschema}.{table_name}" else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - return [row[0] for row in rs] + name = table_name + return name @reflection.cache - def get_temp_table_names(self, connection, **kw): - s = ( - "SELECT name FROM sqlite_temp_master " - "WHERE type='table' ORDER BY name " - ) - rs = connection.exec_driver_sql(s) + def get_table_names(self, connection, schema=None, **kw): + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names - return [row[0] for row in rs] + @reflection.cache + def get_temp_table_names(self, connection, **kw): + main = "sqlite_temp_master" + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_temp_view_names(self, connection, **kw): @@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect): "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + names = connection.exec_driver_sql(s).scalars().all() + return names - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) info = self._get_table_pragma( @@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema - else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + master = f"{qschema}.sqlite_master" s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( master, ) @@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect): result = rs.fetchall() if result: return result[0].sql + else: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect): tablesql, ) ) - return columns + if columns: + return columns + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.columns() def _get_column_info( self, @@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect): "type": coltype, "nullable": nullable, "default": default, - "autoincrement": "auto", "primary_key": primary_key, } if generated: @@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) + # consider only pk columns. This also avoids sorting the cached + # value returned by get_columns + cols = [col for col in cols if col.get("primary_key", 0) > 0] cols.sort(key=lambda col: col.get("primary_key")) - pkeys = [] - for col in cols: - if col["primary_key"]: - pkeys.append(col["name"]) + pkeys = [col["name"] for col in cols] - return {"constrained_columns": pkeys, "name": constraint_name} + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect): # original DDL. The referred columns of the foreign key # constraint are therefore the primary key of the referred # table. - referred_pk = self.get_pk_constraint( - connection, rtbl, schema=schema, **kw - ) - # note that if table doesn't exist, we still get back a record, - # just it has no columns in it - referred_columns = referred_pk["constrained_columns"] + try: + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + referred_columns = referred_pk["constrained_columns"] + except exc.NoSuchTableError: + # ignore not existing parents + referred_columns = [] else: # note we use this list only if this is the first column # in the constraint. for subsequent columns we ignore the @@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect): ) table_data = self._get_table_sql(connection, table_name, schema=schema) - if table_data is None: - # system tables, etc. - return [] def parse_fks(): + if table_data is None: + # system tables, etc. + return FK_PATTERN = ( r"(?:CONSTRAINT (\w+) +)?" r"FOREIGN KEY *\( *(.+?) *\) +" @@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect): # use them as is as it's extremely difficult to parse inline # constraints fkeys.extend(keys_by_signature.values()) - return fkeys + if fkeys: + return fkeys + else: + return ReflectionDefaults.foreign_keys() def _find_cols_in_sig(self, sig): for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): @@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] - unique_constraints = [] def parse_uqs(): + if table_data is None: + return UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' @@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect): unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. - return unique_constraints + if unique_constraints: + return unique_constraints + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] @@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. - for match in re.finditer(CHECK_PATTERN, table_data, re.I): + for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): name = match.group(1) if name: @@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect): check_constraints.append({"sqltext": match.group(2), "name": name}) - return check_constraints + if check_constraints: + return check_constraints + else: + return ReflectionDefaults.check_constraints() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in list(indexes): pragma_index = self._get_table_pragma( - connection, "index_info", idx["name"] + connection, "index_info", idx["name"], schema=schema ) for row in pragma_index: @@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect): break else: idx["column_names"].append(row[2]) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + if indexes: + return indexes + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.indexes() + + def _is_sys_table(self, table_name): + return table_name in { + "sqlite_schema", + "sqlite_master", + "sqlite_temp_schema", + "sqlite_temp_master", + } @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): @@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect): " (SELECT * FROM %(schema)ssqlite_master UNION ALL " " SELECT * FROM %(schema)ssqlite_temp_master) " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) except exc.DBAPIError: s = ( "SELECT sql FROM %(schema)ssqlite_master " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) - return rs.scalar() + value = rs.scalar() + if value is None and not self._is_sys_table(table_name): + raise exc.NoSuchTableError(f"{schema_expr}{table_name}") + return value def _get_table_pragma(self, connection, pragma, table_name, schema=None): quote = self.identifier_preparer.quote_identifier if schema is not None: - statements = ["PRAGMA %s." % quote(schema)] + statements = [f"PRAGMA {quote(schema)}."] else: # because PRAGMA looks in all attached databases if no schema # given, need to specify "main" schema, however since we want @@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect): qtable = quote(table_name) for statement in statements: - statement = "%s%s(%s)" % (statement, pragma, qtable) + statement = f"{statement}{pragma}({qtable})" cursor = connection.exec_driver_sql(statement) if not cursor._soft_closed: # work around SQLite issue whereby cursor.description |