summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2017-03-28 11:00:37 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2017-06-05 11:27:00 -0400
commitbb6a1f690d4a749df44a1ef329b66f71205968fe (patch)
tree90aac9e592df3a769f5397f84a14b911e4cb52f1 /lib
parent6bb97495baa640c6f03d1b50affd664cb903dee3 (diff)
downloadsqlalchemy-bb6a1f690d4a749df44a1ef329b66f71205968fe.tar.gz
selectin polymorphic loading
Added a new style of mapper-level inheritance loading "polymorphic selectin". This style of loading emits queries for each subclass in an inheritance hierarchy subsequent to the load of the base object type, using IN to specify the desired primary key values. Fixes: #3948 Change-Id: I59e071c6142354a3f95730046e3dcdfc0e2c4de5
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/ext/baked.py11
-rw-r--r--lib/sqlalchemy/orm/__init__.py2
-rw-r--r--lib/sqlalchemy/orm/loading.py53
-rw-r--r--lib/sqlalchemy/orm/mapper.py111
-rw-r--r--lib/sqlalchemy/orm/strategies.py18
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py108
-rw-r--r--lib/sqlalchemy/orm/util.py21
-rw-r--r--lib/sqlalchemy/testing/assertions.py5
-rw-r--r--lib/sqlalchemy/testing/assertsql.py48
9 files changed, 340 insertions, 37 deletions
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
index ba3c2aed0..c0fe963ac 100644
--- a/lib/sqlalchemy/ext/baked.py
+++ b/lib/sqlalchemy/ext/baked.py
@@ -154,7 +154,7 @@ class BakedQuery(object):
self._spoiled = True
return self
- def _add_lazyload_options(self, options, effective_path):
+ def _add_lazyload_options(self, options, effective_path, cache_path=None):
"""Used by per-state lazy loaders to add options to the
"lazy load" query from a parent query.
@@ -166,13 +166,16 @@ class BakedQuery(object):
key = ()
- if effective_path.path[0].is_aliased_class:
+ if not cache_path:
+ cache_path = effective_path
+
+ if cache_path.path[0].is_aliased_class:
# paths that are against an AliasedClass are unsafe to cache
# with since the AliasedClass is an ad-hoc object.
self.spoil()
else:
for opt in options:
- cache_key = opt._generate_cache_key(effective_path)
+ cache_key = opt._generate_cache_key(cache_path)
if cache_key is False:
self.spoil()
elif cache_key is not None:
@@ -181,7 +184,7 @@ class BakedQuery(object):
self.add_criteria(
lambda q: q._with_current_path(effective_path).
_conditional_options(*options),
- effective_path.path, key
+ cache_path.path, key
)
def _retrieve_baked_query(self, session):
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index adfe2360a..7ecd5b67e 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -246,6 +246,7 @@ immediateload = strategy_options.immediateload._unbound_fn
noload = strategy_options.noload._unbound_fn
raiseload = strategy_options.raiseload._unbound_fn
defaultload = strategy_options.defaultload._unbound_fn
+selectin_polymorphic = strategy_options.selectin_polymorphic._unbound_fn
from .strategy_options import Load
@@ -268,6 +269,7 @@ def __go(lcls):
from .. import util as sa_util
from . import dynamic
from . import events
+ from . import loading
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 7feec660d..48c0db851 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -19,6 +19,7 @@ from . import attributes, exc as orm_exc
from ..sql import util as sql_util
from . import strategy_options
from . import path_registry
+from .. import sql
from .util import _none_set, state_str
from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE
@@ -353,6 +354,27 @@ def _instance_processor(
session_id = context.session.hash_key
version_check = context.version_check
runid = context.runid
+
+ if not refresh_state and _polymorphic_from is not None:
+ key = ('loader', path.path)
+ if (
+ key in context.attributes and
+ context.attributes[key].strategy ==
+ (('selectinload_polymorphic', True), ) and
+ mapper in context.attributes[key].local_opts['mappers']
+ ) or mapper.polymorphic_load == 'selectin':
+
+ # only_load_props goes w/ refresh_state only, and in a refresh
+ # we are a single row query for the exact entity; polymorphic
+ # loading does not apply
+ assert only_load_props is None
+
+ callable_ = _load_subclass_via_in(context, path, mapper)
+
+ PostLoad.callable_for_path(
+ context, load_path, mapper,
+ callable_, mapper)
+
post_load = PostLoad.for_context(context, load_path, only_load_props)
if refresh_state:
@@ -501,6 +523,37 @@ def _instance_processor(
return _instance
+@util.dependencies("sqlalchemy.ext.baked")
+def _load_subclass_via_in(baked, context, path, mapper):
+
+ zero_idx = len(mapper.base_mapper.primary_key) == 1
+
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in
+
+ def do_load(context, path, states, load_only, effective_entity):
+ orig_query = context.query
+
+ q._add_lazyload_options(
+ (enable_opt, ) + orig_query._with_options + (disable_opt, ),
+ path.parent, cache_path=path
+ )
+
+ if orig_query._populate_existing:
+ q.add_criteria(
+ lambda q: q.populate_existing()
+ )
+
+ q(context.session).params(
+ primary_keys=[
+ state.key[1][0] if zero_idx else state.key[1]
+ for state, load_attrs in states
+ if state.mapper.isa(mapper)
+ ]
+ ).all()
+
+ return do_load
+
+
def _populate_full(
context, row, state, dict_, isnew, load_path,
loaded_instance, populate_existing, populators):
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 6bf86d0ef..1042442c0 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -106,6 +106,7 @@ class Mapper(InspectionAttr):
polymorphic_identity=None,
concrete=False,
with_polymorphic=None,
+ polymorphic_load=None,
allow_partial_pks=True,
batch=True,
column_prefix=None,
@@ -381,6 +382,27 @@ class Mapper(InspectionAttr):
:paramref:`.mapper.passive_deletes` - supporting ON DELETE
CASCADE for joined-table inheritance mappers
+ :param polymorphic_load: Specifies "polymorphic loading" behavior
+ for a subclass in an inheritance hierarchy (joined and single
+ table inheritance only). Valid values are:
+
+ * "'inline'" - specifies this class should be part of the
+ "with_polymorphic" mappers, e.g. its columns will be included
+ in a SELECT query against the base.
+
+ * "'selectin'" - specifies that when instances of this class
+ are loaded, an additional SELECT will be emitted to retrieve
+ the columns specific to this subclass. The SELECT uses
+ IN to fetch multiple subclasses at once.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`with_polymorphic_mapper_config`
+
+ :ref:`polymorphic_selectin`
+
:param polymorphic_on: Specifies the column, attribute, or
SQL expression used to determine the target class for an
incoming row, when inheriting classes are present.
@@ -622,8 +644,6 @@ class Mapper(InspectionAttr):
else:
self.confirm_deleted_rows = confirm_deleted_rows
- self._set_with_polymorphic(with_polymorphic)
-
if isinstance(self.local_table, expression.SelectBase):
raise sa_exc.InvalidRequestError(
"When mapping against a select() construct, map against "
@@ -632,11 +652,8 @@ class Mapper(InspectionAttr):
"SELECT from a subquery that does not have an alias."
)
- if self.with_polymorphic and \
- isinstance(self.with_polymorphic[1],
- expression.SelectBase):
- self.with_polymorphic = (self.with_polymorphic[0],
- self.with_polymorphic[1].alias())
+ self._set_with_polymorphic(with_polymorphic)
+ self.polymorphic_load = polymorphic_load
# our 'polymorphic identity', a string name that when located in a
# result set row indicates this Mapper should be used to construct
@@ -1037,6 +1054,19 @@ class Mapper(InspectionAttr):
)
self.polymorphic_map[self.polymorphic_identity] = self
+ if self.polymorphic_load and self.concrete:
+ raise exc.ArgumentError(
+ "polymorphic_load is not currently supported "
+ "with concrete table inheritance")
+ if self.polymorphic_load == 'inline':
+ self.inherits._add_with_polymorphic_subclass(self)
+ elif self.polymorphic_load == 'selectin':
+ pass
+ elif self.polymorphic_load is not None:
+ raise sa_exc.ArgumentError(
+ "unknown argument for polymorphic_load: %r" %
+ self.polymorphic_load)
+
else:
self._all_tables = set()
self.base_mapper = self
@@ -1077,9 +1107,22 @@ class Mapper(InspectionAttr):
expression.SelectBase):
self.with_polymorphic = (self.with_polymorphic[0],
self.with_polymorphic[1].alias())
+
if self.configured:
self._expire_memoizations()
+ def _add_with_polymorphic_subclass(self, mapper):
+ subcl = mapper.class_
+ if self.with_polymorphic is None:
+ self._set_with_polymorphic((subcl,))
+ elif self.with_polymorphic[0] != '*':
+ self._set_with_polymorphic(
+ (
+ self.with_polymorphic[0] + (subcl, ),
+ self.with_polymorphic[1]
+ )
+ )
+
def _set_concrete_base(self, mapper):
"""Set the given :class:`.Mapper` as the 'inherits' for this
:class:`.Mapper`, assuming this :class:`.Mapper` is concrete
@@ -2663,6 +2706,60 @@ class Mapper(InspectionAttr):
cols.extend(props[key].columns)
return sql.select(cols, cond, use_labels=True)
+ @_memoized_configured_property
+ @util.dependencies(
+ "sqlalchemy.ext.baked",
+ "sqlalchemy.orm.strategy_options")
+ def _subclass_load_via_in(self, baked, strategy_options):
+ """Assemble a BakedQuery that can load the columns local to
+ this subclass as a SELECT with IN.
+
+ """
+ assert self.inherits
+
+ polymorphic_prop = self._columntoproperty[
+ self.polymorphic_on]
+ keep_props = set(
+ [polymorphic_prop] + self._identity_key_props)
+
+ disable_opt = strategy_options.Load(self)
+ enable_opt = strategy_options.Load(self)
+
+ for prop in self.attrs:
+ if prop.parent is self or prop in keep_props:
+ # "enable" options, to turn on the properties that we want to
+ # load by default (subject to options from the query)
+ enable_opt.set_generic_strategy(
+ (prop.key, ),
+ dict(prop.strategy_key)
+ )
+ else:
+ # "disable" options, to turn off the properties from the
+ # superclass that we *don't* want to load, applied after
+ # the options from the query to override them
+ disable_opt.set_generic_strategy(
+ (prop.key, ),
+ {"do_nothing": True}
+ )
+
+ if len(self.primary_key) > 1:
+ in_expr = sql.tuple_(*self.primary_key)
+ else:
+ in_expr = self.primary_key[0]
+
+ q = baked.BakedQuery(
+ self._compiled_cache,
+ lambda session: session.query(self),
+ (self, )
+ )
+ q += lambda q: q.filter(
+ in_expr.in_(
+ sql.bindparam('primary_keys', expanding=True)
+ )
+ ).order_by(*self.primary_key)
+
+ return q, enable_opt, disable_opt
+
def cascade_iterator(self, type_, state, halt_on=None):
"""Iterate each element and its mapper in an object graph,
for all relationships that meet the given cascade rule.
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index dc69ae99d..e48462d35 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -196,6 +196,7 @@ class ColumnLoader(LoaderStrategy):
@log.class_logger
@properties.ColumnProperty.strategy_for(deferred=True, instrument=True)
+@properties.ColumnProperty.strategy_for(do_nothing=True)
class DeferredColumnLoader(LoaderStrategy):
"""Provide loading behavior for a deferred :class:`.ColumnProperty`."""
@@ -336,6 +337,18 @@ class AbstractRelationshipLoader(LoaderStrategy):
@log.class_logger
+@properties.RelationshipProperty.strategy_for(do_nothing=True)
+class DoNothingLoader(LoaderStrategy):
+ """Relationship loader that makes no change to the object's state.
+
+ Compared to NoLoader, this loader does not initialize the
+ collection/attribute to empty/none; the usual default LazyLoader will
+ take effect.
+
+ """
+
+
+@log.class_logger
@properties.RelationshipProperty.strategy_for(lazy="noload")
@properties.RelationshipProperty.strategy_for(lazy=None)
class NoLoader(AbstractRelationshipLoader):
@@ -711,6 +724,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
self, context, path, loadopt,
mapper, result, adapter, populators):
key = self.key
+
if not self.is_class_level:
# we are not the primary manager for this attribute
# on this class - set up a
@@ -1804,6 +1818,9 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
selectin_path = (
context.query._current_path or orm_util.PathRegistry.root) + path
+ if not orm_util._entity_isa(path[-1], self.parent):
+ return
+
if loading.PostLoad.path_exists(context, selectin_path, self.key):
return
@@ -1914,6 +1931,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
}
for key, state, overwrite in chunk:
+
if not overwrite and self.key in state.dict:
continue
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index df13f05db..d3f456969 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -13,7 +13,7 @@ from .attributes import QueryableAttribute
from .. import util
from ..sql.base import _generative, Generative
from .. import exc as sa_exc, inspect
-from .base import _is_aliased_class, _class_to_mapper
+from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class
from . import util as orm_util
from .path_registry import PathRegistry, TokenRegistry, \
_WILDCARD_TOKEN, _DEFAULT_TOKEN
@@ -63,6 +63,7 @@ class Load(Generative, MapperOption):
self.context = util.OrderedDict()
self.local_opts = {}
self._of_type = None
+ self.is_class_strategy = False
@classmethod
def for_existing_path(cls, path):
@@ -127,6 +128,7 @@ class Load(Generative, MapperOption):
return cloned
is_opts_only = False
+ is_class_strategy = False
strategy = None
propagate_to_loaders = False
@@ -148,6 +150,7 @@ class Load(Generative, MapperOption):
def _generate_path(self, path, attr, wildcard_key, raiseerr=True):
self._of_type = None
+
if raiseerr and not path.has_entity:
if isinstance(path, TokenRegistry):
raise sa_exc.ArgumentError(
@@ -187,6 +190,14 @@ class Load(Generative, MapperOption):
attr = attr.property
path = path[attr]
+ elif _is_mapped_class(attr):
+ if not attr.common_parent(path.mapper):
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Attribute '%s' does not "
+ "link from element '%s'" % (attr, path.entity))
+ else:
+ return None
else:
prop = attr.property
@@ -246,6 +257,7 @@ class Load(Generative, MapperOption):
self, attr, strategy, propagate_to_loaders=True):
strategy = self._coerce_strat(strategy)
+ self.is_class_strategy = False
self.propagate_to_loaders = propagate_to_loaders
# if the path is a wildcard, this will set propagate_to_loaders=False
self._generate_path(self.path, attr, "relationship")
@@ -257,6 +269,7 @@ class Load(Generative, MapperOption):
def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False):
strategy = self._coerce_strat(strategy)
+ self.is_class_strategy = False
for attr in attrs:
cloned = self._generate()
cloned.strategy = strategy
@@ -267,6 +280,31 @@ class Load(Generative, MapperOption):
if opts_only:
cloned.is_opts_only = True
cloned._set_path_strategy()
+ self.is_class_strategy = False
+
+ @_generative
+ def set_generic_strategy(self, attrs, strategy):
+ strategy = self._coerce_strat(strategy)
+
+ for attr in attrs:
+ path = self._generate_path(self.path, attr, None)
+ cloned = self._generate()
+ cloned.strategy = strategy
+ cloned.path = path
+ cloned.propagate_to_loaders = True
+ cloned._set_path_strategy()
+
+ @_generative
+ def set_class_strategy(self, strategy, opts):
+ strategy = self._coerce_strat(strategy)
+ cloned = self._generate()
+ cloned.is_class_strategy = True
+ path = cloned._generate_path(self.path, None, None)
+ cloned.strategy = strategy
+ cloned.path = path
+ cloned.propagate_to_loaders = True
+ cloned._set_path_strategy()
+ cloned.local_opts.update(opts)
def _set_for_path(self, context, path, replace=True, merge_opts=False):
if merge_opts or not replace:
@@ -284,7 +322,7 @@ class Load(Generative, MapperOption):
self.local_opts.update(existing.local_opts)
def _set_path_strategy(self):
- if self.path.has_entity:
+ if not self.is_class_strategy and self.path.has_entity:
effective_path = self.path.parent
else:
effective_path = self.path
@@ -367,7 +405,10 @@ class _UnboundLoad(Load):
if attr == _DEFAULT_TOKEN:
self.propagate_to_loaders = False
attr = "%s:%s" % (wildcard_key, attr)
- path = path + (attr, )
+ if path and _is_mapped_class(path[-1]) and not self.is_class_strategy:
+ path = path[0:-1]
+ if attr:
+ path = path + (attr, )
self.path = path
return path
@@ -502,7 +543,12 @@ class _UnboundLoad(Load):
(User, User.orders.property, Order, Order.items.property))
"""
+
start_path = self.path
+
+ if self.is_class_strategy and current_path:
+ start_path += (entities[0], )
+
# _current_path implies we're in a
# secondary load with an existing path
@@ -517,7 +563,8 @@ class _UnboundLoad(Load):
token = start_path[0]
if isinstance(token, util.string_types):
- entity = self._find_entity_basestring(entities, token, raiseerr)
+ entity = self._find_entity_basestring(
+ entities, token, raiseerr)
elif isinstance(token, PropComparator):
prop = token.property
entity = self._find_entity_prop_comparator(
@@ -525,7 +572,10 @@ class _UnboundLoad(Load):
prop.key,
token._parententity,
raiseerr)
-
+ elif self.is_class_strategy and _is_mapped_class(token):
+ entity = inspect(token)
+ if entity not in entities:
+ entity = None
else:
raise sa_exc.ArgumentError(
"mapper option expects "
@@ -541,7 +591,6 @@ class _UnboundLoad(Load):
# we just located, then go through the rest of our path
# tokens and populate into the Load().
loader = Load(path_element)
-
if context is not None:
loader.context = context
else:
@@ -549,16 +598,19 @@ class _UnboundLoad(Load):
loader.strategy = self.strategy
loader.is_opts_only = self.is_opts_only
+ loader.is_class_strategy = self.is_class_strategy
path = loader.path
- for token in start_path:
- if not loader._generate_path(
- loader.path, token, None, raiseerr):
- return
+
+ if not loader.is_class_strategy:
+ for token in start_path:
+ if not loader._generate_path(
+ loader.path, token, None, raiseerr):
+ return
loader.local_opts.update(self.local_opts)
- if loader.path.has_entity:
+ if not loader.is_class_strategy and loader.path.has_entity:
effective_path = loader.path.parent
else:
effective_path = loader.path
@@ -1289,3 +1341,37 @@ def undefer_group(loadopt, name):
@undefer_group._add_unbound_fn
def undefer_group(name):
return _UnboundLoad().undefer_group(name)
+
+
+@loader_option()
+def selectin_polymorphic(loadopt, classes):
+ """Indicate an eager load should take place for all attributes
+ specific to a subclass.
+
+ This uses an additional SELECT with IN against all matched primary
+ key values, and is the per-query analogue to the ``"selectin"``
+ setting on the :paramref:`.mapper.polymorphic_load` parameter.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`inheritance_polymorphic_load`
+
+ """
+ loadopt.set_class_strategy(
+ {"selectinload_polymorphic": True},
+ opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))}
+ )
+ return loadopt
+
+
+@selectin_polymorphic._add_unbound_fn
+def selectin_polymorphic(base_cls, classes):
+ ul = _UnboundLoad()
+ ul.is_class_strategy = True
+ ul.path = (inspect(base_cls), )
+ ul.selectin_polymorphic(
+ classes
+ )
+ return ul
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 9a397ccf3..4267b79fb 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1043,7 +1043,13 @@ def was_deleted(object):
state = attributes.instance_state(object)
return state.was_deleted
+
def _entity_corresponds_to(given, entity):
+ """determine if 'given' corresponds to 'entity', in terms
+ of an entity passed to Query that would match the same entity
+ being referred to elsewhere in the query.
+
+ """
if entity.is_aliased_class:
if given.is_aliased_class:
if entity._base_alias is given._base_alias:
@@ -1057,6 +1063,21 @@ def _entity_corresponds_to(given, entity):
return entity.common_parent(given)
+
+def _entity_isa(given, mapper):
+ """determine if 'given' "is a" mapper, in terms of the given
+ would load rows of type 'mapper'.
+
+ """
+ if given.is_aliased_class:
+ return mapper in given.with_polymorphic_mappers or \
+ given.mapper.isa(mapper)
+ elif given.with_polymorphic_mappers:
+ return mapper in given.with_polymorphic_mappers
+ else:
+ return given.isa(mapper)
+
+
def randomize_unitofwork():
"""Use random-ordering sets within the unit of work in order
to detect unit of work sorting issues.
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index dfea33dc7..c0854ea55 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -497,8 +497,9 @@ class AssertsExecutionResults(object):
def assert_sql_execution(self, db, callable_, *rules):
with self.sql_execution_asserter(db) as asserter:
- callable_()
+ result = callable_()
asserter.assert_(*rules)
+ return result
def assert_sql(self, db, callable_, rules):
@@ -512,7 +513,7 @@ class AssertsExecutionResults(object):
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
- self.assert_sql_execution(db, callable_, *newrules)
+ return self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index e39b6315d..86d850733 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -282,6 +282,32 @@ class AllOf(AssertRule):
self.errormessage = list(self.rules)[0].errormessage
+class EachOf(AssertRule):
+
+ def __init__(self, *rules):
+ self.rules = list(rules)
+
+ def process_statement(self, execute_observed):
+ while self.rules:
+ rule = self.rules[0]
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.pop(0)
+ elif rule.errormessage:
+ self.errormessage = rule.errormessage
+ if rule.consume_statement:
+ break
+
+ if not self.rules:
+ self.is_consumed = True
+
+ def no_more_statements(self):
+ if self.rules and not self.rules[0].is_consumed:
+ self.rules[0].no_more_statements()
+ elif self.rules:
+ super(EachOf, self).no_more_statements()
+
+
class Or(AllOf):
def process_statement(self, execute_observed):
@@ -319,24 +345,20 @@ class SQLAsserter(object):
del self.accumulated
def assert_(self, *rules):
- rules = list(rules)
- observed = list(self._final)
+ rule = EachOf(*rules)
- while observed and rules:
- rule = rules[0]
- rule.process_statement(observed[0])
+ observed = list(self._final)
+ while observed:
+ statement = observed.pop(0)
+ rule.process_statement(statement)
if rule.is_consumed:
- rules.pop(0)
+ break
elif rule.errormessage:
assert False, rule.errormessage
-
- if rule.consume_statement:
- observed.pop(0)
-
- if not observed and rules:
- rules[0].no_more_statements()
- elif not rules and observed:
+ if observed:
assert False, "Additional SQL statements remain"
+ elif not rule.is_consumed:
+ rule.no_more_statements()
@contextlib.contextmanager