diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 131 |
1 files changed, 128 insertions, 3 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index df35e7128..40af06252 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -45,6 +45,8 @@ from .interfaces import CacheStats from .interfaces import DBAPICursor from .interfaces import Dialect from .interfaces import ExecutionContext +from .reflection import ObjectKind +from .reflection import ObjectScope from .. import event from .. import exc from .. import pool @@ -508,15 +510,22 @@ class DefaultDialect(Dialect): """ return type_api.adapt_type(typeobj, self.colspecs) - def has_index(self, connection, table_name, index_name, schema=None): - if not self.has_table(connection, table_name, schema=schema): + def has_index(self, connection, table_name, index_name, schema=None, **kw): + if not self.has_table(connection, table_name, schema=schema, **kw): return False - for idx in self.get_indexes(connection, table_name, schema=schema): + for idx in self.get_indexes( + connection, table_name, schema=schema, **kw + ): if idx["name"] == index_name: return True else: return False + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + return schema_name in self.get_schema_names(connection, **kw) + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( @@ -769,6 +778,122 @@ class DefaultDialect(Dialect): def get_driver_connection(self, connection): return connection + def _overrides_default(self, method): + return ( + getattr(type(self), method).__code__ + is not getattr(DefaultDialect, method).__code__ + ) + + def _default_multi_reflect( + self, + single_tbl_method, + connection, + kind, + schema, + filter_names, + scope, + **kw, + ): + + names_fns = [] + temp_names_fns = [] + if ObjectKind.TABLE in kind: + names_fns.append(self.get_table_names) + temp_names_fns.append(self.get_temp_table_names) + if ObjectKind.VIEW in kind: + names_fns.append(self.get_view_names) + temp_names_fns.append(self.get_temp_view_names) + if ObjectKind.MATERIALIZED_VIEW in kind: + names_fns.append(self.get_materialized_view_names) + # no temp materialized view at the moment + # temp_names_fns.append(self.get_temp_materialized_view_names) + + unreflectable = kw.pop("unreflectable", {}) + + if ( + filter_names + and scope is ObjectScope.ANY + and kind is ObjectKind.ANY + ): + # if names are given and no qualification on type of table + # (i.e. the Table(..., autoload) case), take the names as given, + # don't run names queries. If a table does not exit + # NoSuchTableError is raised and it's skipped + + # this also suits the case for mssql where we can reflect + # individual temp tables but there's no temp_names_fn + names = filter_names + else: + names = [] + name_kw = {"schema": schema, **kw} + fns = [] + if ObjectScope.DEFAULT in scope: + fns.extend(names_fns) + if ObjectScope.TEMPORARY in scope: + fns.extend(temp_names_fns) + + for fn in fns: + try: + names.extend(fn(connection, **name_kw)) + except NotImplementedError: + pass + + if filter_names: + filter_names = set(filter_names) + + # iterate over all the tables/views and call the single table method + for table in names: + if not filter_names or table in filter_names: + key = (schema, table) + try: + yield ( + key, + single_tbl_method( + connection, table, schema=schema, **kw + ), + ) + except exc.UnreflectableTableError as err: + if key not in unreflectable: + unreflectable[key] = err + except exc.NoSuchTableError: + pass + + def get_multi_table_options(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_options, connection, **kw + ) + + def get_multi_columns(self, connection, **kw): + return self._default_multi_reflect(self.get_columns, connection, **kw) + + def get_multi_pk_constraint(self, connection, **kw): + return self._default_multi_reflect( + self.get_pk_constraint, connection, **kw + ) + + def get_multi_foreign_keys(self, connection, **kw): + return self._default_multi_reflect( + self.get_foreign_keys, connection, **kw + ) + + def get_multi_indexes(self, connection, **kw): + return self._default_multi_reflect(self.get_indexes, connection, **kw) + + def get_multi_unique_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_unique_constraints, connection, **kw + ) + + def get_multi_check_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_check_constraints, connection, **kw + ) + + def get_multi_table_comment(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_comment, connection, **kw + ) + class StrCompileDialect(DefaultDialect): |