diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-01-24 13:42:43 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-01-24 13:42:43 -0500 |
commit | a851daaa1b2a4efa5990f55c3c97282cafdab9e1 (patch) | |
tree | 920819fdc141dd794f3a45af1f2c4d592d054d48 /alembic/migration.py | |
parent | 72416bcde500f48a66310eafc86c071ee2672d09 (diff) | |
download | alembic-a851daaa1b2a4efa5990f55c3c97282cafdab9e1.tar.gz |
this is all tests passing with the refactor, which IMHO is
miraculous
Diffstat (limited to 'alembic/migration.py')
-rw-r--r-- | alembic/migration.py | 72 |
1 files changed, 55 insertions, 17 deletions
diff --git a/alembic/migration.py b/alembic/migration.py index 69e6930..733727d 100644 --- a/alembic/migration.py +++ b/alembic/migration.py @@ -4,7 +4,7 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \ from sqlalchemy import create_engine from alembic import ddl import sys -from contextlib import contextmanager +from sqlalchemy.engine import url as sqla_url import logging log = logging.getLogger(__name__) @@ -21,22 +21,19 @@ class MigrationContext(object): Mediates the relationship between an ``env.py`` environment script, a :class:`.ScriptDirectory` instance, and a :class:`.DefaultImpl` instance. - The :class:`.Context` is available directly via the :func:`.get_context` function, + The :class:`.MigrationContext` is available directly via the :func:`.get_context` function, though usually it is referenced behind the scenes by the various module level functions within the :mod:`alembic.context` module. """ - def __init__(self, dialect, script, connection, - opts, - as_sql=False, - output_buffer=None, - transactional_ddl=None, - starting_rev=None, - compare_type=False, - compare_server_default=False): + def __init__(self, dialect, connection, opts): + self.opts = opts self.dialect = dialect - # TODO: need this ? - self.script = script + self.script = opts.get('script') + + as_sql=opts.get('as_sql', False) + transactional_ddl=opts.get("transactional_ddl") + if as_sql: self.connection = self._stdout_connection(connection) assert self.connection is not None @@ -44,12 +41,12 @@ class MigrationContext(object): self.connection = connection self._migrations_fn = opts.get('fn') self.as_sql = as_sql - self.output_buffer = output_buffer if output_buffer else sys.stdout + self.output_buffer = opts.get("output_buffer", sys.stdout) - self._user_compare_type = compare_type - self._user_compare_server_default = compare_server_default + self._user_compare_type = opts.get('compare_type', False) + self._user_compare_server_default = opts.get('compare_server_default', False) - self._start_from_rev = starting_rev + self._start_from_rev = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( dialect, self.connection, self.as_sql, transactional_ddl, @@ -63,6 +60,46 @@ class MigrationContext(object): "transactional" if self.impl.transactional_ddl else "non-transactional") + @classmethod + def configure(cls, + connection=None, + url=None, + dialect_name=None, + opts=None, + ): + """Create a new :class:`.MigrationContext`. + + This is a factory method usually called + by :meth:`.EnvironmentContext.configure`. + + :param connection: a :class:`~sqlalchemy.engine.base.Connection` to use + for SQL execution in "online" mode. When present, is also used to + determine the type of dialect in use. + :param url: a string database url, or a :class:`sqlalchemy.engine.url.URL` object. + The type of dialect to be used will be derived from this if ``connection`` is + not passed. + :param dialect_name: string name of a dialect, such as "postgresql", "mssql", etc. + The type of dialect to be used will be derived from this if ``connection`` + and ``url`` are not passed. + :param opts: dictionary of options. Most other options + accepted by :meth:`.EnvironmentContext.configure` are passed via + this dictionary. + + """ + if connection: + dialect = connection.dialect + elif url: + url = sqla_url.make_url(url) + dialect = url.get_dialect()() + elif dialect_name: + url = sqla_url.make_url("%s://" % dialect_name) + dialect = url.get_dialect()() + else: + raise Exception("Connection, url, or dialect_name is required.") + + return MigrationContext(dialect, connection, opts) + + def _current_rev(self): if self.as_sql: return self._start_from_rev @@ -93,7 +130,8 @@ class MigrationContext(object): current_rev = rev = False self.impl.start_migrations() for change, prev_rev, rev in self._migrations_fn( - self._current_rev()): + self._current_rev(), + self): if current_rev is False: current_rev = prev_rev if self.as_sql and not current_rev: |