diff options
-rw-r--r-- | alembic/command.py | 30 | ||||
-rw-r--r-- | alembic/config.py | 5 | ||||
-rw-r--r-- | alembic/context.py | 42 | ||||
-rw-r--r-- | alembic/ddl/sqlite.py | 5 | ||||
-rw-r--r-- | alembic/script.py | 28 |
5 files changed, 83 insertions, 27 deletions
diff --git a/alembic/command.py b/alembic/command.py index 050111a..3a4bff4 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -22,12 +22,12 @@ def init(config, directory, template='generic'): """Initialize a new scripts directory.""" if os.access(directory, os.F_OK): - raise util.CommandException("Directory %s already exists" % directory) + raise util.CommandError("Directory %s already exists" % directory) template_dir = os.path.join(config.get_template_directory(), template) if not os.access(template_dir, os.F_OK): - raise util.CommandException("No such template %r" % template) + raise util.CommandError("No such template %r" % template) util.status("Creating directory %s" % os.path.abspath(directory), os.makedirs, directory) @@ -65,20 +65,26 @@ def revision(config, message=None): script = ScriptDirectory.from_config(config) script.generate_rev(util.rev_id(), message) -def upgrade(config, revision): +def upgrade(config, revision, sql=False): """Upgrade to a later version.""" script = ScriptDirectory.from_config(config) - context._migration_fn = functools.partial(script.upgrade_from, revision) - context.config = config + context.opts( + config, + fn = functools.partial(script.upgrade_from, revision), + as_sql = sql + ) script.run_env() -def downgrade(config, revision): +def downgrade(config, revision, sql=False): """Revert to a previous version.""" script = ScriptDirectory.from_config(config) - context._migration_fn = functools.partial(script.downgrade_to, revision) - context.config = config + context.opts( + config, + fn = functools.partial(script.downgrade_to, revision), + as_sql = sql, + ) script.run_env() def history(config): @@ -107,9 +113,11 @@ def current(config): context.get_context().connection.engine.url), script._get_rev(rev)) return [] - - context._migration_fn = display_version - context.config = config + + context.opts( + config, + fn = display_version + ) script.run_env() def splice(config, parent, child): diff --git a/alembic/config.py b/alembic/config.py index d244193..f867338 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -79,7 +79,7 @@ def main(argv): format_opt(cmd) for cmd in commands.values() ])) + - "\n\n<revision> is a hex revision id or 'head'" + "\n\n<revision> is a hex revision id, 'head' or 'base'." ) parser.add_option("-c", "--config", @@ -93,6 +93,9 @@ def main(argv): parser.add_option("-m", "--message", type="string", help="Message string to use with 'revision'") + parser.add_option("--sql", + action="store_true", + help="Dump output to a SQL file") cmd_line_options, cmd_line_args = parser.parse_args(argv[1:]) diff --git a/alembic/context.py b/alembic/context.py index 9694190..176388b 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -1,6 +1,7 @@ from alembic.ddl import base from alembic import util -from sqlalchemy import MetaData, Table, Column, String +from sqlalchemy import MetaData, Table, Column, String, literal_column, text +from sqlalchemy.schema import CreateTable import logging log = logging.getLogger(__name__) @@ -24,13 +25,20 @@ class DefaultContext(object): __dialect__ = 'default' transactional_ddl = False + as_sql = False - def __init__(self, connection, fn): + def __init__(self, connection, fn, as_sql=False): self.connection = connection self._migrations_fn = fn + self.as_sql = as_sql def _current_rev(self): - _version.create(self.connection, checkfirst=True) + if self.as_sql: + if not self.connection.dialect.has_table(self.connection, 'alembic_version'): + self._exec(CreateTable(_version)) + return None + else: + _version.create(self.connection, checkfirst=True) return self.connection.scalar(_version.select()) def _update_current_rev(self, old, new): @@ -38,17 +46,21 @@ class DefaultContext(object): return if new is None: - self.connection.execute(_version.delete()) + self._exec(_version.delete()) elif old is None: - self.connection.execute(_version.insert(), {'version_num':new}) + self._exec(_version.insert().values(version_num=literal_column("'%s'" % new))) else: - self.connection.execute(_version.update(), {'version_num':new}) + self._exec(_version.update().values(version_num=literal_column("'%s'" % new))) def run_migrations(self, **kw): log.info("Context class %s.", self.__class__.__name__) log.info("Will assume %s DDL.", "transactional" if self.transactional_ddl else "non-transactional") + + if self.as_sql and self.transactional_ddl: + print "BEGIN;\n" + current_rev = prev_rev = rev = self._current_rev() for change, rev in self._migrations_fn(current_rev): log.info("Running %s %s -> %s", change.__name__, prev_rev, rev) @@ -60,8 +72,16 @@ class DefaultContext(object): if self.transactional_ddl: self._update_current_rev(current_rev, rev) + if self.as_sql and self.transactional_ddl: + print "COMMIT;\n" + def _exec(self, construct): - self.connection.execute(construct) + if isinstance(construct, basestring): + construct = text(construct) + if self.as_sql: + print unicode(construct.compile(dialect=self.connection.dialect)).replace("\t", " ") + ";" + else: + self.connection.execute(construct) def execute(self, sql): self._exec(sql) @@ -83,10 +103,14 @@ class DefaultContext(object): def add_constraint(self, const): self._exec(schema.AddConstraint(const)) - +def opts(cfg, **kw): + global _context_opts, config + _context_opts = kw + config = cfg + def configure_connection(connection): global _context - _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, _migration_fn) + _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, **_context_opts) def run_migrations(**kw): _context.run_migrations(**kw) diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py new file mode 100644 index 0000000..20ec1eb --- /dev/null +++ b/alembic/ddl/sqlite.py @@ -0,0 +1,5 @@ +from alembic.context import DefaultContext + +class SQLiteContext(DefaultContext): + __dialect__ = 'sqlite' + transactional_ddl = True diff --git a/alembic/script.py b/alembic/script.py index d0de004..541b32f 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -82,7 +82,7 @@ class ScriptDirectory(object): def _revision_map(self): map_ = {} for file_ in os.listdir(self.versions): - script = Script.from_path(self.versions, file_) + script = Script.from_filename(self.versions, file_) if script is None: continue if script.revision in map_: @@ -100,6 +100,18 @@ class ScriptDirectory(object): map_[None] = None return map_ + def rev_path(self, rev_id): + filename = "%s.py" % rev_id + return os.path.join(self.versions, filename) + + def refresh(self, rev_id): + script = Script.from_path(self.rev_path(rev_id)) + old = self._revision_map[script.revision] + if old.down_revision != script.down_revision: + raise Exception("Can't change down_revision on a refresh operation.") + self._revision_map[script.revision] = script + script.nextrev = old.nextrev + def _current_head(self): current_heads = self._get_heads() if len(current_heads) > 1: @@ -139,16 +151,16 @@ class ScriptDirectory(object): def generate_rev(self, revid, message): current_head = self._current_head() - filename = "%s.py" % revid + path = self.rev_path(revid) self.generate_template( os.path.join(self.dir, "script.py.mako"), - os.path.join(self.versions, filename), + path, up_revision=str(revid), down_revision=current_head, create_date=datetime.datetime.now(), message=message if message is not None else ("empty message") ) - script = Script.from_path(self.versions, filename) + script = Script.from_path(path) self._revision_map[script.revision] = script if script.down_revision: self._revision_map[script.down_revision].add_nextrev(script.revision) @@ -186,11 +198,15 @@ class Script(object): self.doc) @classmethod - def from_path(cls, dir_, filename): + def from_path(cls, path): + dir_, filename = os.path.split(path) + return cls.from_filename(dir_, filename) + + @classmethod + def from_filename(cls, dir_, filename): m = _rev_file.match(filename) if not m: return None - module = util.load_python_file(dir_, filename) return Script(module, m.group(1))
\ No newline at end of file |