summaryrefslogtreecommitdiff
path: root/alembic/migration.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-01-24 13:42:43 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-01-24 13:42:43 -0500
commita851daaa1b2a4efa5990f55c3c97282cafdab9e1 (patch)
tree920819fdc141dd794f3a45af1f2c4d592d054d48 /alembic/migration.py
parent72416bcde500f48a66310eafc86c071ee2672d09 (diff)
downloadalembic-a851daaa1b2a4efa5990f55c3c97282cafdab9e1.tar.gz
this is all tests passing with the refactor, which IMHO is
miraculous
Diffstat (limited to 'alembic/migration.py')
-rw-r--r--alembic/migration.py72
1 files changed, 55 insertions, 17 deletions
diff --git a/alembic/migration.py b/alembic/migration.py
index 69e6930..733727d 100644
--- a/alembic/migration.py
+++ b/alembic/migration.py
@@ -4,7 +4,7 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \
from sqlalchemy import create_engine
from alembic import ddl
import sys
-from contextlib import contextmanager
+from sqlalchemy.engine import url as sqla_url
import logging
log = logging.getLogger(__name__)
@@ -21,22 +21,19 @@ class MigrationContext(object):
Mediates the relationship between an ``env.py`` environment script,
a :class:`.ScriptDirectory` instance, and a :class:`.DefaultImpl` instance.
- The :class:`.Context` is available directly via the :func:`.get_context` function,
+ The :class:`.MigrationContext` is available directly via the :func:`.get_context` function,
though usually it is referenced behind the scenes by the various module level functions
within the :mod:`alembic.context` module.
"""
- def __init__(self, dialect, script, connection,
- opts,
- as_sql=False,
- output_buffer=None,
- transactional_ddl=None,
- starting_rev=None,
- compare_type=False,
- compare_server_default=False):
+ def __init__(self, dialect, connection, opts):
+ self.opts = opts
self.dialect = dialect
- # TODO: need this ?
- self.script = script
+ self.script = opts.get('script')
+
+ as_sql=opts.get('as_sql', False)
+ transactional_ddl=opts.get("transactional_ddl")
+
if as_sql:
self.connection = self._stdout_connection(connection)
assert self.connection is not None
@@ -44,12 +41,12 @@ class MigrationContext(object):
self.connection = connection
self._migrations_fn = opts.get('fn')
self.as_sql = as_sql
- self.output_buffer = output_buffer if output_buffer else sys.stdout
+ self.output_buffer = opts.get("output_buffer", sys.stdout)
- self._user_compare_type = compare_type
- self._user_compare_server_default = compare_server_default
+ self._user_compare_type = opts.get('compare_type', False)
+ self._user_compare_server_default = opts.get('compare_server_default', False)
- self._start_from_rev = starting_rev
+ self._start_from_rev = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
dialect, self.connection, self.as_sql,
transactional_ddl,
@@ -63,6 +60,46 @@ class MigrationContext(object):
"transactional" if self.impl.transactional_ddl
else "non-transactional")
+ @classmethod
+ def configure(cls,
+ connection=None,
+ url=None,
+ dialect_name=None,
+ opts=None,
+ ):
+ """Create a new :class:`.MigrationContext`.
+
+ This is a factory method usually called
+ by :meth:`.EnvironmentContext.configure`.
+
+ :param connection: a :class:`~sqlalchemy.engine.base.Connection` to use
+ for SQL execution in "online" mode. When present, is also used to
+ determine the type of dialect in use.
+ :param url: a string database url, or a :class:`sqlalchemy.engine.url.URL` object.
+ The type of dialect to be used will be derived from this if ``connection`` is
+ not passed.
+ :param dialect_name: string name of a dialect, such as "postgresql", "mssql", etc.
+ The type of dialect to be used will be derived from this if ``connection``
+ and ``url`` are not passed.
+ :param opts: dictionary of options. Most other options
+ accepted by :meth:`.EnvironmentContext.configure` are passed via
+ this dictionary.
+
+ """
+ if connection:
+ dialect = connection.dialect
+ elif url:
+ url = sqla_url.make_url(url)
+ dialect = url.get_dialect()()
+ elif dialect_name:
+ url = sqla_url.make_url("%s://" % dialect_name)
+ dialect = url.get_dialect()()
+ else:
+ raise Exception("Connection, url, or dialect_name is required.")
+
+ return MigrationContext(dialect, connection, opts)
+
+
def _current_rev(self):
if self.as_sql:
return self._start_from_rev
@@ -93,7 +130,8 @@ class MigrationContext(object):
current_rev = rev = False
self.impl.start_migrations()
for change, prev_rev, rev in self._migrations_fn(
- self._current_rev()):
+ self._current_rev(),
+ self):
if current_rev is False:
current_rev = prev_rev
if self.as_sql and not current_rev: