summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/command.py30
-rw-r--r--alembic/config.py5
-rw-r--r--alembic/context.py42
-rw-r--r--alembic/ddl/sqlite.py5
-rw-r--r--alembic/script.py28
5 files changed, 83 insertions, 27 deletions
diff --git a/alembic/command.py b/alembic/command.py
index 050111a..3a4bff4 100644
--- a/alembic/command.py
+++ b/alembic/command.py
@@ -22,12 +22,12 @@ def init(config, directory, template='generic'):
"""Initialize a new scripts directory."""
if os.access(directory, os.F_OK):
- raise util.CommandException("Directory %s already exists" % directory)
+ raise util.CommandError("Directory %s already exists" % directory)
template_dir = os.path.join(config.get_template_directory(),
template)
if not os.access(template_dir, os.F_OK):
- raise util.CommandException("No such template %r" % template)
+ raise util.CommandError("No such template %r" % template)
util.status("Creating directory %s" % os.path.abspath(directory),
os.makedirs, directory)
@@ -65,20 +65,26 @@ def revision(config, message=None):
script = ScriptDirectory.from_config(config)
script.generate_rev(util.rev_id(), message)
-def upgrade(config, revision):
+def upgrade(config, revision, sql=False):
"""Upgrade to a later version."""
script = ScriptDirectory.from_config(config)
- context._migration_fn = functools.partial(script.upgrade_from, revision)
- context.config = config
+ context.opts(
+ config,
+ fn = functools.partial(script.upgrade_from, revision),
+ as_sql = sql
+ )
script.run_env()
-def downgrade(config, revision):
+def downgrade(config, revision, sql=False):
"""Revert to a previous version."""
script = ScriptDirectory.from_config(config)
- context._migration_fn = functools.partial(script.downgrade_to, revision)
- context.config = config
+ context.opts(
+ config,
+ fn = functools.partial(script.downgrade_to, revision),
+ as_sql = sql,
+ )
script.run_env()
def history(config):
@@ -107,9 +113,11 @@ def current(config):
context.get_context().connection.engine.url),
script._get_rev(rev))
return []
-
- context._migration_fn = display_version
- context.config = config
+
+ context.opts(
+ config,
+ fn = display_version
+ )
script.run_env()
def splice(config, parent, child):
diff --git a/alembic/config.py b/alembic/config.py
index d244193..f867338 100644
--- a/alembic/config.py
+++ b/alembic/config.py
@@ -79,7 +79,7 @@ def main(argv):
format_opt(cmd)
for cmd in commands.values()
])) +
- "\n\n<revision> is a hex revision id or 'head'"
+ "\n\n<revision> is a hex revision id, 'head' or 'base'."
)
parser.add_option("-c", "--config",
@@ -93,6 +93,9 @@ def main(argv):
parser.add_option("-m", "--message",
type="string",
help="Message string to use with 'revision'")
+ parser.add_option("--sql",
+ action="store_true",
+ help="Dump output to a SQL file")
cmd_line_options, cmd_line_args = parser.parse_args(argv[1:])
diff --git a/alembic/context.py b/alembic/context.py
index 9694190..176388b 100644
--- a/alembic/context.py
+++ b/alembic/context.py
@@ -1,6 +1,7 @@
from alembic.ddl import base
from alembic import util
-from sqlalchemy import MetaData, Table, Column, String
+from sqlalchemy import MetaData, Table, Column, String, literal_column, text
+from sqlalchemy.schema import CreateTable
import logging
log = logging.getLogger(__name__)
@@ -24,13 +25,20 @@ class DefaultContext(object):
__dialect__ = 'default'
transactional_ddl = False
+ as_sql = False
- def __init__(self, connection, fn):
+ def __init__(self, connection, fn, as_sql=False):
self.connection = connection
self._migrations_fn = fn
+ self.as_sql = as_sql
def _current_rev(self):
- _version.create(self.connection, checkfirst=True)
+ if self.as_sql:
+ if not self.connection.dialect.has_table(self.connection, 'alembic_version'):
+ self._exec(CreateTable(_version))
+ return None
+ else:
+ _version.create(self.connection, checkfirst=True)
return self.connection.scalar(_version.select())
def _update_current_rev(self, old, new):
@@ -38,17 +46,21 @@ class DefaultContext(object):
return
if new is None:
- self.connection.execute(_version.delete())
+ self._exec(_version.delete())
elif old is None:
- self.connection.execute(_version.insert(), {'version_num':new})
+ self._exec(_version.insert().values(version_num=literal_column("'%s'" % new)))
else:
- self.connection.execute(_version.update(), {'version_num':new})
+ self._exec(_version.update().values(version_num=literal_column("'%s'" % new)))
def run_migrations(self, **kw):
log.info("Context class %s.", self.__class__.__name__)
log.info("Will assume %s DDL.",
"transactional" if self.transactional_ddl
else "non-transactional")
+
+ if self.as_sql and self.transactional_ddl:
+ print "BEGIN;\n"
+
current_rev = prev_rev = rev = self._current_rev()
for change, rev in self._migrations_fn(current_rev):
log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
@@ -60,8 +72,16 @@ class DefaultContext(object):
if self.transactional_ddl:
self._update_current_rev(current_rev, rev)
+ if self.as_sql and self.transactional_ddl:
+ print "COMMIT;\n"
+
def _exec(self, construct):
- self.connection.execute(construct)
+ if isinstance(construct, basestring):
+ construct = text(construct)
+ if self.as_sql:
+ print unicode(construct.compile(dialect=self.connection.dialect)).replace("\t", " ") + ";"
+ else:
+ self.connection.execute(construct)
def execute(self, sql):
self._exec(sql)
@@ -83,10 +103,14 @@ class DefaultContext(object):
def add_constraint(self, const):
self._exec(schema.AddConstraint(const))
-
+def opts(cfg, **kw):
+ global _context_opts, config
+ _context_opts = kw
+ config = cfg
+
def configure_connection(connection):
global _context
- _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, _migration_fn)
+ _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, **_context_opts)
def run_migrations(**kw):
_context.run_migrations(**kw)
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
new file mode 100644
index 0000000..20ec1eb
--- /dev/null
+++ b/alembic/ddl/sqlite.py
@@ -0,0 +1,5 @@
+from alembic.context import DefaultContext
+
+class SQLiteContext(DefaultContext):
+ __dialect__ = 'sqlite'
+ transactional_ddl = True
diff --git a/alembic/script.py b/alembic/script.py
index d0de004..541b32f 100644
--- a/alembic/script.py
+++ b/alembic/script.py
@@ -82,7 +82,7 @@ class ScriptDirectory(object):
def _revision_map(self):
map_ = {}
for file_ in os.listdir(self.versions):
- script = Script.from_path(self.versions, file_)
+ script = Script.from_filename(self.versions, file_)
if script is None:
continue
if script.revision in map_:
@@ -100,6 +100,18 @@ class ScriptDirectory(object):
map_[None] = None
return map_
+ def rev_path(self, rev_id):
+ filename = "%s.py" % rev_id
+ return os.path.join(self.versions, filename)
+
+ def refresh(self, rev_id):
+ script = Script.from_path(self.rev_path(rev_id))
+ old = self._revision_map[script.revision]
+ if old.down_revision != script.down_revision:
+ raise Exception("Can't change down_revision on a refresh operation.")
+ self._revision_map[script.revision] = script
+ script.nextrev = old.nextrev
+
def _current_head(self):
current_heads = self._get_heads()
if len(current_heads) > 1:
@@ -139,16 +151,16 @@ class ScriptDirectory(object):
def generate_rev(self, revid, message):
current_head = self._current_head()
- filename = "%s.py" % revid
+ path = self.rev_path(revid)
self.generate_template(
os.path.join(self.dir, "script.py.mako"),
- os.path.join(self.versions, filename),
+ path,
up_revision=str(revid),
down_revision=current_head,
create_date=datetime.datetime.now(),
message=message if message is not None else ("empty message")
)
- script = Script.from_path(self.versions, filename)
+ script = Script.from_path(path)
self._revision_map[script.revision] = script
if script.down_revision:
self._revision_map[script.down_revision].add_nextrev(script.revision)
@@ -186,11 +198,15 @@ class Script(object):
self.doc)
@classmethod
- def from_path(cls, dir_, filename):
+ def from_path(cls, path):
+ dir_, filename = os.path.split(path)
+ return cls.from_filename(dir_, filename)
+
+ @classmethod
+ def from_filename(cls, dir_, filename):
m = _rev_file.match(filename)
if not m:
return None
-
module = util.load_python_file(dir_, filename)
return Script(module, m.group(1))
\ No newline at end of file