diff options
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r-- | test/testlib/testing.py | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 3b7b2992e..5c37f95bd 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -7,10 +7,10 @@ from cStringIO import StringIO import testlib.config as config from testlib.compat import * -sql, MetaData, clear_mappers, Session, util = None, None, None, None, None +sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None sa_exceptions = None -__all__ = ('PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest') +__all__ = ('PersistTest', 'AssertMixin', 'ComparesTables', 'ORMTest', 'SQLCompileTest') _ops = { '<': operator.lt, '>': operator.gt, @@ -485,10 +485,41 @@ class SQLCompileTest(PersistTest): if checkparams is not None: self.assertEquals(c.construct_params(params), checkparams) +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table): + global sqltypes, schema + if sqltypes is None: + import sqlalchemy.types as sqltypes + if schema is None: + import sqlalchemy.schema as schema + base_mro = sqltypes.TypeEngine.__mro__ + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + self.assertEquals(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + self.assertEquals(c.primary_key, reflected_c.primary_key) + self.assertEquals(c.nullable, reflected_c.nullable) + assert len( + set(type(reflected_c.type).__mro__).difference(base_mro).intersection( + set(type(c.type).__mro__).difference(base_mro) + ) + ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + + if isinstance(c.type, sqltypes.String): + self.assertEquals(c.type.length, reflected_c.type.length) + + self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + if c.default: + assert isinstance(reflected_c.default, schema.PassiveDefault) + elif not c.primary_key or not against('postgres'): + assert reflected_c.default is None + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] + + class AssertMixin(PersistTest): - """given a list-based structure of keys/properties which represent information within an object structure, and - a list of actual objects, asserts that the list of objects corresponds to the structure.""" - def assert_result(self, result, class_, *objects): result = list(result) print repr(result) |