diff options
Diffstat (limited to 'lib/sqlalchemy/testing/exclusions.py')
-rw-r--r-- | lib/sqlalchemy/testing/exclusions.py | 123 |
1 files changed, 59 insertions, 64 deletions
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 512fffb3b..9ed9e42c3 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -16,6 +16,7 @@ from . import config from .. import util from ..util import decorator + def skip_if(predicate, reason=None): rule = compound() pred = _as_predicate(predicate, reason) @@ -70,15 +71,15 @@ class compound(object): def matching_config_reasons(self, config): return [ - predicate._as_string(config) for predicate - in self.skips.union(self.fails) + predicate._as_string(config) + for predicate in self.skips.union(self.fails) if predicate(config) ] def include_test(self, include_tags, exclude_tags): return bool( - not self.tags.intersection(exclude_tags) and - (not include_tags or self.tags.intersection(include_tags)) + not self.tags.intersection(exclude_tags) + and (not include_tags or self.tags.intersection(include_tags)) ) def _extend(self, other): @@ -87,13 +88,14 @@ class compound(object): self.tags.update(other.tags) def __call__(self, fn): - if hasattr(fn, '_sa_exclusion_extend'): + if hasattr(fn, "_sa_exclusion_extend"): fn._sa_exclusion_extend._extend(self) return fn @decorator def decorate(fn, *args, **kw): return self._do(config._current, fn, *args, **kw) + decorated = decorate(fn) decorated._sa_exclusion_extend = self return decorated @@ -113,10 +115,7 @@ class compound(object): def _do(self, cfg, fn, *args, **kw): for skip in self.skips: if skip(cfg): - msg = "'%s' : %s" % ( - fn.__name__, - skip._as_string(cfg) - ) + msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg)) config.skip_test(msg) try: @@ -127,16 +126,20 @@ class compound(object): self._expect_success(cfg, name=fn.__name__) return return_value - def _expect_failure(self, config, ex, name='block'): + def _expect_failure(self, config, ex, name="block"): for fail in self.fails: if fail(config): - print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(ex)))) + print( + ( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), str(ex)) + ) + ) break else: util.raise_from_cause(ex) - def _expect_success(self, config, name='block'): + def _expect_success(self, config, name="block"): if not self.fails: return for fail in self.fails: @@ -144,13 +147,12 @@ class compound(object): break else: raise AssertionError( - "Unexpected success for '%s' (%s)" % - ( + "Unexpected success for '%s' (%s)" + % ( name, " and ".join( - fail._as_string(config) - for fail in self.fails - ) + fail._as_string(config) for fail in self.fails + ), ) ) @@ -186,21 +188,24 @@ class Predicate(object): return predicate elif isinstance(predicate, (list, set)): return OrPredicate( - [cls.as_predicate(pred) for pred in predicate], - description) + [cls.as_predicate(pred) for pred in predicate], description + ) elif isinstance(predicate, tuple): return SpecPredicate(*predicate) elif isinstance(predicate, util.string_types): tokens = re.match( - r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate) + r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate + ) if not tokens: raise ValueError( - "Couldn't locate DB name in predicate: %r" % predicate) + "Couldn't locate DB name in predicate: %r" % predicate + ) db = tokens.group(1) op = tokens.group(2) spec = ( tuple(int(d) for d in tokens.group(3).split(".")) - if tokens.group(3) else None + if tokens.group(3) + else None ) return SpecPredicate(db, op, spec, description=description) @@ -215,11 +220,13 @@ class Predicate(object): bool_ = not negate return self.description % { "driver": config.db.url.get_driver_name() - if config else "<no driver>", + if config + else "<no driver>", "database": config.db.url.get_backend_name() - if config else "<no database>", + if config + else "<no database>", "doesnt_support": "doesn't support" if bool_ else "does support", - "does_support": "does support" if bool_ else "doesn't support" + "does_support": "does support" if bool_ else "doesn't support", } def _as_string(self, config=None, negate=False): @@ -246,21 +253,21 @@ class SpecPredicate(Predicate): self.description = description _ops = { - '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], } def __call__(self, config): engine = config.db if "+" in self.db: - dialect, driver = self.db.split('+') + dialect, driver = self.db.split("+") else: dialect, driver = self.db, None @@ -273,8 +280,9 @@ class SpecPredicate(Predicate): assert driver is None, "DBAPI version specs not supported yet" version = _server_version(engine) - oper = hasattr(self.op, '__call__') and self.op \ - or self._ops[self.op] + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) return oper(version, self.spec) else: return True @@ -289,17 +297,9 @@ class SpecPredicate(Predicate): return "%s" % self.db else: if negate: - return "not %s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "not %s %s %s" % (self.db, self.op, self.spec) else: - return "%s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "%s %s %s" % (self.db, self.op, self.spec) class LambdaPredicate(Predicate): @@ -356,8 +356,9 @@ class OrPredicate(Predicate): conjunction = " and " else: conjunction = " or " - return conjunction.join(p._as_string(config, negate=negate) - for p in self.predicates) + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) def _negation_str(self, config): if self.description is not None: @@ -387,7 +388,7 @@ def _server_version(engine): # force metadata to be retrieved conn = engine.connect() - version = getattr(engine.dialect, 'server_version_info', None) + version = getattr(engine.dialect, "server_version_info", None) if version is None: version = () conn.close() @@ -395,9 +396,7 @@ def _server_version(engine): def db_spec(*dbs): - return OrPredicate( - [Predicate.as_predicate(db) for db in dbs] - ) + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) def open(): @@ -422,11 +421,7 @@ def fails_on(db, reason=None): def fails_on_everything_except(*dbs): - return succeeds_if( - OrPredicate([ - Predicate.as_predicate(db) for db in dbs - ]) - ) + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) def skip(db, reason=None): @@ -435,8 +430,9 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([Predicate.as_predicate(db, reason) - for db in util.to_list(dbs)]) + OrPredicate( + [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] + ) ) @@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None): def against(config, *queries): assert queries, "no queries sent!" - return OrPredicate([ - Predicate.as_predicate(query) - for query in queries - ])(config) + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) |