summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/exclusions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/exclusions.py')
-rw-r--r--lib/sqlalchemy/testing/exclusions.py123
1 files changed, 59 insertions, 64 deletions
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index 512fffb3b..9ed9e42c3 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -16,6 +16,7 @@ from . import config
from .. import util
from ..util import decorator
+
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
@@ -70,15 +71,15 @@ class compound(object):
def matching_config_reasons(self, config):
return [
- predicate._as_string(config) for predicate
- in self.skips.union(self.fails)
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
- not self.tags.intersection(exclude_tags) and
- (not include_tags or self.tags.intersection(include_tags))
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
@@ -87,13 +88,14 @@ class compound(object):
self.tags.update(other.tags)
def __call__(self, fn):
- if hasattr(fn, '_sa_exclusion_extend'):
+ if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
+
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@@ -113,10 +115,7 @@ class compound(object):
def _do(self, cfg, fn, *args, **kw):
for skip in self.skips:
if skip(cfg):
- msg = "'%s' : %s" % (
- fn.__name__,
- skip._as_string(cfg)
- )
+ msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg))
config.skip_test(msg)
try:
@@ -127,16 +126,20 @@ class compound(object):
self._expect_success(cfg, name=fn.__name__)
return return_value
- def _expect_failure(self, config, ex, name='block'):
+ def _expect_failure(self, config, ex, name="block"):
for fail in self.fails:
if fail(config):
- print(("%s failed as expected (%s): %s " % (
- name, fail._as_string(config), str(ex))))
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str(ex))
+ )
+ )
break
else:
util.raise_from_cause(ex)
- def _expect_success(self, config, name='block'):
+ def _expect_success(self, config, name="block"):
if not self.fails:
return
for fail in self.fails:
@@ -144,13 +147,12 @@ class compound(object):
break
else:
raise AssertionError(
- "Unexpected success for '%s' (%s)" %
- (
+ "Unexpected success for '%s' (%s)"
+ % (
name,
" and ".join(
- fail._as_string(config)
- for fail in self.fails
- )
+ fail._as_string(config) for fail in self.fails
+ ),
)
)
@@ -186,21 +188,24 @@ class Predicate(object):
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate],
- description)
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, util.string_types):
tokens = re.match(
- r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate)
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
if not tokens:
raise ValueError(
- "Couldn't locate DB name in predicate: %r" % predicate)
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
db = tokens.group(1)
op = tokens.group(2)
spec = (
tuple(int(d) for d in tokens.group(3).split("."))
- if tokens.group(3) else None
+ if tokens.group(3)
+ else None
)
return SpecPredicate(db, op, spec, description=description)
@@ -215,11 +220,13 @@ class Predicate(object):
bool_ = not negate
return self.description % {
"driver": config.db.url.get_driver_name()
- if config else "<no driver>",
+ if config
+ else "<no driver>",
"database": config.db.url.get_backend_name()
- if config else "<no database>",
+ if config
+ else "<no database>",
"doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support"
+ "does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
@@ -246,21 +253,21 @@ class SpecPredicate(Predicate):
self.description = description
_ops = {
- '<': operator.lt,
- '>': operator.gt,
- '==': operator.eq,
- '!=': operator.ne,
- '<=': operator.le,
- '>=': operator.ge,
- 'in': operator.contains,
- 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
engine = config.db
if "+" in self.db:
- dialect, driver = self.db.split('+')
+ dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
@@ -273,8 +280,9 @@ class SpecPredicate(Predicate):
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
- oper = hasattr(self.op, '__call__') and self.op \
- or self._ops[self.op]
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
return oper(version, self.spec)
else:
return True
@@ -289,17 +297,9 @@ class SpecPredicate(Predicate):
return "%s" % self.db
else:
if negate:
- return "not %s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "not %s %s %s" % (self.db, self.op, self.spec)
else:
- return "%s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
@@ -356,8 +356,9 @@ class OrPredicate(Predicate):
conjunction = " and "
else:
conjunction = " or "
- return conjunction.join(p._as_string(config, negate=negate)
- for p in self.predicates)
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
def _negation_str(self, config):
if self.description is not None:
@@ -387,7 +388,7 @@ def _server_version(engine):
# force metadata to be retrieved
conn = engine.connect()
- version = getattr(engine.dialect, 'server_version_info', None)
+ version = getattr(engine.dialect, "server_version_info", None)
if version is None:
version = ()
conn.close()
@@ -395,9 +396,7 @@ def _server_version(engine):
def db_spec(*dbs):
- return OrPredicate(
- [Predicate.as_predicate(db) for db in dbs]
- )
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open():
@@ -422,11 +421,7 @@ def fails_on(db, reason=None):
def fails_on_everything_except(*dbs):
- return succeeds_if(
- OrPredicate([
- Predicate.as_predicate(db) for db in dbs
- ])
- )
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
@@ -435,8 +430,9 @@ def skip(db, reason=None):
def only_on(dbs, reason=None):
return only_if(
- OrPredicate([Predicate.as_predicate(db, reason)
- for db in util.to_list(dbs)])
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
)
@@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None):
def against(config, *queries):
assert queries, "no queries sent!"
- return OrPredicate([
- Predicate.as_predicate(query)
- for query in queries
- ])(config)
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )