summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py106
1 files changed, 77 insertions, 29 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index be3375def..a3008d085 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -183,6 +183,11 @@ class SQLCompiler(engine.Compiled):
# clauses before the VALUES or WHERE clause (i.e. MSSQL)
returning_precedes_values = False
+ # SQL 92 doesn't allow bind parameters to be used
+ # in the columns clause of a SELECT. A compiler
+ # subclass can set this flag to False if the target
+ # driver/DB enforces this
+ binds_in_columns_clause = True
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -260,9 +265,14 @@ class SQLCompiler(engine.Compiled):
else:
if bindparam.required:
if _group_number:
- raise exc.InvalidRequestError("A value is required for bind parameter %r, in parameter group %d" % (bindparam.key, _group_number))
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r, "
+ "in parameter group %d" %
+ (bindparam.key, _group_number))
else:
- raise exc.InvalidRequestError("A value is required for bind parameter %r" % bindparam.key)
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r"
+ % bindparam.key)
elif util.callable(bindparam.value):
pd[name] = bindparam.value()
else:
@@ -290,8 +300,8 @@ class SQLCompiler(engine.Compiled):
"""
return ""
- def visit_grouping(self, grouping, **kwargs):
- return "(" + self.process(grouping.element) + ")"
+ def visit_grouping(self, grouping, asfrom=False, **kwargs):
+ return "(" + self.process(grouping.element, **kwargs) + ")"
def visit_label(self, label, result_map=None, within_columns_clause=False):
# only render labels within the columns clause
@@ -384,27 +394,28 @@ class SQLCompiler(engine.Compiled):
sep = " "
else:
sep = OPERATORS[clauselist.operator]
- return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
+ return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses)
if s is not None)
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
- x += self.process(clause.value) + " "
+ x += self.process(clause.value, **kwargs) + " "
for cond, result in clause.whens:
- x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " "
+ x += "WHEN " + self.process(cond, **kwargs) + \
+ " THEN " + self.process(result, **kwargs) + " "
if clause.else_ is not None:
- x += "ELSE " + self.process(clause.else_) + " "
+ x += "ELSE " + self.process(clause.else_, **kwargs) + " "
x += "END"
return x
def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % \
- (self.process(cast.clause), self.process(cast.typeclause))
+ (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs))
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
- return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr))
+ return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs))
def visit_function(self, func, result_map=None, **kwargs):
if result_map is not None:
@@ -421,22 +432,23 @@ class SQLCompiler(engine.Compiled):
def function_argspec(self, func, **kwargs):
return self.process(func.clause_expr, **kwargs)
- def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
+ def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+ (self.process(c, asfrom=asfrom, parens=False,
+ compound_index=i, **kwargs)
for i, c in enumerate(cs.selects))
)
- group_by = self.process(cs._group_by_clause, asfrom=asfrom)
+ group_by = self.process(cs._group_by_clause, asfrom=asfrom, **kwargs)
if group_by:
text += " GROUP BY " + group_by
- text += self.order_by_clause(cs)
+ text += self.order_by_clause(cs, **kwargs)
text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
self.stack.pop(-1)
@@ -457,28 +469,38 @@ class SQLCompiler(engine.Compiled):
return self._operator_dispatch(binary.operator,
binary,
- lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+ lambda opstr: self.process(binary.left, **kwargs) +
+ opstr +
+ self.process(binary.right, **kwargs),
**kwargs
)
def visit_like_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ return '%s LIKE %s' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def visit_notlike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ return '%s NOT LIKE %s' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def visit_ilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ return 'lower(%s) LIKE lower(%s)' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def visit_notilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ return 'lower(%s) NOT LIKE lower(%s)' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def _operator_dispatch(self, operator, element, fn, **kw):
@@ -491,7 +513,14 @@ class SQLCompiler(engine.Compiled):
else:
return fn(" " + operator + " ")
- def visit_bindparam(self, bindparam, **kwargs):
+ def visit_bindparam(self, bindparam, within_columns_clause=False,
+ literal_binds=False, **kwargs):
+ if literal_binds or \
+ (within_columns_clause and \
+ not self.binds_in_columns_clause) and \
+ bindparam.value is not None:
+ return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
+
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
@@ -510,7 +539,26 @@ class SQLCompiler(engine.Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
-
+
+ def render_literal_bindparam(self, bindparam, **kw):
+ value = bindparam.value
+ processor = bindparam.bind_processor(self.dialect)
+ if processor:
+ value = processor(value)
+ return self.render_literal_value(value)
+
+ def render_literal_value(self, value):
+ """Render the value of a bind parameter as a quoted literal.
+
+ This is used for statement sections that do not accept bind paramters
+ on the target driver/database.
+
+ This should be implemented by subclasses using the quoting services
+ of the DBAPI.
+
+ """
+ raise NotImplementedError()
+
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
@@ -624,33 +672,33 @@ class SQLCompiler(engine.Compiled):
text = "SELECT " # we're off to a good start !
if select._prefixes:
- text += " ".join(self.process(x) for x in select._prefixes) + " "
+ text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
if froms:
text += " \nFROM "
- text += ', '.join(self.process(f, asfrom=True) for f in froms)
+ text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms)
else:
text += self.default_from()
if select._whereclause is not None:
- t = self.process(select._whereclause)
+ t = self.process(select._whereclause, **kwargs)
if t:
text += " \nWHERE " + t
if select._group_by_clause.clauses:
- group_by = self.process(select._group_by_clause)
+ group_by = self.process(select._group_by_clause, **kwargs)
if group_by:
text += " GROUP BY " + group_by
if select._having is not None:
- t = self.process(select._having)
+ t = self.process(select._having, **kwargs)
if t:
text += " \nHAVING " + t
if select._order_by_clause.clauses:
- text += self.order_by_clause(select)
+ text += self.order_by_clause(select, **kwargs)
if select._limit is not None or select._offset is not None:
text += self.limit_clause(select)
if select.for_update:
@@ -670,8 +718,8 @@ class SQLCompiler(engine.Compiled):
"""
return select._distinct and "DISTINCT " or ""
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
+ def order_by_clause(self, select, **kw):
+ order_by = self.process(select._order_by_clause, **kw)
if order_by:
return " ORDER BY " + order_by
else: