summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py52
1 files changed, 33 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d94d3fef6..ea7c1b734 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -8,15 +8,10 @@
import string, re
from sqlalchemy import schema, engine, util, exceptions
-from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import operators, visitors, functions
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import expression as sql
-ANSI_FUNCS = util.Set([
- 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
- 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
- 'SESSION_USER', 'USER'])
-
RESERVED_WORDS = util.Set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
'as', 'asc', 'asymmetric', 'authorization', 'between',
@@ -87,6 +82,18 @@ OPERATORS = {
operators.isnot : 'IS NOT'
}
+FUNCTIONS = {
+ functions.coalesce : 'coalesce%(expr)s',
+ functions.current_date: 'CURRENT_DATE',
+ functions.current_time: 'CURRENT_TIME',
+ functions.current_timestamp: 'CURRENT_TIMESTAMP',
+ functions.current_user: 'CURRENT_USER',
+ functions.localtime: 'LOCALTIME',
+ functions.localtimestamp: 'LOCALTIMESTAMP',
+ functions.session_user :'SESSION_USER',
+ functions.user: 'USER'
+}
+
class DefaultCompiler(engine.Compiled):
"""Default implementation of Compiled.
@@ -97,6 +104,7 @@ class DefaultCompiler(engine.Compiled):
__traverse_options__ = {'column_collections':False, 'entry':True}
operators = OPERATORS
+ functions = FUNCTIONS
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -215,7 +223,7 @@ class DefaultCompiler(engine.Compiled):
labelname = self._truncated_identifier("colident", label.name)
if result_map is not None:
- result_map[labelname] = (label.name, (label, label.obj), label.obj.type)
+ result_map[labelname.lower()] = (label.name, (label, label.obj), label.obj.type)
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
@@ -231,7 +239,7 @@ class DefaultCompiler(engine.Compiled):
name = column.name
if result_map is not None:
- result_map[name] = (name, (column, ), column.type)
+ result_map[name.lower()] = (name, (column, ), column.type)
if column._is_oid:
n = self.dialect.oid_column_name(column)
@@ -270,7 +278,7 @@ class DefaultCompiler(engine.Compiled):
def visit_textclause(self, textclause, **kwargs):
if textclause.typemap is not None:
for colname, type_ in textclause.typemap.iteritems():
- self.result_map[colname] = (colname, None, type_)
+ self.result_map[colname.lower()] = (colname, None, type_)
def do_bindparam(m):
name = m.group(1)
@@ -297,9 +305,6 @@ class DefaultCompiler(engine.Compiled):
sep = " " + self.operator_string(clauselist.operator) + " "
return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None])
- def apply_function_parens(self, func):
- return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
-
def visit_calculatedclause(self, clause, **kwargs):
return self.process(clause.clause_expr)
@@ -308,12 +313,20 @@ class DefaultCompiler(engine.Compiled):
def visit_function(self, func, result_map=None, **kwargs):
if result_map is not None:
- result_map[func.name] = (func.name, None, func.type)
+ result_map[func.name.lower()] = (func.name, None, func.type)
+
+ name = self.function_string(func)
- if not self.apply_function_parens(func):
- return ".".join(func.packagenames + [func.name])
+ if callable(name):
+ return name(*[self.process(x) for x in func.clause_expr])
else:
- return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
+ return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+
+ def function_argspec(self, func):
+ return self.process(func.clause_expr)
+
+ def function_string(self, func):
+ return self.functions.get(func.__class__, func.name + "%(expr)s")
def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
stack_entry = {'select':cs}
@@ -358,10 +371,11 @@ class DefaultCompiler(engine.Compiled):
def operator_string(self, operator):
return self.operators.get(operator, str(operator))
-
+
def visit_bindparam(self, bindparam, **kwargs):
- # apply truncation to the ultimate generated name
-
+ # TODO: remove this whole "unique" thing, just use regular
+ # anonymous params to implement. params used for inserts/updates
+ # etc. should no longer be "unique".
if bindparam.unique:
count = 1
key = bindparam.key