summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-01-06 01:14:26 -0500
committermike bayer <mike_mp@zzzcomputing.com>2019-01-06 17:34:50 +0000
commit1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch)
tree28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing/assertions.py
parent404e69426b05a82d905cbb3ad33adafccddb00dd (diff)
downloadsqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r--lib/sqlalchemy/testing/assertions.py182
1 files changed, 110 insertions, 72 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e42376921..73ab4556a 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -86,6 +86,7 @@ def emits_warning_on(db, *messages):
were in fact seen.
"""
+
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
@@ -114,12 +115,14 @@ def uses_deprecated(*messages):
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
+
return decorate
@contextlib.contextmanager
-def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
- py2konly=False):
+def _expect_warnings(
+ exc_cls, messages, regex=True, assert_=True, py2konly=False
+):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
@@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or \
- (not regex and filter_ == msg):
+ if (regex and filter_.match(msg)) or (
+ not regex and filter_ == msg
+ ):
seen.discard(filter_)
break
else:
@@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
yield
if assert_ and (not py2konly or not compat.py3k):
- assert not seen, "Warnings were not seen: %s" % \
- ", ".join("%r" % (s.pattern if regex else s) for s in seen)
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
def global_cleanup_assertions():
@@ -170,6 +175,7 @@ def global_cleanup_assertions():
"""
_assert_no_stray_pool_connections()
+
_STRAY_CONNECTION_FAILURES = 0
@@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections():
# OK, let's be somewhat forgiving.
_STRAY_CONNECTION_FAILURES += 1
- print("Encountered a stray connection in test cleanup: %s"
- % str(pool._refs))
+ print(
+ "Encountered a stray connection in test cleanup: %s"
+ % str(pool._refs)
+ )
# then do a real GC sweep. We shouldn't even be here
# so a single sweep should really be doing it, otherwise
# there's probably a real unreachable cycle somewhere.
@@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections():
pool._refs.clear()
_STRAY_CONNECTION_FAILURES = 0
warnings.warn(
- "Stray connection refused to leave "
- "after gc.collect(): %s" % err)
+ "Stray connection refused to leave " "after gc.collect(): %s" % err
+ )
elif _STRAY_CONNECTION_FAILURES > 10:
assert False, "Encountered more than 10 stray connections"
_STRAY_CONNECTION_FAILURES = 0
@@ -263,14 +271,16 @@ def not_in_(a, b, msg=None):
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
+ a,
+ fragment,
+ )
def eq_ignore_whitespace(a, b, msg=None):
- a = re.sub(r'^\s+?|\n', "", a)
- a = re.sub(r' {2,}', " ", a)
- b = re.sub(r'^\s+?|\n', "", b)
- b = re.sub(r' {2,}', " ", b)
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
@@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
- assert re.search(
- msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(util.text_type(e).encode('utf-8'))
+ assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
+ msg,
+ e,
+ )
+ print(util.text_type(e).encode("utf-8"))
+
class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- check_prefetch=None,
- use_default_dialect=False,
- allow_dialect_select=False,
- literal_binds=False,
- schema_translate_map=None):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ literal_binds=False,
+ schema_translate_map=None,
+ ):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
- dialect = getattr(self, '__dialect__', None)
+ dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
- elif dialect == 'default':
+ elif dialect == "default":
dialect = default.DefaultDialect()
- elif dialect == 'default_enhanced':
+ elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
@@ -325,13 +344,13 @@ class AssertsCompiledSQL(object):
compile_kwargs = {}
if schema_translate_map:
- kw['schema_translate_map'] = schema_translate_map
+ kw["schema_translate_map"] = schema_translate_map
if params is not None:
- kw['column_keys'] = list(params)
+ kw["column_keys"] = list(params)
if literal_binds:
- compile_kwargs['literal_binds'] = True
+ compile_kwargs["literal_binds"] = True
if isinstance(clause, orm.Query):
context = clause._compile_context()
@@ -343,25 +362,27 @@ class AssertsCompiledSQL(object):
clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
- kw['compile_kwargs'] = compile_kwargs
+ kw["compile_kwargs"] = compile_kwargs
c = clause.compile(dialect=dialect, **kw)
- param_str = repr(getattr(c, 'params', {}))
+ param_str = repr(getattr(c, "params", {}))
if util.py3k:
- param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(
- ("\nSQL String:\n" +
- util.text_type(c) +
- param_str).encode('utf-8'))
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
else:
print(
- "\nSQL String:\n" +
- util.text_type(c).encode('utf-8') +
- param_str)
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
- cc = re.sub(r'[\n\t]', '', util.text_type(c))
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -375,7 +396,6 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
-
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
@@ -386,8 +406,10 @@ class ComparesTables(object):
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
- assert isinstance(reflected_c.type, type(c.type)), \
- msg % (reflected_c.type, c.type)
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
else:
self.assert_types_base(reflected_c, c)
@@ -396,20 +418,22 @@ class ComparesTables(object):
eq_(
{f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys}
+ {f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
+ assert c1.type._compare_type_affinity(c2.type), (
+ "On column %r, type '%s' doesn't correspond to type '%s'"
+ % (c1.name, c1.type, c2.type)
+ )
class AssertsExecutionResults(object):
@@ -419,15 +443,19 @@ class AssertsExecutionResults(object):
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
+ self.assert_(
+ len(result) == len(list),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
@@ -435,9 +463,11 @@ class AssertsExecutionResults(object):
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
@@ -453,14 +483,19 @@ class AssertsExecutionResults(object):
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
- for wrong in util.itertools_filterfalse(lambda o:
- isinstance(o, cls), found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
NOVALUE = object()
@@ -469,7 +504,8 @@ class AssertsExecutionResults(object):
if isinstance(value, tuple):
try:
self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
+ getattr(obj, key), value[0], *value[1]
+ )
except AssertionError:
return False
else:
@@ -484,8 +520,9 @@ class AssertsExecutionResults(object):
break
else:
fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
return True
def sql_execution_asserter(self, db=None):
@@ -505,9 +542,9 @@ class AssertsExecutionResults(object):
newrules = []
for rule in rules:
if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.CompiledSQL(k, v) for k, v in rule.items()
- ])
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
@@ -516,7 +553,8 @@ class AssertsExecutionResults(object):
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
- db, callable_, assertsql.CountStatements(count))
+ db, callable_, assertsql.CountStatements(count)
+ )
def assert_multiple_sql_count(self, dbs, callable_, counts):
recs = [