summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r--lib/sqlalchemy/testing/assertions.py182
1 files changed, 110 insertions, 72 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e42376921..73ab4556a 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -86,6 +86,7 @@ def emits_warning_on(db, *messages):
were in fact seen.
"""
+
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
@@ -114,12 +115,14 @@ def uses_deprecated(*messages):
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
+
return decorate
@contextlib.contextmanager
-def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
- py2konly=False):
+def _expect_warnings(
+ exc_cls, messages, regex=True, assert_=True, py2konly=False
+):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
@@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or \
- (not regex and filter_ == msg):
+ if (regex and filter_.match(msg)) or (
+ not regex and filter_ == msg
+ ):
seen.discard(filter_)
break
else:
@@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
yield
if assert_ and (not py2konly or not compat.py3k):
- assert not seen, "Warnings were not seen: %s" % \
- ", ".join("%r" % (s.pattern if regex else s) for s in seen)
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
def global_cleanup_assertions():
@@ -170,6 +175,7 @@ def global_cleanup_assertions():
"""
_assert_no_stray_pool_connections()
+
_STRAY_CONNECTION_FAILURES = 0
@@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections():
# OK, let's be somewhat forgiving.
_STRAY_CONNECTION_FAILURES += 1
- print("Encountered a stray connection in test cleanup: %s"
- % str(pool._refs))
+ print(
+ "Encountered a stray connection in test cleanup: %s"
+ % str(pool._refs)
+ )
# then do a real GC sweep. We shouldn't even be here
# so a single sweep should really be doing it, otherwise
# there's probably a real unreachable cycle somewhere.
@@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections():
pool._refs.clear()
_STRAY_CONNECTION_FAILURES = 0
warnings.warn(
- "Stray connection refused to leave "
- "after gc.collect(): %s" % err)
+ "Stray connection refused to leave " "after gc.collect(): %s" % err
+ )
elif _STRAY_CONNECTION_FAILURES > 10:
assert False, "Encountered more than 10 stray connections"
_STRAY_CONNECTION_FAILURES = 0
@@ -263,14 +271,16 @@ def not_in_(a, b, msg=None):
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
+ a,
+ fragment,
+ )
def eq_ignore_whitespace(a, b, msg=None):
- a = re.sub(r'^\s+?|\n', "", a)
- a = re.sub(r' {2,}', " ", a)
- b = re.sub(r'^\s+?|\n', "", b)
- b = re.sub(r' {2,}', " ", b)
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
@@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
- assert re.search(
- msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(util.text_type(e).encode('utf-8'))
+ assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
+ msg,
+ e,
+ )
+ print(util.text_type(e).encode("utf-8"))
+
class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- check_prefetch=None,
- use_default_dialect=False,
- allow_dialect_select=False,
- literal_binds=False,
- schema_translate_map=None):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ literal_binds=False,
+ schema_translate_map=None,
+ ):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
- dialect = getattr(self, '__dialect__', None)
+ dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
- elif dialect == 'default':
+ elif dialect == "default":
dialect = default.DefaultDialect()
- elif dialect == 'default_enhanced':
+ elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
@@ -325,13 +344,13 @@ class AssertsCompiledSQL(object):
compile_kwargs = {}
if schema_translate_map:
- kw['schema_translate_map'] = schema_translate_map
+ kw["schema_translate_map"] = schema_translate_map
if params is not None:
- kw['column_keys'] = list(params)
+ kw["column_keys"] = list(params)
if literal_binds:
- compile_kwargs['literal_binds'] = True
+ compile_kwargs["literal_binds"] = True
if isinstance(clause, orm.Query):
context = clause._compile_context()
@@ -343,25 +362,27 @@ class AssertsCompiledSQL(object):
clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
- kw['compile_kwargs'] = compile_kwargs
+ kw["compile_kwargs"] = compile_kwargs
c = clause.compile(dialect=dialect, **kw)
- param_str = repr(getattr(c, 'params', {}))
+ param_str = repr(getattr(c, "params", {}))
if util.py3k:
- param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(
- ("\nSQL String:\n" +
- util.text_type(c) +
- param_str).encode('utf-8'))
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
else:
print(
- "\nSQL String:\n" +
- util.text_type(c).encode('utf-8') +
- param_str)
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
- cc = re.sub(r'[\n\t]', '', util.text_type(c))
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -375,7 +396,6 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
-
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
@@ -386,8 +406,10 @@ class ComparesTables(object):
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
- assert isinstance(reflected_c.type, type(c.type)), \
- msg % (reflected_c.type, c.type)
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
else:
self.assert_types_base(reflected_c, c)
@@ -396,20 +418,22 @@ class ComparesTables(object):
eq_(
{f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys}
+ {f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
+ assert c1.type._compare_type_affinity(c2.type), (
+ "On column %r, type '%s' doesn't correspond to type '%s'"
+ % (c1.name, c1.type, c2.type)
+ )
class AssertsExecutionResults(object):
@@ -419,15 +443,19 @@ class AssertsExecutionResults(object):
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
+ self.assert_(
+ len(result) == len(list),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
@@ -435,9 +463,11 @@ class AssertsExecutionResults(object):
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
@@ -453,14 +483,19 @@ class AssertsExecutionResults(object):
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
- for wrong in util.itertools_filterfalse(lambda o:
- isinstance(o, cls), found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
NOVALUE = object()
@@ -469,7 +504,8 @@ class AssertsExecutionResults(object):
if isinstance(value, tuple):
try:
self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
+ getattr(obj, key), value[0], *value[1]
+ )
except AssertionError:
return False
else:
@@ -484,8 +520,9 @@ class AssertsExecutionResults(object):
break
else:
fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
return True
def sql_execution_asserter(self, db=None):
@@ -505,9 +542,9 @@ class AssertsExecutionResults(object):
newrules = []
for rule in rules:
if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.CompiledSQL(k, v) for k, v in rule.items()
- ])
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
@@ -516,7 +553,8 @@ class AssertsExecutionResults(object):
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
- db, callable_, assertsql.CountStatements(count))
+ db, callable_, assertsql.CountStatements(count)
+ )
def assert_multiple_sql_count(self, dbs, callable_, counts):
recs = [