diff options
Diffstat (limited to 'alembic/testing/fixtures.py')
-rw-r--r-- | alembic/testing/fixtures.py | 95 |
1 files changed, 50 insertions, 45 deletions
diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 4091388..ae25fd2 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -88,39 +88,9 @@ def capture_context_buffer(**kw): yield buf -def op_fixture(dialect='default', as_sql=False, naming_convention=None): - impl = _impls[dialect] - - class Impl(impl): - - def __init__(self, dialect, as_sql): - self.assertion = [] - self.dialect = dialect - self.as_sql = as_sql - # TODO: this might need to - # be more like a real connection - # as tests get more involved - if as_sql and self.dialect.name != 'default': - # act similarly to MigrationContext - def dump(construct, *multiparams, **params): - self._exec(construct) - - self.connection = create_engine( - "%s://" % self.dialect.name, - strategy="mock", executor=dump) - - else: - self.connection = mock.Mock(dialect=dialect) - - def _exec(self, construct, *args, **kw): - if isinstance(construct, string_types): - construct = text(construct) - assert construct.supports_execution - sql = text_type(construct.compile(dialect=self.dialect)) - sql = re.sub(r'[\n\t]', '', sql) - self.assertion.append( - sql - ) +def op_fixture( + dialect='default', as_sql=False, + naming_convention=None, literal_binds=False): opts = {} if naming_convention: @@ -130,32 +100,67 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None): "sqla 0.9.2 or greater") opts['target_metadata'] = MetaData(naming_convention=naming_convention) - class ctx(MigrationContext): + class buffer_(object): + def __init__(self): + self.lines = [] + + def write(self, msg): + msg = msg.strip() + msg = re.sub(r'[\n\t]', '', msg) + if as_sql: + # the impl produces soft tabs, + # so search for blocks of 4 spaces + msg = re.sub(r' ', '', msg) + msg = re.sub('\;\n*$', '', msg) + + self.lines.append(msg) + + def flush(self): + pass - def __init__(self, dialect='default', as_sql=False): - self.dialect = _get_dialect(dialect) - self.impl = Impl(self.dialect, as_sql) - self.opts = opts - self.as_sql = as_sql + buf = buffer_() + class ctx(MigrationContext): def clear_assertions(self): - self.impl.assertion[:] = [] + buf.lines[:] = [] def assert_(self, *sql): # TODO: make this more flexible about # whitespace and such - eq_(self.impl.assertion, list(sql)) + eq_(buf.lines, list(sql)) def assert_contains(self, sql): - for stmt in self.impl.assertion: + for stmt in buf.lines: if sql in stmt: return else: assert False, "Could not locate fragment %r in %r" % ( sql, - self.impl.assertion + buf.lines ) - context = ctx(dialect, as_sql) + + if as_sql: + opts['as_sql'] = as_sql + if literal_binds: + opts['literal_binds'] = literal_binds + ctx_dialect = _get_dialect(dialect) + if not as_sql: + def execute(stmt, *multiparam, **param): + if isinstance(stmt, string_types): + stmt = text(stmt) + assert stmt.supports_execution + sql = text_type(stmt.compile(dialect=ctx_dialect)) + + buf.write(sql) + + connection = mock.Mock(dialect=ctx_dialect, execute=execute) + else: + opts['output_buffer'] = buf + connection = None + context = ctx( + ctx_dialect, + connection, + opts) + alembic.op._proxy = Operations(context) return context - |