diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 52 |
1 files changed, 33 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d94d3fef6..ea7c1b734 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -8,15 +8,10 @@ import string, re from sqlalchemy import schema, engine, util, exceptions -from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql import operators, visitors, functions from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression as sql -ANSI_FUNCS = util.Set([ - 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', - 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', - 'SESSION_USER', 'USER']) - RESERVED_WORDS = util.Set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', 'as', 'asc', 'asymmetric', 'authorization', 'between', @@ -87,6 +82,18 @@ OPERATORS = { operators.isnot : 'IS NOT' } +FUNCTIONS = { + functions.coalesce : 'coalesce%(expr)s', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', + functions.current_timestamp: 'CURRENT_TIMESTAMP', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', + functions.localtimestamp: 'LOCALTIMESTAMP', + functions.session_user :'SESSION_USER', + functions.user: 'USER' +} + class DefaultCompiler(engine.Compiled): """Default implementation of Compiled. @@ -97,6 +104,7 @@ class DefaultCompiler(engine.Compiled): __traverse_options__ = {'column_collections':False, 'entry':True} operators = OPERATORS + functions = FUNCTIONS def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -215,7 +223,7 @@ class DefaultCompiler(engine.Compiled): labelname = self._truncated_identifier("colident", label.name) if result_map is not None: - result_map[labelname] = (label.name, (label, label.obj), label.obj.type) + result_map[labelname.lower()] = (label.name, (label, label.obj), label.obj.type) return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) @@ -231,7 +239,7 @@ class DefaultCompiler(engine.Compiled): name = column.name if result_map is not None: - result_map[name] = (name, (column, ), column.type) + result_map[name.lower()] = (name, (column, ), column.type) if column._is_oid: n = self.dialect.oid_column_name(column) @@ -270,7 +278,7 @@ class DefaultCompiler(engine.Compiled): def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): - self.result_map[colname] = (colname, None, type_) + self.result_map[colname.lower()] = (colname, None, type_) def do_bindparam(m): name = m.group(1) @@ -297,9 +305,6 @@ class DefaultCompiler(engine.Compiled): sep = " " + self.operator_string(clauselist.operator) + " " return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None]) - def apply_function_parens(self, func): - return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, clause, **kwargs): return self.process(clause.clause_expr) @@ -308,12 +313,20 @@ class DefaultCompiler(engine.Compiled): def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: - result_map[func.name] = (func.name, None, func.type) + result_map[func.name.lower()] = (func.name, None, func.type) + + name = self.function_string(func) - if not self.apply_function_parens(func): - return ".".join(func.packagenames + [func.name]) + if callable(name): + return name(*[self.process(x) for x in func.clause_expr]) else: - return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) + return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + + def function_argspec(self, func): + return self.process(func.clause_expr) + + def function_string(self, func): + return self.functions.get(func.__class__, func.name + "%(expr)s") def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): stack_entry = {'select':cs} @@ -358,10 +371,11 @@ class DefaultCompiler(engine.Compiled): def operator_string(self, operator): return self.operators.get(operator, str(operator)) - + def visit_bindparam(self, bindparam, **kwargs): - # apply truncation to the ultimate generated name - + # TODO: remove this whole "unique" thing, just use regular + # anonymous params to implement. params used for inserts/updates + # etc. should no longer be "unique". if bindparam.unique: count = 1 key = bindparam.key |