diff options
-rw-r--r-- | alembic/autogenerate/__init__.py | 1 | ||||
-rw-r--r-- | alembic/autogenerate/generate.py | 14 | ||||
-rw-r--r-- | alembic/command.py | 4 |
3 files changed, 15 insertions, 4 deletions
diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py index 2d75912..661970a 100644 --- a/alembic/autogenerate/__init__.py +++ b/alembic/autogenerate/__init__.py @@ -1,2 +1,3 @@ from .api import compare_metadata, _produce_migration_diffs, \ _produce_net_changes +from .generate import RevisionContext diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py index fcd3533..f399e41 100644 --- a/alembic/autogenerate/generate.py +++ b/alembic/autogenerate/generate.py @@ -11,11 +11,14 @@ class GeneratedRevision(object): self.head = self.revision_context.command_args['head'] self.splice = self.revision_context.command_args['splice'] - self.branch_labels = \ - self.revision_context.command_args['branch_labels'] + self.branch_label = \ + self.revision_context.command_args['branch_label'] self.version_path = self.revision_context.command_args['version_path'] def to_script(self): + for k, v in self.revision_context.template_args.items(): + self.template_args.setdefault(k, v) + return self.revision_context.script_directory.generate_revision( self.rev_id, self.revision_context.command_args['message'], @@ -28,9 +31,14 @@ class GeneratedRevision(object): class RevisionContext(object): - def __init__(self, script_directory, command_args): + def __init__(self, config, script_directory, command_args): + self.config = config self.script_directory = script_directory self.command_args = command_args + self.template_args = { + 'config': config # Let templates use config for + # e.g. multiple databases + } self.generated_revisions = [ GeneratedRevision(self) ] diff --git a/alembic/command.py b/alembic/command.py index d15bf90..9819204 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -78,7 +78,8 @@ def revision( sql=sql, head=head, splice=splice, branch_label=branch_label, version_path=version_path, rev_id=rev_id ) - revision_context = autogen.RevisionContext(script_directory, command_args) + revision_context = autogen.RevisionContext( + config, script_directory, command_args) environment = util.asbool( config.get_main_option("revision_environment") @@ -109,6 +110,7 @@ def revision( script_directory, fn=retrieve_migrations, as_sql=sql, + template_args=revision_context.template_args, revision_context=revision_context ): script_directory.run_env() |