diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 182 |
1 files changed, 110 insertions, 72 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e42376921..73ab4556a 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -86,6 +86,7 @@ def emits_warning_on(db, *messages): were in fact seen. """ + @decorator def decorate(fn, *args, **kw): with expect_warnings_on(db, assert_=False, *messages): @@ -114,12 +115,14 @@ def uses_deprecated(*messages): def decorate(fn, *args, **kw): with expect_deprecated(*messages, assert_=False): return fn(*args, **kw) + return decorate @contextlib.contextmanager -def _expect_warnings(exc_cls, messages, regex=True, assert_=True, - py2konly=False): +def _expect_warnings( + exc_cls, messages, regex=True, assert_=True, py2konly=False +): if regex: filters = [re.compile(msg, re.I | re.S) for msg in messages] @@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, return for filter_ in filters: - if (regex and filter_.match(msg)) or \ - (not regex and filter_ == msg): + if (regex and filter_.match(msg)) or ( + not regex and filter_ == msg + ): seen.discard(filter_) break else: @@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, yield if assert_ and (not py2konly or not compat.py3k): - assert not seen, "Warnings were not seen: %s" % \ - ", ".join("%r" % (s.pattern if regex else s) for s in seen) + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) def global_cleanup_assertions(): @@ -170,6 +175,7 @@ def global_cleanup_assertions(): """ _assert_no_stray_pool_connections() + _STRAY_CONNECTION_FAILURES = 0 @@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections(): # OK, let's be somewhat forgiving. _STRAY_CONNECTION_FAILURES += 1 - print("Encountered a stray connection in test cleanup: %s" - % str(pool._refs)) + print( + "Encountered a stray connection in test cleanup: %s" + % str(pool._refs) + ) # then do a real GC sweep. We shouldn't even be here # so a single sweep should really be doing it, otherwise # there's probably a real unreachable cycle somewhere. @@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections(): pool._refs.clear() _STRAY_CONNECTION_FAILURES = 0 warnings.warn( - "Stray connection refused to leave " - "after gc.collect(): %s" % err) + "Stray connection refused to leave " "after gc.collect(): %s" % err + ) elif _STRAY_CONNECTION_FAILURES > 10: assert False, "Encountered more than 10 stray connections" _STRAY_CONNECTION_FAILURES = 0 @@ -263,14 +271,16 @@ def not_in_(a, b, msg=None): def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, fragment) + a, + fragment, + ) def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r'^\s+?|\n', "", a) - a = re.sub(r' {2,}', " ", a) - b = re.sub(r'^\s+?|\n', "", b) - b = re.sub(r' {2,}', " ", b) + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) assert a == b, msg or "%r != %r" % (a, b) @@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search( - msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) - print(util.text_type(e).encode('utf-8')) + assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( + msg, + e, + ) + print(util.text_type(e).encode("utf-8")) + class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, - checkparams=None, dialect=None, - checkpositional=None, - check_prefetch=None, - use_default_dialect=False, - allow_dialect_select=False, - literal_binds=False, - schema_translate_map=None): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False, + schema_translate_map=None, + ): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: dialect = None else: if dialect is None: - dialect = getattr(self, '__dialect__', None) + dialect = getattr(self, "__dialect__", None) if dialect is None: dialect = config.db.dialect - elif dialect == 'default': + elif dialect == "default": dialect = default.DefaultDialect() - elif dialect == 'default_enhanced': + elif dialect == "default_enhanced": dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() @@ -325,13 +344,13 @@ class AssertsCompiledSQL(object): compile_kwargs = {} if schema_translate_map: - kw['schema_translate_map'] = schema_translate_map + kw["schema_translate_map"] = schema_translate_map if params is not None: - kw['column_keys'] = list(params) + kw["column_keys"] = list(params) if literal_binds: - compile_kwargs['literal_binds'] = True + compile_kwargs["literal_binds"] = True if isinstance(clause, orm.Query): context = clause._compile_context() @@ -343,25 +362,27 @@ class AssertsCompiledSQL(object): clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: - kw['compile_kwargs'] = compile_kwargs + kw["compile_kwargs"] = compile_kwargs c = clause.compile(dialect=dialect, **kw) - param_str = repr(getattr(c, 'params', {})) + param_str = repr(getattr(c, "params", {})) if util.py3k: - param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + param_str = param_str.encode("utf-8").decode("ascii", "ignore") print( - ("\nSQL String:\n" + - util.text_type(c) + - param_str).encode('utf-8')) + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) else: print( - "\nSQL String:\n" + - util.text_type(c).encode('utf-8') + - param_str) + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) - cc = re.sub(r'[\n\t]', '', util.text_type(c)) + cc = re.sub(r"[\n\t]", "", util.text_type(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -375,7 +396,6 @@ class AssertsCompiledSQL(object): class ComparesTables(object): - def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -386,8 +406,10 @@ class ComparesTables(object): if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" - assert isinstance(reflected_c.type, type(c.type)), \ - msg % (reflected_c.type, c.type) + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) else: self.assert_types_base(reflected_c, c) @@ -396,20 +418,22 @@ class ComparesTables(object): eq_( {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys} + {f.column.name for f in reflected_c.foreign_keys}, ) if c.server_default: - assert isinstance(reflected_c.server_default, - schema.FetchedValue) + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): - assert c1.type._compare_type_affinity(c2.type),\ - "On column %r, type '%s' doesn't correspond to type '%s'" % \ - (c1.name, c1.type, c2.type) + assert c1.type._compare_type_affinity(c2.type), ( + "On column %r, type '%s' doesn't correspond to type '%s'" + % (c1.name, c1.type, c2.type) + ) class AssertsExecutionResults(object): @@ -419,15 +443,19 @@ class AssertsExecutionResults(object): self.assert_list(result, class_, objects) def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), - "result list is not the same size as test list, " + - "for class " + class_.__name__) + self.assert_( + len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, - "item class is not " + repr(class_)) + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) for key, value in desc.items(): if isinstance(value, tuple): if isinstance(value[1], list): @@ -435,9 +463,11 @@ class AssertsExecutionResults(object): else: self.assert_row(value[0], getattr(rowobj, key), value[1]) else: - self.assert_(getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" % ( - key, getattr(rowobj, key), value)) + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) def assert_unordered_result(self, result, cls, *expected): """As assert_result, but the order of objects is not considered. @@ -453,14 +483,19 @@ class AssertsExecutionResults(object): found = util.IdentitySet(result) expected = {immutabledict(e) for e in expected} - for wrong in util.itertools_filterfalse(lambda o: - isinstance(o, cls), found): - fail('Unexpected type "%s", expected "%s"' % ( - type(wrong).__name__, cls.__name__)) + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) if len(found) != len(expected): - fail('Unexpected object count "%s", expected "%s"' % ( - len(found), len(expected))) + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) NOVALUE = object() @@ -469,7 +504,8 @@ class AssertsExecutionResults(object): if isinstance(value, tuple): try: self.assert_unordered_result( - getattr(obj, key), value[0], *value[1]) + getattr(obj, key), value[0], *value[1] + ) except AssertionError: return False else: @@ -484,8 +520,9 @@ class AssertsExecutionResults(object): break else: fail( - "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) return True def sql_execution_asserter(self, db=None): @@ -505,9 +542,9 @@ class AssertsExecutionResults(object): newrules = [] for rule in rules: if isinstance(rule, dict): - newrule = assertsql.AllOf(*[ - assertsql.CompiledSQL(k, v) for k, v in rule.items() - ]) + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) else: newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) @@ -516,7 +553,8 @@ class AssertsExecutionResults(object): def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( - db, callable_, assertsql.CountStatements(count)) + db, callable_, assertsql.CountStatements(count) + ) def assert_multiple_sql_count(self, dbs, callable_, counts): recs = [ |