summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-07-07 17:28:47 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-07-07 17:28:47 -0400
commit5c31c546b5b90e7b1aecd9c442412e6c0439ce00 (patch)
tree0b6200955de081e922c5d562ae1a76b5c1cae3dd
parentaa35022ca7e21f07e4d9ebee0d4e8d4c8dc565d8 (diff)
downloadalembic-5c31c546b5b90e7b1aecd9c442412e6c0439ce00.tar.gz
- propose an AutogenContext object but the task of fixing it within
all the tests has to be worked out
-rw-r--r--alembic/autogenerate/api.py108
-rw-r--r--alembic/autogenerate/compare.py31
-rw-r--r--alembic/util/langhelpers.py32
-rw-r--r--tests/test_autogen_composition.py20
4 files changed, 158 insertions, 33 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index 00fe9de..97f0868 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -5,6 +5,7 @@ from ..operations import ops
from . import render
from . import compare
from .. import util
+from sqlalchemy.engine.reflection import Inspector
def compare_metadata(context, metadata):
@@ -160,10 +161,10 @@ def render_python_code(
up_or_down_op, autogen_context))
-def _render_migration_diffs(context, template_args, imports):
+def _render_migration_diffs(context, template_args):
"""legacy, used by test_autogen_composition at the moment"""
- autogen_context = _autogen_context(context, imports=imports)
+ autogen_context = _autogen_context(context)
upgrade_ops = ops.UpgradeOps([])
compare._produce_net_changes(autogen_context, upgrade_ops)
@@ -230,7 +231,110 @@ def _autogen_context(
}
+class AutogenContext(object):
+ """Maintains configuration and state that's specific to an
+ autogenerate operation."""
+
+ metadata = None
+ """The :class:`~sqlalchemy.schema.MetaData` object
+ representing the destination.
+
+ This object is the one that is passed within ``env.py``
+ to the :paramref:`.EnvironmentContext.configure.target_metadata`
+ parameter. It represents the structure of :class:`.Table` and other
+ objects as stated in the current database model, and represents the
+ destination structure for the database being examined.
+
+ While the :class:`~sqlalchemy.schema.MetaData` object is primarily
+ known as a collection of :class:`~sqlalchemy.schema.Table` objects,
+ it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
+ that may be used by end-user schemes to store additional schema-level
+ objects that are to be compared in custom autogeneration schemes.
+
+ """
+
+ connection = None
+ """The :class:`~sqlalchemy.engine.base.Connection` object currently
+ connected to the database backend being compared.
+
+ This is obtained from the :attr:`.MigrationContext.bind` and is
+ utimately set up in the ``env.py`` script.
+
+ """
+
+ migration_context = None
+ """The :class:`.MigrationContext` established by the ``env.py`` script."""
+
+ def __init__(self, migration_context, metadata=None):
+
+ if migration_context.as_sql:
+ raise util.CommandError(
+ "autogenerate can't use as_sql=True as it prevents querying "
+ "the database for schema information")
+
+ opts = migration_context.opts
+ self.metadata = metadata = opts['target_metadata'] \
+ if metadata is None else metadata
+
+ if metadata is None:
+ raise util.CommandError(
+ "Can't proceed with --autogenerate option; environment "
+ "script %s does not provide "
+ "a MetaData object to the context." % (
+ migration_context.script.env_py_location
+ ))
+
+ include_schemas = opts.get('include_schemas', False)
+
+ include_symbol = opts.get('include_symbol', None)
+ include_object = opts.get('include_object', None)
+
+ object_filters = []
+ if include_symbol:
+ def include_symbol_filter(
+ object, name, type_, reflected, compare_to):
+ if type_ == "table":
+ return include_symbol(name, object.schema)
+ else:
+ return True
+ object_filters.append(include_symbol_filter)
+ if include_object:
+ object_filters.append(include_object)
+
+ self._object_filters = object_filters
+
+ self.connection = migration_context.bind
+ self.migration_context = migration_context
+ self._imports = set()
+ self._opts = opts
+ self.dialect = migration_context.dialect
+ self._include_schemas = include_schemas
+ self.inspector = Inspector.from_engine(self.connection)
+
+ def run_filters(self, object_, name, type_, reflected, compare_to):
+ """Run the context's object filters and return True if the targets
+ should be part of the autogenerate operation.
+
+ This method should be run for every kind of object encountered within
+ an autogenerate operation, giving the environment the chance
+ to filter what objects should be included in the comparison.
+ The filters here are produced directly via the
+ :paramref:`.EnvironmentContext.configure.include_object`
+ and :paramref:`.EnvironmentContext.configure.include_symbol`
+ functions, if present.
+
+ """
+ for fn in self._object_filters:
+ if not fn(object_, name, type_, reflected, compare_to):
+ return False
+ else:
+ return True
+
+
class RevisionContext(object):
+ """Maintains configuration and state that's specific to a revision
+ file generation operation."""
+
def __init__(self, config, script_directory, command_args):
self.config = config
self.script_directory = script_directory
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 4001453..267a108 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -21,13 +21,10 @@ def _populate_migration_script(autogen_context, migration_script):
def _produce_net_changes(autogen_context, upgrade_ops):
- metadata = autogen_context['metadata']
connection = autogen_context['connection']
- object_filters = autogen_context.get('object_filters', ())
include_schemas = autogen_context.get('include_schemas', False)
inspector = Inspector.from_engine(connection)
- conn_table_names = set()
default_schema = connection.dialect.default_schema_name
if include_schemas:
@@ -40,6 +37,26 @@ def _produce_net_changes(autogen_context, upgrade_ops):
else:
schemas = [None]
+ _autogen_for_tables(autogen_context, schemas, upgrade_ops)
+
+
+def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
+ for fn in object_filters:
+ if not fn(object_, name, type_, reflected, compare_to):
+ return False
+ else:
+ return True
+
+
+def _autogen_for_tables(autogen_context, schemas, upgrade_ops):
+ connection = autogen_context['connection']
+ inspector = Inspector.from_engine(connection)
+
+ metadata = autogen_context['metadata']
+ object_filters = autogen_context.get('object_filters', ())
+
+ conn_table_names = set()
+
version_table_schema = autogen_context['context'].version_table_schema
version_table = autogen_context['context'].version_table
@@ -60,14 +77,6 @@ def _produce_net_changes(autogen_context, upgrade_ops):
inspector, metadata, upgrade_ops, autogen_context)
-def _run_filters(object_, name, type_, reflected, compare_to, object_filters):
- for fn in object_filters:
- if not fn(object_, name, type_, reflected, compare_to):
- return False
- else:
- return True
-
-
def _compare_tables(conn_table_names, metadata_table_names,
object_filters,
inspector, metadata, upgrade_ops, autogen_context):
diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py
index 904848c..efa86b3 100644
--- a/alembic/util/langhelpers.py
+++ b/alembic/util/langhelpers.py
@@ -246,30 +246,50 @@ def _with_legacy_names(translations):
class Dispatcher(object):
- def __init__(self):
+ def __init__(self, uselist=False):
self._registry = {}
+ self.uselist = uselist
def dispatch_for(self, target, qualifier='default'):
def decorate(fn):
assert isinstance(target, type)
- assert target not in self._registry
- self._registry[(target, qualifier)] = fn
+ if self.uselist:
+ assert target not in self._registry
+ self._registry.setdefault((target, qualifier), []).append(fn)
+ else:
+ assert target not in self._registry
+ self._registry[(target, qualifier)] = fn
return fn
return decorate
def dispatch(self, obj, qualifier='default'):
for spcls in type(obj).__mro__:
if qualifier != 'default' and (spcls, qualifier) in self._registry:
- return self._registry[(spcls, qualifier)]
+ return self._fn_or_list(self._registry[(spcls, qualifier)])
elif (spcls, 'default') in self._registry:
- return self._registry[(spcls, 'default')]
+ return self._fn_or_list(self._registry[(spcls, 'default')])
else:
raise ValueError("no dispatch function for object: %s" % obj)
+ def _fn_or_list(self, fn_or_list):
+ if self.uselist:
+ def go(*arg, **kw):
+ for fn in fn_or_list:
+ fn(*arg, **kw)
+ return go
+ else:
+ return fn_or_list
+
def branch(self):
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
- d._registry.update(self._registry)
+ if self.uselist:
+ d._registry.update(
+ (k, [fn for fn in self._registry[k]])
+ for k in self._registry
+ )
+ else:
+ d._registry.update(self._registry)
return d
diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py
index ff516f6..6d1f55b 100644
--- a/tests/test_autogen_composition.py
+++ b/tests/test_autogen_composition.py
@@ -23,7 +23,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._render_migration_diffs(context, template_args, set())
+ autogenerate._render_migration_diffs(context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -50,10 +50,8 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._render_migration_diffs(
- context, template_args, set(),
+ autogenerate._render_migration_diffs(context, template_args)
- )
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
pass
@@ -67,8 +65,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
"""test a full render including indentation"""
template_args = {}
- autogenerate._render_migration_diffs(
- self.context, template_args, set())
+ autogenerate._render_migration_diffs(self.context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
op.create_table('item',
@@ -135,8 +132,7 @@ nullable=True))
template_args = {}
self.context.opts['render_as_batch'] = True
- autogenerate._render_migration_diffs(
- self.context, template_args, set())
+ autogenerate._render_migration_diffs(self.context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
@@ -229,10 +225,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
}
)
template_args = {}
- autogenerate._render_migration_diffs(
- context, template_args, set(),
+ autogenerate._render_migration_diffs(context, template_args)
- )
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###
pass
@@ -250,9 +244,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
'include_object': _default_include_object,
'include_schemas': True
})
- autogenerate._render_migration_diffs(
- self.context, template_args, set()
- )
+ autogenerate._render_migration_diffs(self.context, template_args)
eq_(re.sub(r"u'", "'", template_args['upgrades']),
"""### commands auto generated by Alembic - please adjust! ###