summaryrefslogtreecommitdiff
path: root/alembic/testing/fixtures.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/testing/fixtures.py')
-rw-r--r--alembic/testing/fixtures.py95
1 files changed, 50 insertions, 45 deletions
diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py
index 4091388..ae25fd2 100644
--- a/alembic/testing/fixtures.py
+++ b/alembic/testing/fixtures.py
@@ -88,39 +88,9 @@ def capture_context_buffer(**kw):
yield buf
-def op_fixture(dialect='default', as_sql=False, naming_convention=None):
- impl = _impls[dialect]
-
- class Impl(impl):
-
- def __init__(self, dialect, as_sql):
- self.assertion = []
- self.dialect = dialect
- self.as_sql = as_sql
- # TODO: this might need to
- # be more like a real connection
- # as tests get more involved
- if as_sql and self.dialect.name != 'default':
- # act similarly to MigrationContext
- def dump(construct, *multiparams, **params):
- self._exec(construct)
-
- self.connection = create_engine(
- "%s://" % self.dialect.name,
- strategy="mock", executor=dump)
-
- else:
- self.connection = mock.Mock(dialect=dialect)
-
- def _exec(self, construct, *args, **kw):
- if isinstance(construct, string_types):
- construct = text(construct)
- assert construct.supports_execution
- sql = text_type(construct.compile(dialect=self.dialect))
- sql = re.sub(r'[\n\t]', '', sql)
- self.assertion.append(
- sql
- )
+def op_fixture(
+ dialect='default', as_sql=False,
+ naming_convention=None, literal_binds=False):
opts = {}
if naming_convention:
@@ -130,32 +100,67 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None):
"sqla 0.9.2 or greater")
opts['target_metadata'] = MetaData(naming_convention=naming_convention)
- class ctx(MigrationContext):
+ class buffer_(object):
+ def __init__(self):
+ self.lines = []
+
+ def write(self, msg):
+ msg = msg.strip()
+ msg = re.sub(r'[\n\t]', '', msg)
+ if as_sql:
+ # the impl produces soft tabs,
+ # so search for blocks of 4 spaces
+ msg = re.sub(r' ', '', msg)
+ msg = re.sub('\;\n*$', '', msg)
+
+ self.lines.append(msg)
+
+ def flush(self):
+ pass
- def __init__(self, dialect='default', as_sql=False):
- self.dialect = _get_dialect(dialect)
- self.impl = Impl(self.dialect, as_sql)
- self.opts = opts
- self.as_sql = as_sql
+ buf = buffer_()
+ class ctx(MigrationContext):
def clear_assertions(self):
- self.impl.assertion[:] = []
+ buf.lines[:] = []
def assert_(self, *sql):
# TODO: make this more flexible about
# whitespace and such
- eq_(self.impl.assertion, list(sql))
+ eq_(buf.lines, list(sql))
def assert_contains(self, sql):
- for stmt in self.impl.assertion:
+ for stmt in buf.lines:
if sql in stmt:
return
else:
assert False, "Could not locate fragment %r in %r" % (
sql,
- self.impl.assertion
+ buf.lines
)
- context = ctx(dialect, as_sql)
+
+ if as_sql:
+ opts['as_sql'] = as_sql
+ if literal_binds:
+ opts['literal_binds'] = literal_binds
+ ctx_dialect = _get_dialect(dialect)
+ if not as_sql:
+ def execute(stmt, *multiparam, **param):
+ if isinstance(stmt, string_types):
+ stmt = text(stmt)
+ assert stmt.supports_execution
+ sql = text_type(stmt.compile(dialect=ctx_dialect))
+
+ buf.write(sql)
+
+ connection = mock.Mock(dialect=ctx_dialect, execute=execute)
+ else:
+ opts['output_buffer'] = buf
+ connection = None
+ context = ctx(
+ ctx_dialect,
+ connection,
+ opts)
+
alembic.op._proxy = Operations(context)
return context
-