summaryrefslogtreecommitdiff
path: root/alembic/context.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-04-30 15:47:18 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2010-04-30 15:47:18 -0400
commit4cf13f69a435f6a3405eb926f553b1aca2b32227 (patch)
tree747fc284b665a4cc388c07467a326e188c533d6c /alembic/context.py
parentbeddcdf1659d8427d0f045a4e4b000001a73c2e5 (diff)
downloadalembic-4cf13f69a435f6a3405eb926f553b1aca2b32227.tar.gz
- sqlite dialect
- SQL text mode - some methods to help with upcoming tests
Diffstat (limited to 'alembic/context.py')
-rw-r--r--alembic/context.py42
1 files changed, 33 insertions, 9 deletions
diff --git a/alembic/context.py b/alembic/context.py
index 9694190..176388b 100644
--- a/alembic/context.py
+++ b/alembic/context.py
@@ -1,6 +1,7 @@
from alembic.ddl import base
from alembic import util
-from sqlalchemy import MetaData, Table, Column, String
+from sqlalchemy import MetaData, Table, Column, String, literal_column, text
+from sqlalchemy.schema import CreateTable
import logging
log = logging.getLogger(__name__)
@@ -24,13 +25,20 @@ class DefaultContext(object):
__dialect__ = 'default'
transactional_ddl = False
+ as_sql = False
- def __init__(self, connection, fn):
+ def __init__(self, connection, fn, as_sql=False):
self.connection = connection
self._migrations_fn = fn
+ self.as_sql = as_sql
def _current_rev(self):
- _version.create(self.connection, checkfirst=True)
+ if self.as_sql:
+ if not self.connection.dialect.has_table(self.connection, 'alembic_version'):
+ self._exec(CreateTable(_version))
+ return None
+ else:
+ _version.create(self.connection, checkfirst=True)
return self.connection.scalar(_version.select())
def _update_current_rev(self, old, new):
@@ -38,17 +46,21 @@ class DefaultContext(object):
return
if new is None:
- self.connection.execute(_version.delete())
+ self._exec(_version.delete())
elif old is None:
- self.connection.execute(_version.insert(), {'version_num':new})
+ self._exec(_version.insert().values(version_num=literal_column("'%s'" % new)))
else:
- self.connection.execute(_version.update(), {'version_num':new})
+ self._exec(_version.update().values(version_num=literal_column("'%s'" % new)))
def run_migrations(self, **kw):
log.info("Context class %s.", self.__class__.__name__)
log.info("Will assume %s DDL.",
"transactional" if self.transactional_ddl
else "non-transactional")
+
+ if self.as_sql and self.transactional_ddl:
+ print "BEGIN;\n"
+
current_rev = prev_rev = rev = self._current_rev()
for change, rev in self._migrations_fn(current_rev):
log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
@@ -60,8 +72,16 @@ class DefaultContext(object):
if self.transactional_ddl:
self._update_current_rev(current_rev, rev)
+ if self.as_sql and self.transactional_ddl:
+ print "COMMIT;\n"
+
def _exec(self, construct):
- self.connection.execute(construct)
+ if isinstance(construct, basestring):
+ construct = text(construct)
+ if self.as_sql:
+ print unicode(construct.compile(dialect=self.connection.dialect)).replace("\t", " ") + ";"
+ else:
+ self.connection.execute(construct)
def execute(self, sql):
self._exec(sql)
@@ -83,10 +103,14 @@ class DefaultContext(object):
def add_constraint(self, const):
self._exec(schema.AddConstraint(const))
-
+def opts(cfg, **kw):
+ global _context_opts, config
+ _context_opts = kw
+ config = cfg
+
def configure_connection(connection):
global _context
- _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, _migration_fn)
+ _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, **_context_opts)
def run_migrations(**kw):
_context.run_migrations(**kw)