summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py131
1 files changed, 128 insertions, 3 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index df35e7128..40af06252 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -45,6 +45,8 @@ from .interfaces import CacheStats
from .interfaces import DBAPICursor
from .interfaces import Dialect
from .interfaces import ExecutionContext
+from .reflection import ObjectKind
+from .reflection import ObjectScope
from .. import event
from .. import exc
from .. import pool
@@ -508,15 +510,22 @@ class DefaultDialect(Dialect):
"""
return type_api.adapt_type(typeobj, self.colspecs)
- def has_index(self, connection, table_name, index_name, schema=None):
- if not self.has_table(connection, table_name, schema=schema):
+ def has_index(self, connection, table_name, index_name, schema=None, **kw):
+ if not self.has_table(connection, table_name, schema=schema, **kw):
return False
- for idx in self.get_indexes(connection, table_name, schema=schema):
+ for idx in self.get_indexes(
+ connection, table_name, schema=schema, **kw
+ ):
if idx["name"] == index_name:
return True
else:
return False
+ def has_schema(
+ self, connection: Connection, schema_name: str, **kw: Any
+ ) -> bool:
+ return schema_name in self.get_schema_names(connection, **kw)
+
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
@@ -769,6 +778,122 @@ class DefaultDialect(Dialect):
def get_driver_connection(self, connection):
return connection
+ def _overrides_default(self, method):
+ return (
+ getattr(type(self), method).__code__
+ is not getattr(DefaultDialect, method).__code__
+ )
+
+ def _default_multi_reflect(
+ self,
+ single_tbl_method,
+ connection,
+ kind,
+ schema,
+ filter_names,
+ scope,
+ **kw,
+ ):
+
+ names_fns = []
+ temp_names_fns = []
+ if ObjectKind.TABLE in kind:
+ names_fns.append(self.get_table_names)
+ temp_names_fns.append(self.get_temp_table_names)
+ if ObjectKind.VIEW in kind:
+ names_fns.append(self.get_view_names)
+ temp_names_fns.append(self.get_temp_view_names)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ names_fns.append(self.get_materialized_view_names)
+ # no temp materialized view at the moment
+ # temp_names_fns.append(self.get_temp_materialized_view_names)
+
+ unreflectable = kw.pop("unreflectable", {})
+
+ if (
+ filter_names
+ and scope is ObjectScope.ANY
+ and kind is ObjectKind.ANY
+ ):
+ # if names are given and no qualification on type of table
+ # (i.e. the Table(..., autoload) case), take the names as given,
+ # don't run names queries. If a table does not exit
+ # NoSuchTableError is raised and it's skipped
+
+ # this also suits the case for mssql where we can reflect
+ # individual temp tables but there's no temp_names_fn
+ names = filter_names
+ else:
+ names = []
+ name_kw = {"schema": schema, **kw}
+ fns = []
+ if ObjectScope.DEFAULT in scope:
+ fns.extend(names_fns)
+ if ObjectScope.TEMPORARY in scope:
+ fns.extend(temp_names_fns)
+
+ for fn in fns:
+ try:
+ names.extend(fn(connection, **name_kw))
+ except NotImplementedError:
+ pass
+
+ if filter_names:
+ filter_names = set(filter_names)
+
+ # iterate over all the tables/views and call the single table method
+ for table in names:
+ if not filter_names or table in filter_names:
+ key = (schema, table)
+ try:
+ yield (
+ key,
+ single_tbl_method(
+ connection, table, schema=schema, **kw
+ ),
+ )
+ except exc.UnreflectableTableError as err:
+ if key not in unreflectable:
+ unreflectable[key] = err
+ except exc.NoSuchTableError:
+ pass
+
+ def get_multi_table_options(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_table_options, connection, **kw
+ )
+
+ def get_multi_columns(self, connection, **kw):
+ return self._default_multi_reflect(self.get_columns, connection, **kw)
+
+ def get_multi_pk_constraint(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_pk_constraint, connection, **kw
+ )
+
+ def get_multi_foreign_keys(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_foreign_keys, connection, **kw
+ )
+
+ def get_multi_indexes(self, connection, **kw):
+ return self._default_multi_reflect(self.get_indexes, connection, **kw)
+
+ def get_multi_unique_constraints(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_unique_constraints, connection, **kw
+ )
+
+ def get_multi_check_constraints(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_check_constraints, connection, **kw
+ )
+
+ def get_multi_table_comment(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_table_comment, connection, **kw
+ )
+
class StrCompileDialect(DefaultDialect):