diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing/assertions.py | |
parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 182 |
1 files changed, 110 insertions, 72 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e42376921..73ab4556a 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -86,6 +86,7 @@ def emits_warning_on(db, *messages): were in fact seen. """ + @decorator def decorate(fn, *args, **kw): with expect_warnings_on(db, assert_=False, *messages): @@ -114,12 +115,14 @@ def uses_deprecated(*messages): def decorate(fn, *args, **kw): with expect_deprecated(*messages, assert_=False): return fn(*args, **kw) + return decorate @contextlib.contextmanager -def _expect_warnings(exc_cls, messages, regex=True, assert_=True, - py2konly=False): +def _expect_warnings( + exc_cls, messages, regex=True, assert_=True, py2konly=False +): if regex: filters = [re.compile(msg, re.I | re.S) for msg in messages] @@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, return for filter_ in filters: - if (regex and filter_.match(msg)) or \ - (not regex and filter_ == msg): + if (regex and filter_.match(msg)) or ( + not regex and filter_ == msg + ): seen.discard(filter_) break else: @@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, yield if assert_ and (not py2konly or not compat.py3k): - assert not seen, "Warnings were not seen: %s" % \ - ", ".join("%r" % (s.pattern if regex else s) for s in seen) + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) def global_cleanup_assertions(): @@ -170,6 +175,7 @@ def global_cleanup_assertions(): """ _assert_no_stray_pool_connections() + _STRAY_CONNECTION_FAILURES = 0 @@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections(): # OK, let's be somewhat forgiving. _STRAY_CONNECTION_FAILURES += 1 - print("Encountered a stray connection in test cleanup: %s" - % str(pool._refs)) + print( + "Encountered a stray connection in test cleanup: %s" + % str(pool._refs) + ) # then do a real GC sweep. We shouldn't even be here # so a single sweep should really be doing it, otherwise # there's probably a real unreachable cycle somewhere. @@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections(): pool._refs.clear() _STRAY_CONNECTION_FAILURES = 0 warnings.warn( - "Stray connection refused to leave " - "after gc.collect(): %s" % err) + "Stray connection refused to leave " "after gc.collect(): %s" % err + ) elif _STRAY_CONNECTION_FAILURES > 10: assert False, "Encountered more than 10 stray connections" _STRAY_CONNECTION_FAILURES = 0 @@ -263,14 +271,16 @@ def not_in_(a, b, msg=None): def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, fragment) + a, + fragment, + ) def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r'^\s+?|\n', "", a) - a = re.sub(r' {2,}', " ", a) - b = re.sub(r'^\s+?|\n', "", b) - b = re.sub(r' {2,}', " ", b) + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) assert a == b, msg or "%r != %r" % (a, b) @@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search( - msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) - print(util.text_type(e).encode('utf-8')) + assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( + msg, + e, + ) + print(util.text_type(e).encode("utf-8")) + class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, - checkparams=None, dialect=None, - checkpositional=None, - check_prefetch=None, - use_default_dialect=False, - allow_dialect_select=False, - literal_binds=False, - schema_translate_map=None): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False, + schema_translate_map=None, + ): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: dialect = None else: if dialect is None: - dialect = getattr(self, '__dialect__', None) + dialect = getattr(self, "__dialect__", None) if dialect is None: dialect = config.db.dialect - elif dialect == 'default': + elif dialect == "default": dialect = default.DefaultDialect() - elif dialect == 'default_enhanced': + elif dialect == "default_enhanced": dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() @@ -325,13 +344,13 @@ class AssertsCompiledSQL(object): compile_kwargs = {} if schema_translate_map: - kw['schema_translate_map'] = schema_translate_map + kw["schema_translate_map"] = schema_translate_map if params is not None: - kw['column_keys'] = list(params) + kw["column_keys"] = list(params) if literal_binds: - compile_kwargs['literal_binds'] = True + compile_kwargs["literal_binds"] = True if isinstance(clause, orm.Query): context = clause._compile_context() @@ -343,25 +362,27 @@ class AssertsCompiledSQL(object): clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: - kw['compile_kwargs'] = compile_kwargs + kw["compile_kwargs"] = compile_kwargs c = clause.compile(dialect=dialect, **kw) - param_str = repr(getattr(c, 'params', {})) + param_str = repr(getattr(c, "params", {})) if util.py3k: - param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + param_str = param_str.encode("utf-8").decode("ascii", "ignore") print( - ("\nSQL String:\n" + - util.text_type(c) + - param_str).encode('utf-8')) + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) else: print( - "\nSQL String:\n" + - util.text_type(c).encode('utf-8') + - param_str) + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) - cc = re.sub(r'[\n\t]', '', util.text_type(c)) + cc = re.sub(r"[\n\t]", "", util.text_type(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -375,7 +396,6 @@ class AssertsCompiledSQL(object): class ComparesTables(object): - def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -386,8 +406,10 @@ class ComparesTables(object): if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" - assert isinstance(reflected_c.type, type(c.type)), \ - msg % (reflected_c.type, c.type) + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) else: self.assert_types_base(reflected_c, c) @@ -396,20 +418,22 @@ class ComparesTables(object): eq_( {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys} + {f.column.name for f in reflected_c.foreign_keys}, ) if c.server_default: - assert isinstance(reflected_c.server_default, - schema.FetchedValue) + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): - assert c1.type._compare_type_affinity(c2.type),\ - "On column %r, type '%s' doesn't correspond to type '%s'" % \ - (c1.name, c1.type, c2.type) + assert c1.type._compare_type_affinity(c2.type), ( + "On column %r, type '%s' doesn't correspond to type '%s'" + % (c1.name, c1.type, c2.type) + ) class AssertsExecutionResults(object): @@ -419,15 +443,19 @@ class AssertsExecutionResults(object): self.assert_list(result, class_, objects) def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), - "result list is not the same size as test list, " + - "for class " + class_.__name__) + self.assert_( + len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, - "item class is not " + repr(class_)) + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) for key, value in desc.items(): if isinstance(value, tuple): if isinstance(value[1], list): @@ -435,9 +463,11 @@ class AssertsExecutionResults(object): else: self.assert_row(value[0], getattr(rowobj, key), value[1]) else: - self.assert_(getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" % ( - key, getattr(rowobj, key), value)) + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) def assert_unordered_result(self, result, cls, *expected): """As assert_result, but the order of objects is not considered. @@ -453,14 +483,19 @@ class AssertsExecutionResults(object): found = util.IdentitySet(result) expected = {immutabledict(e) for e in expected} - for wrong in util.itertools_filterfalse(lambda o: - isinstance(o, cls), found): - fail('Unexpected type "%s", expected "%s"' % ( - type(wrong).__name__, cls.__name__)) + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) if len(found) != len(expected): - fail('Unexpected object count "%s", expected "%s"' % ( - len(found), len(expected))) + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) NOVALUE = object() @@ -469,7 +504,8 @@ class AssertsExecutionResults(object): if isinstance(value, tuple): try: self.assert_unordered_result( - getattr(obj, key), value[0], *value[1]) + getattr(obj, key), value[0], *value[1] + ) except AssertionError: return False else: @@ -484,8 +520,9 @@ class AssertsExecutionResults(object): break else: fail( - "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) return True def sql_execution_asserter(self, db=None): @@ -505,9 +542,9 @@ class AssertsExecutionResults(object): newrules = [] for rule in rules: if isinstance(rule, dict): - newrule = assertsql.AllOf(*[ - assertsql.CompiledSQL(k, v) for k, v in rule.items() - ]) + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) else: newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) @@ -516,7 +553,8 @@ class AssertsExecutionResults(object): def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( - db, callable_, assertsql.CountStatements(count)) + db, callable_, assertsql.CountStatements(count) + ) def assert_multiple_sql_count(self, dbs, callable_, counts): recs = [ |