diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-06-11 15:24:22 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-06-11 15:24:22 -0400 |
commit | 7491dd7850924550374dd399253af5145b04283c (patch) | |
tree | a705614dc07664bf1b265bbfe590d44d2e98c845 | |
parent | 541cbd26b5b86ce445f2065b60d28fdcbbb299a9 (diff) | |
download | alembic-7491dd7850924550374dd399253af5145b04283c.tar.gz |
- initial sketch of extension
-rw-r--r-- | alembic/autogenerate/generate.py | 55 | ||||
-rw-r--r-- | alembic/command.py | 40 |
2 files changed, 78 insertions, 17 deletions
diff --git a/alembic/autogenerate/generate.py b/alembic/autogenerate/generate.py new file mode 100644 index 0000000..fcd3533 --- /dev/null +++ b/alembic/autogenerate/generate.py @@ -0,0 +1,55 @@ +from .. import util +from .api import _produce_migration_diffs + + +class GeneratedRevision(object): + def __init__(self, revision_context): + self.revision_context = revision_context + self.template_args = {} + self.imports = set() + self.rev_id = revision_context.command_args['rev_id'] or util.rev_id() + + self.head = self.revision_context.command_args['head'] + self.splice = self.revision_context.command_args['splice'] + self.branch_labels = \ + self.revision_context.command_args['branch_labels'] + self.version_path = self.revision_context.command_args['version_path'] + + def to_script(self): + return self.revision_context.script_directory.generate_revision( + self.rev_id, + self.revision_context.command_args['message'], + refresh=True, + head=self.head, + splice=self.splice, + branch_labels=self.branch_label, + version_path=self.version_path, + **self.template_args) + + +class RevisionContext(object): + def __init__(self, script_directory, command_args): + self.script_directory = script_directory + self.command_args = command_args + self.generated_revisions = [ + GeneratedRevision(self) + ] + + def run_autogenerate(self, rev, context): + if self.command_args['sql']: + raise util.CommandError( + "Using --sql with --autogenerate does not make any sense") + if set(self.script_directory.get_revisions(rev)) != \ + set(self.script_directory.get_revisions("heads")): + raise util.CommandError("Target database is not up to date.") + for generated_revision in self.generated_revisions: + _produce_migration_diffs( + context, + generated_revision.template_args, generated_revision.imports) + + def run_no_autogenerate(self, rev, context): + pass + + def generate_scripts(self): + for generated_revision in self.generated_revisions: + yield generated_revision.to_script() diff --git a/alembic/command.py b/alembic/command.py index 5ba6d6a..d15bf90 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -70,12 +70,15 @@ def revision( version_path=None, rev_id=None): """Create a new revision file.""" - script = ScriptDirectory.from_config(config) - template_args = { - 'config': config # Let templates use config for - # e.g. multiple databases - } - imports = set() + script_directory = ScriptDirectory.from_config(config) + + command_args = dict( + message=message, + autogenerate=autogenerate, + sql=sql, head=head, splice=splice, branch_label=branch_label, + version_path=version_path, rev_id=rev_id + ) + revision_context = autogen.RevisionContext(script_directory, command_args) environment = util.asbool( config.get_main_option("revision_environment") @@ -89,13 +92,11 @@ def revision( "Using --sql with --autogenerate does not make any sense") def retrieve_migrations(rev, context): - if set(script.get_revisions(rev)) != \ - set(script.get_revisions("heads")): - raise util.CommandError("Target database is not up to date.") - autogen._produce_migration_diffs(context, template_args, imports) + revision_context.run_autogenerate(rev, context) return [] elif environment: def retrieve_migrations(rev, context): + revision_context.run_no_autogenerate(rev, context) return [] elif sql: raise util.CommandError( @@ -105,16 +106,21 @@ def revision( if environment: with EnvironmentContext( config, - script, + script_directory, fn=retrieve_migrations, as_sql=sql, - template_args=template_args, + revision_context=revision_context ): - script.run_env() - return script.generate_revision( - rev_id or util.rev_id(), message, refresh=True, - head=head, splice=splice, branch_labels=branch_label, - version_path=version_path, **template_args) + script_directory.run_env() + + scripts = [ + script for script in + revision_context.generate_scripts() + ] + if len(scripts) == 1: + return scripts[0] + else: + return scripts def merge(config, revisions, message=None, branch_label=None, rev_id=None): |