diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 5 |
2 files changed, 15 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index e1dd96068..ca26c9ca4 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -25,6 +25,9 @@ _notimplemented_ops = set(getattr(operators, op) class EvaluatorCompiler(object): + def __init__(self, target_cls=None): + self.target_cls = target_cls + def process(self, clause): meth = getattr(self, "visit_%s" % clause.__visit_name__, None) if not meth: @@ -46,10 +49,17 @@ class EvaluatorCompiler(object): def visit_column(self, clause): if 'parentmapper' in clause._annotations: - key = clause._annotations['parentmapper'].\ - _columntoproperty[clause].key + parentmapper = clause._annotations['parentmapper'] + if self.target_cls and not issubclass( + self.target_cls, parentmapper.class_): + raise UnevaluatableError( + "Can't evaluate criteria against alternate class %s" % + parentmapper.class_ + ) + key = parentmapper._columntoproperty[clause].key else: key = clause.key + get_corresponding_attr = operator.attrgetter(key) return lambda obj: get_corresponding_attr(obj) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 996cc8802..56778cb05 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -922,8 +922,10 @@ class BulkEvaluate(BulkUD): def _do_pre_synchronize(self): query = self.query + target_cls = query._mapper_zero().class_ + try: - evaluator_compiler = evaluator.EvaluatorCompiler() + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) if query.whereclause is not None: eval_condition = evaluator_compiler.process( query.whereclause) @@ -938,7 +940,6 @@ class BulkEvaluate(BulkUD): "Could not evaluate current criteria in Python. " "Specify 'fetch' or False for the " "synchronize_session parameter.") - target_cls = query._mapper_zero().class_ #TODO: detect when the where clause is a trivial primary key match self.matched_objects = [ |