diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-01-18 12:27:38 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-01-18 14:47:35 -0500 |
commit | f86c338d7fc2cacaf479f5e75b1476452d381ebb (patch) | |
tree | 4bd62dd12fe0a63fbcd781da00f2cfa339b1301a /alembic/autogenerate | |
parent | 131eace6aea202484ac2a5ee1a8082c851affbf3 (diff) | |
download | alembic-f86c338d7fc2cacaf479f5e75b1476452d381ebb.tar.gz |
implement include_name hook
Added new hook :paramref:`.EnvironmentContext.configure.include_name`,
which complements the
:paramref:`.EnvironmentContext.configure.include_object` hook by providing
a means of preventing objects of a certain name from being autogenerated
**before** the SQLAlchemy reflection process takes place, and notably
includes explicit support for passing each schema name when
:paramref:`.EnvironmentContext.configure.include_schemas` is set to True.
This is most important especially for enviroments that make use of
:paramref:`.EnvironmentContext.configure.include_schemas` where schemas are
actually databases (e.g. MySQL) in order to prevent reflection sweeps of
the entire server.
The long deprecated
:paramref:`.EnvironmentContext.configure.include_symbol` hook is removed.
The :paramref:`.EnvironmentContext.configure.include_object`
and :paramref:`.EnvironmentContext.configure.include_name`
hooks both achieve the goals of this hook.
Change-Id: Idd44a357088a79be94488fdd7a7841bf118d47e2
Fixes: #650
Diffstat (limited to 'alembic/autogenerate')
-rw-r--r-- | alembic/autogenerate/api.py | 58 | ||||
-rw-r--r-- | alembic/autogenerate/compare.py | 84 |
2 files changed, 106 insertions, 36 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index db5fe12..7411f9b 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -286,25 +286,18 @@ class AutogenContext(object): % (migration_context.script.env_py_location) ) - include_symbol = opts.get("include_symbol", None) include_object = opts.get("include_object", None) + include_name = opts.get("include_name", None) object_filters = [] - if include_symbol: - - def include_symbol_filter( - object_, name, type_, reflected, compare_to - ): - if type_ == "table": - return include_symbol(name, object_.schema) - else: - return True - - object_filters.append(include_symbol_filter) + name_filters = [] if include_object: object_filters.append(include_object) + if include_name: + name_filters.append(include_name) self._object_filters = object_filters + self._name_filters = name_filters self.migration_context = migration_context if self.migration_context is not None: @@ -325,7 +318,40 @@ class AutogenContext(object): yield self._has_batch = False - def run_filters(self, object_, name, type_, reflected, compare_to): + def run_name_filters(self, name, type_, parent_names): + """Run the context's name filters and return True if the targets + should be part of the autogenerate operation. + + This method should be run for every kind of name encountered within the + reflection side of an autogenerate operation, giving the environment + the chance to filter what names should be reflected as database + objects. The filters here are produced directly via the + :paramref:`.EnvironmentContext.configure.include_name` parameter. + + """ + + if "schema_name" in parent_names: + if type_ == "table": + table_name = name + else: + table_name = parent_names["table_name"] + schema_name = parent_names["schema_name"] + if schema_name: + parent_names["schema_qualified_table_name"] = "%s.%s" % ( + schema_name, + table_name, + ) + else: + parent_names["schema_qualified_table_name"] = table_name + + for fn in self._name_filters: + + if not fn(name, type_, parent_names): + return False + else: + return True + + def run_object_filters(self, object_, name, type_, reflected, compare_to): """Run the context's object filters and return True if the targets should be part of the autogenerate operation. @@ -333,9 +359,7 @@ class AutogenContext(object): an autogenerate operation, giving the environment the chance to filter what objects should be included in the comparison. The filters here are produced directly via the - :paramref:`.EnvironmentContext.configure.include_object` - and :paramref:`.EnvironmentContext.configure.include_symbol` - functions, if present. + :paramref:`.EnvironmentContext.configure.include_object` parameter. """ for fn in self._object_filters: @@ -344,6 +368,8 @@ class AutogenContext(object): else: return True + run_filters = run_object_filters + @util.memoized_property def sorted_tables(self): """Return an aggregate of the :attr:`.MetaData.sorted_tables` collection(s). diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index b82225d..1b36565 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -47,6 +47,10 @@ def _produce_net_changes(autogen_context, upgrade_ops): else: schemas = [None] + schemas = { + s for s in schemas if autogen_context.run_name_filters(s, "schema", {}) + } + comparators.dispatch("schema", autogen_context.dialect.name)( autogen_context, upgrade_ops, schemas ) @@ -63,13 +67,20 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas): ) version_table = autogen_context.migration_context.version_table - for s in schemas: - tables = set(inspector.get_table_names(schema=s)) - if s == version_table_schema: + for schema_name in schemas: + tables = set(inspector.get_table_names(schema=schema_name)) + if schema_name == version_table_schema: tables = tables.difference( [autogen_context.migration_context.version_table] ) - conn_table_names.update(zip([s] * len(tables), tables)) + + conn_table_names.update( + (schema_name, tname) + for tname in tables + if autogen_context.run_name_filters( + tname, "table", {"schema_name": schema_name} + ) + ) metadata_table_names = OrderedSet( [(table.schema, table.name) for table in autogen_context.sorted_tables] @@ -125,7 +136,7 @@ def _compare_tables( for s, tname in metadata_table_names.difference(conn_table_names): name = "%s.%s" % (s, tname) if s else tname metadata_table = tname_to_table[(s, tname)] - if autogen_context.run_filters( + if autogen_context.run_object_filters( metadata_table, tname, "table", False, None ): upgrade_ops.ops.append( @@ -162,7 +173,7 @@ def _compare_tables( # fmt: on ) sqla_compat._reflect_table(inspector, t, None) - if autogen_context.run_filters(t, tname, "table", True, None): + if autogen_context.run_object_filters(t, tname, "table", True, None): modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) @@ -201,7 +212,7 @@ def _compare_tables( metadata_table = tname_to_table[(s, tname)] conn_table = existing_metadata.tables[name] - if autogen_context.run_filters( + if autogen_context.run_object_filters( metadata_table, tname, "table", False, conn_table ): @@ -286,11 +297,17 @@ def _compare_columns( metadata_cols_by_name = dict( (c.name, c) for c in metadata_table.c if not c.system ) - conn_col_names = dict((c.name, c) for c in conn_table.c) + conn_col_names = dict( + (c.name, c) + for c in conn_table.c + if autogen_context.run_name_filters( + c.name, "column", {"table_name": tname, "schema_name": schema} + ) + ) metadata_col_names = OrderedSet(sorted(metadata_cols_by_name)) for cname in metadata_col_names.difference(conn_col_names): - if autogen_context.run_filters( + if autogen_context.run_object_filters( metadata_cols_by_name[cname], cname, "column", False, None ): modify_table_ops.ops.append( @@ -303,7 +320,7 @@ def _compare_columns( for colname in metadata_col_names.intersection(conn_col_names): metadata_col = metadata_cols_by_name[colname] conn_col = conn_table.c[colname] - if not autogen_context.run_filters( + if not autogen_context.run_object_filters( metadata_col, colname, "column", False, conn_col ): continue @@ -325,7 +342,7 @@ def _compare_columns( yield for cname in set(conn_col_names).difference(metadata_col_names): - if autogen_context.run_filters( + if autogen_context.run_object_filters( conn_table.c[cname], cname, "column", True, None ): modify_table_ops.ops.append( @@ -471,6 +488,15 @@ def _compare_indexes_and_uniques( # not being present pass else: + conn_uniques = [ + uq + for uq in conn_uniques + if autogen_context.run_name_filters( + uq["name"], + "unique_constraint", + {"table_name": tname, "schema_name": schema}, + ) + ] for uq in conn_uniques: if uq.get("duplicates_index"): unique_constraints_duplicate_unique_indexes = True @@ -478,6 +504,16 @@ def _compare_indexes_and_uniques( conn_indexes = inspector.get_indexes(tname, schema=schema) except NotImplementedError: pass + else: + conn_indexes = [ + ix + for ix in conn_indexes + if autogen_context.run_name_filters( + ix["name"], + "index", + {"table_name": tname, "schema_name": schema}, + ) + ] # 2. convert conn-level objects from raw inspector records # into schema objects @@ -578,7 +614,7 @@ def _compare_indexes_and_uniques( def obj_added(obj): if obj.is_index: - if autogen_context.run_filters( + if autogen_context.run_object_filters( obj.const, obj.name, "index", False, None ): modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const)) @@ -595,7 +631,7 @@ def _compare_indexes_and_uniques( if is_create_table or is_drop_table: # unique constraints are created inline with table defs return - if autogen_context.run_filters( + if autogen_context.run_object_filters( obj.const, obj.name, "unique_constraint", False, None ): modify_ops.ops.append( @@ -615,7 +651,7 @@ def _compare_indexes_and_uniques( # be sure what we're doing here return - if autogen_context.run_filters( + if autogen_context.run_object_filters( obj.const, obj.name, "index", True, None ): modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const)) @@ -627,7 +663,7 @@ def _compare_indexes_and_uniques( # if the whole table is being dropped, we don't need to # consider unique constraint separately return - if autogen_context.run_filters( + if autogen_context.run_object_filters( obj.const, obj.name, "unique_constraint", True, None ): modify_ops.ops.append( @@ -641,7 +677,7 @@ def _compare_indexes_and_uniques( def obj_changed(old, new, msg): if old.is_index: - if autogen_context.run_filters( + if autogen_context.run_object_filters( new.const, new.name, "index", False, old.const ): log.info( @@ -653,7 +689,7 @@ def _compare_indexes_and_uniques( modify_ops.ops.append(ops.DropIndexOp.from_index(old.const)) modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const)) else: - if autogen_context.run_filters( + if autogen_context.run_object_filters( new.const, new.name, "unique_constraint", False, old.const ): log.info( @@ -1128,7 +1164,15 @@ def _compare_foreign_keys( if isinstance(fk, sa_schema.ForeignKeyConstraint) ) - conn_fks = inspector.get_foreign_keys(tname, schema=schema) + conn_fks = [ + fk + for fk in inspector.get_foreign_keys(tname, schema=schema) + if autogen_context.run_name_filters( + fk["name"], + "foreign_key_constraint", + {"table_name": tname, "schema_name": schema}, + ) + ] backend_reflects_fk_options = conn_fks and "options" in conn_fks[0] @@ -1161,7 +1205,7 @@ def _compare_foreign_keys( ) def _add_fk(obj, compare_to): - if autogen_context.run_filters( + if autogen_context.run_object_filters( obj.const, obj.name, "foreign_key_constraint", False, compare_to ): modify_table_ops.ops.append( @@ -1177,7 +1221,7 @@ def _compare_foreign_keys( ) def _remove_fk(obj, compare_to): - if autogen_context.run_filters( + if autogen_context.run_object_filters( obj.const, obj.name, "foreign_key_constraint", True, compare_to ): modify_table_ops.ops.append( |