diff options
Diffstat (limited to 'alembic/command.py')
-rw-r--r-- | alembic/command.py | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/alembic/command.py b/alembic/command.py index 18b421c..ed7b830 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -1,7 +1,7 @@ from alembic.script import ScriptDirectory -from alembic import util, ddl, autogenerate as autogen, environment +from alembic.environment import EnvironmentContext +from alembic import util, ddl, autogenerate as autogen import os -import functools def list_templates(config): """List available templates""" @@ -44,14 +44,14 @@ def init(config, directory, template='generic'): if os.access(config_file, os.F_OK): util.msg("File %s already exists, skipping" % config_file) else: - script.generate_template( + script._generate_template( os.path.join(template_dir, file_), config_file, script_location=directory ) else: output_file = os.path.join(directory, file_) - script.copy_file( + script._copy_file( os.path.join(template_dir, file_), output_file ) @@ -68,18 +68,18 @@ def revision(config, message=None, autogenerate=False): if autogenerate: util.requires_07("autogenerate") def retrieve_migrations(rev, context): - if script._get_rev(rev) is not script._get_rev("head"): + if script.get_revision(rev) is not script.get_revision("head"): raise util.CommandError("Target database is not up to date.") - autogen.produce_migration_diffs(context, template_args, imports) + autogen._produce_migration_diffs(context, template_args, imports) return [] - with environment.configure( + with EnvironmentContext( config, script, fn = retrieve_migrations ): script.run_env() - script.generate_rev(util.rev_id(), message, **template_args) + script.generate_revision(util.rev_id(), message, **template_args) def upgrade(config, revision, sql=False, tag=None): @@ -92,10 +92,14 @@ def upgrade(config, revision, sql=False, tag=None): if not sql: raise util.CommandError("Range revision not allowed") starting_rev, revision = revision.split(':', 2) - with environment.configure( + + def upgrade(rev, context): + return script._upgrade_revs(revision, rev) + + with EnvironmentContext( config, script, - fn = functools.partial(script.upgrade_from, revision), + fn = upgrade, as_sql = sql, starting_rev = starting_rev, destination_rev = revision, @@ -114,10 +118,13 @@ def downgrade(config, revision, sql=False, tag=None): raise util.CommandError("Range revision not allowed") starting_rev, revision = revision.split(':', 2) - with environment.configure( + def downgrade(rev, context): + return script._downgrade_revs(revision, rev) + + with EnvironmentContext( config, script, - fn = functools.partial(script.downgrade_to, revision), + fn = downgrade, as_sql = sql, starting_rev = starting_rev, destination_rev = revision, @@ -143,7 +150,7 @@ def branches(config): for rev in sc.nextrev: print "%s -> %s" % ( " " * len(str(sc.down_revision)), - script._get_rev(rev) + script.get_revision(rev) ) def current(config): @@ -154,10 +161,10 @@ def current(config): print "Current revision for %s: %s" % ( util.obfuscate_url_pw( context.connection.engine.url), - script._get_rev(rev)) + script.get_revision(rev)) return [] - with environment.configure( + with EnvironmentContext( config, script, fn = display_version @@ -174,12 +181,12 @@ def stamp(config, revision, sql=False, tag=None): current = False else: current = context._current_rev() - dest = script._get_rev(revision) + dest = script.get_revision(revision) if dest is not None: dest = dest.revision context._update_current_rev(current, dest) return [] - with environment.configure( + with EnvironmentContext( config, script, fn = do_stamp, |