summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/evaluator.py14
-rw-r--r--lib/sqlalchemy/orm/persistence.py5
2 files changed, 15 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index e1dd96068..ca26c9ca4 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -25,6 +25,9 @@ _notimplemented_ops = set(getattr(operators, op)
class EvaluatorCompiler(object):
+ def __init__(self, target_cls=None):
+ self.target_cls = target_cls
+
def process(self, clause):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth:
@@ -46,10 +49,17 @@ class EvaluatorCompiler(object):
def visit_column(self, clause):
if 'parentmapper' in clause._annotations:
- key = clause._annotations['parentmapper'].\
- _columntoproperty[clause].key
+ parentmapper = clause._annotations['parentmapper']
+ if self.target_cls and not issubclass(
+ self.target_cls, parentmapper.class_):
+ raise UnevaluatableError(
+ "Can't evaluate criteria against alternate class %s" %
+ parentmapper.class_
+ )
+ key = parentmapper._columntoproperty[clause].key
else:
key = clause.key
+
get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj)
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 996cc8802..56778cb05 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -922,8 +922,10 @@ class BulkEvaluate(BulkUD):
def _do_pre_synchronize(self):
query = self.query
+ target_cls = query._mapper_zero().class_
+
try:
- evaluator_compiler = evaluator.EvaluatorCompiler()
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
if query.whereclause is not None:
eval_condition = evaluator_compiler.process(
query.whereclause)
@@ -938,7 +940,6 @@ class BulkEvaluate(BulkUD):
"Could not evaluate current criteria in Python. "
"Specify 'fetch' or False for the "
"synchronize_session parameter.")
- target_cls = query._mapper_zero().class_
#TODO: detect when the where clause is a trivial primary key match
self.matched_objects = [