summaryrefslogtreecommitdiff
path: root/test/testlib/testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r--test/testlib/testing.py41
1 files changed, 36 insertions, 5 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index 3b7b2992e..5c37f95bd 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -7,10 +7,10 @@ from cStringIO import StringIO
import testlib.config as config
from testlib.compat import *
-sql, MetaData, clear_mappers, Session, util = None, None, None, None, None
+sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None
sa_exceptions = None
-__all__ = ('PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest')
+__all__ = ('PersistTest', 'AssertMixin', 'ComparesTables', 'ORMTest', 'SQLCompileTest')
_ops = { '<': operator.lt,
'>': operator.gt,
@@ -485,10 +485,41 @@ class SQLCompileTest(PersistTest):
if checkparams is not None:
self.assertEquals(c.construct_params(params), checkparams)
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table):
+ global sqltypes, schema
+ if sqltypes is None:
+ import sqlalchemy.types as sqltypes
+ if schema is None:
+ import sqlalchemy.schema as schema
+ base_mro = sqltypes.TypeEngine.__mro__
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ self.assertEquals(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ self.assertEquals(c.primary_key, reflected_c.primary_key)
+ self.assertEquals(c.nullable, reflected_c.nullable)
+ assert len(
+ set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
+ set(type(c.type).__mro__).difference(base_mro)
+ )
+ ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+
+ if isinstance(c.type, sqltypes.String):
+ self.assertEquals(c.type.length, reflected_c.type.length)
+
+ self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ if c.default:
+ assert isinstance(reflected_c.default, schema.PassiveDefault)
+ elif not c.primary_key or not against('postgres'):
+ assert reflected_c.default is None
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name]
+
+
class AssertMixin(PersistTest):
- """given a list-based structure of keys/properties which represent information within an object structure, and
- a list of actual objects, asserts that the list of objects corresponds to the structure."""
-
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)