diff options
author | Ants Aasma <ants.aasma@gmail.com> | 2008-05-29 02:11:49 +0000 |
---|---|---|
committer | Ants Aasma <ants.aasma@gmail.com> | 2008-05-29 02:11:49 +0000 |
commit | 77c308367ffec3e8af9b5463b1c3bdd89640e8ac (patch) | |
tree | 26b420d993269f33590b421bd4cd712b1f427fff /lib/sqlalchemy/orm/evaluator.py | |
parent | 5ccfa64294fbf730ad7449b60e8b32a38565aea5 (diff) | |
download | sqlalchemy-77c308367ffec3e8af9b5463b1c3bdd89640e8ac.tar.gz |
Preliminary implementation for the evaluation framework
Diffstat (limited to 'lib/sqlalchemy/orm/evaluator.py')
-rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py new file mode 100644 index 000000000..c4517e494 --- /dev/null +++ b/lib/sqlalchemy/orm/evaluator.py @@ -0,0 +1,96 @@ +from sqlalchemy.sql import operators, functions +from sqlalchemy.sql import expression as sql +from sqlalchemy.util import Set +import operator + +class UnevaluatableError(Exception): + pass + +_straight_ops = Set([getattr(operators, op) for op in [ + 'add', 'mul', 'sub', 'div', 'mod', 'truediv', 'lt', 'le', 'ne', 'gt', 'ge', 'eq' +]]) + + +_notimplemented_ops = Set([getattr(operators, op) for op in [ + 'like_op', 'notlike_op', 'ilike_op', 'notilike_op', 'between_op', 'in_op', 'notin_op', + 'endswith_op', 'concat_op', +]]) + +class EvaluatorCompiler(object): + def process(self, clause): + meth = getattr(self, "visit_%s" % clause.__visit_name__, None) + if not meth: + raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_column(self, clause): + if 'parententity' in clause._annotations: + key = clause._annotations['parententity']._get_col_to_prop(clause).key + else: + key = clause.key + get_corresponding_attr = operator.attrgetter(key) + return lambda obj: get_corresponding_attr(obj) + + def visit_clauselist(self, clause): + evaluators = map(self.process, clause.clauses) + if clause.operator is operators.or_: + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + if clause.operator is operators.and_: + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if not value: + if value is None: + return None + return False + return True + + return evaluate + + def visit_binary(self, clause): + eval_left,eval_right = map(self.process, [clause.left, clause.right]) + operator = clause.operator + if operator is operators.is_: + def evaluate(obj): + return eval_left(obj) == eval_right(obj) + if operator is operators.isnot: + def evaluate(obj): + return eval_left(obj) != eval_right(obj) + elif operator in _straight_ops: + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is None or right_val is None: + return None + return operator(eval_left(obj), eval_right(obj)) + return evaluate + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + def evaluate(obj): + value = eval_inner(obj) + if value is None: + return None + return not value + return evaluate + raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) + + def visit_bindparam(self, clause): + val = clause.value + return lambda obj: val |