diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-04-30 15:47:18 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-04-30 15:47:18 -0400 |
commit | 4cf13f69a435f6a3405eb926f553b1aca2b32227 (patch) | |
tree | 747fc284b665a4cc388c07467a326e188c533d6c /alembic/context.py | |
parent | beddcdf1659d8427d0f045a4e4b000001a73c2e5 (diff) | |
download | alembic-4cf13f69a435f6a3405eb926f553b1aca2b32227.tar.gz |
- sqlite dialect
- SQL text mode
- some methods to help with upcoming tests
Diffstat (limited to 'alembic/context.py')
-rw-r--r-- | alembic/context.py | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/alembic/context.py b/alembic/context.py index 9694190..176388b 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -1,6 +1,7 @@ from alembic.ddl import base from alembic import util -from sqlalchemy import MetaData, Table, Column, String +from sqlalchemy import MetaData, Table, Column, String, literal_column, text +from sqlalchemy.schema import CreateTable import logging log = logging.getLogger(__name__) @@ -24,13 +25,20 @@ class DefaultContext(object): __dialect__ = 'default' transactional_ddl = False + as_sql = False - def __init__(self, connection, fn): + def __init__(self, connection, fn, as_sql=False): self.connection = connection self._migrations_fn = fn + self.as_sql = as_sql def _current_rev(self): - _version.create(self.connection, checkfirst=True) + if self.as_sql: + if not self.connection.dialect.has_table(self.connection, 'alembic_version'): + self._exec(CreateTable(_version)) + return None + else: + _version.create(self.connection, checkfirst=True) return self.connection.scalar(_version.select()) def _update_current_rev(self, old, new): @@ -38,17 +46,21 @@ class DefaultContext(object): return if new is None: - self.connection.execute(_version.delete()) + self._exec(_version.delete()) elif old is None: - self.connection.execute(_version.insert(), {'version_num':new}) + self._exec(_version.insert().values(version_num=literal_column("'%s'" % new))) else: - self.connection.execute(_version.update(), {'version_num':new}) + self._exec(_version.update().values(version_num=literal_column("'%s'" % new))) def run_migrations(self, **kw): log.info("Context class %s.", self.__class__.__name__) log.info("Will assume %s DDL.", "transactional" if self.transactional_ddl else "non-transactional") + + if self.as_sql and self.transactional_ddl: + print "BEGIN;\n" + current_rev = prev_rev = rev = self._current_rev() for change, rev in self._migrations_fn(current_rev): log.info("Running %s %s -> %s", change.__name__, prev_rev, rev) @@ -60,8 +72,16 @@ class DefaultContext(object): if self.transactional_ddl: self._update_current_rev(current_rev, rev) + if self.as_sql and self.transactional_ddl: + print "COMMIT;\n" + def _exec(self, construct): - self.connection.execute(construct) + if isinstance(construct, basestring): + construct = text(construct) + if self.as_sql: + print unicode(construct.compile(dialect=self.connection.dialect)).replace("\t", " ") + ";" + else: + self.connection.execute(construct) def execute(self, sql): self._exec(sql) @@ -83,10 +103,14 @@ class DefaultContext(object): def add_constraint(self, const): self._exec(schema.AddConstraint(const)) - +def opts(cfg, **kw): + global _context_opts, config + _context_opts = kw + config = cfg + def configure_connection(connection): global _context - _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, _migration_fn) + _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, **_context_opts) def run_migrations(**kw): _context.run_migrations(**kw) |