summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-03-17 15:15:44 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2010-03-17 15:15:44 -0400
commit318f47dc80c58dee8c798afcc8c19a5dbb21eef7 (patch)
treeefd951b139017acb14302b33b2cad9ecab88fb3b /lib/sqlalchemy/sql/compiler.py
parentb81e9741ba26f2740725c9d403d116284af7d7a4 (diff)
parent55367ac4a26dbc3c0e57783c3964cd5c42647a35 (diff)
downloadsqlalchemy-318f47dc80c58dee8c798afcc8c19a5dbb21eef7.tar.gz
- added pyodbc for sybase driver.
- generalized the "freetds" / "unicode statements" behavior of MS-SQL/pyodbc into the base Pyodbc connector, as this seems to apply to Sybase as well. - generalized the python-sybase "use autocommit for DDL" into the pyodbc connector. With pyodbc, the "autocommit" flag on connection is used, as Pyodbc seems to have more database conversation than python-sybase that can't otherwise be suppressed. - Some platforms will now interpret certain literal values as non-bind parameters, rendered literally into the SQL statement. This to support strict SQL-92 rules that are enforced by some platforms including MS-SQL and Sybase. In this model, bind parameters aren't allowed in the columns clause of a SELECT, nor are certain ambiguous expressions like "?=?". When this mode is enabled, the base compiler will render the binds as inline literals, but only across strings and numeric values. Other types such as dates will raise an error, unless the dialect subclass defines a literal rendering function for those. The bind parameter must have an embedded literal value already or an error is raised (i.e. won't work with straight bindparam('x')). Dialects can also expand upon the areas where binds are not accepted, such as within argument lists of functions (which don't work on MS-SQL when native SQL binding is used).
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py141
1 files changed, 106 insertions, 35 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index be3375def..4e9175ae8 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -26,6 +26,7 @@ import re
from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions, util as sql_util, visitors
from sqlalchemy.sql import expression as sql
+import decimal
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -183,6 +184,12 @@ class SQLCompiler(engine.Compiled):
# clauses before the VALUES or WHERE clause (i.e. MSSQL)
returning_precedes_values = False
+ # SQL 92 doesn't allow bind parameters to be used
+ # in the columns clause of a SELECT, nor does it allow
+ # ambiguous expressions like "? = ?". A compiler
+ # subclass can set this flag to False if the target
+ # driver/DB enforces this
+ ansi_bind_rules = False
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -260,9 +267,14 @@ class SQLCompiler(engine.Compiled):
else:
if bindparam.required:
if _group_number:
- raise exc.InvalidRequestError("A value is required for bind parameter %r, in parameter group %d" % (bindparam.key, _group_number))
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r, "
+ "in parameter group %d" %
+ (bindparam.key, _group_number))
else:
- raise exc.InvalidRequestError("A value is required for bind parameter %r" % bindparam.key)
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r"
+ % bindparam.key)
elif util.callable(bindparam.value):
pd[name] = bindparam.value()
else:
@@ -290,10 +302,10 @@ class SQLCompiler(engine.Compiled):
"""
return ""
- def visit_grouping(self, grouping, **kwargs):
- return "(" + self.process(grouping.element) + ")"
+ def visit_grouping(self, grouping, asfrom=False, **kwargs):
+ return "(" + self.process(grouping.element, **kwargs) + ")"
- def visit_label(self, label, result_map=None, within_columns_clause=False):
+ def visit_label(self, label, result_map=None, within_columns_clause=False, **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
@@ -305,11 +317,15 @@ class SQLCompiler(engine.Compiled):
result_map[labelname.lower()] = \
(label.name, (label, label.element, labelname), label.element.type)
- return self.process(label.element) + \
+ return self.process(label.element,
+ within_columns_clause=within_columns_clause,
+ **kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
- return self.process(label.element)
+ return self.process(label.element,
+ within_columns_clause=within_columns_clause,
+ **kw)
def visit_column(self, column, result_map=None, **kwargs):
name = column.name
@@ -384,27 +400,28 @@ class SQLCompiler(engine.Compiled):
sep = " "
else:
sep = OPERATORS[clauselist.operator]
- return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
+ return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses)
if s is not None)
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
- x += self.process(clause.value) + " "
+ x += self.process(clause.value, **kwargs) + " "
for cond, result in clause.whens:
- x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " "
+ x += "WHEN " + self.process(cond, **kwargs) + \
+ " THEN " + self.process(result, **kwargs) + " "
if clause.else_ is not None:
- x += "ELSE " + self.process(clause.else_) + " "
+ x += "ELSE " + self.process(clause.else_, **kwargs) + " "
x += "END"
return x
def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % \
- (self.process(cast.clause), self.process(cast.typeclause))
+ (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs))
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
- return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr))
+ return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs))
def visit_function(self, func, result_map=None, **kwargs):
if result_map is not None:
@@ -421,22 +438,23 @@ class SQLCompiler(engine.Compiled):
def function_argspec(self, func, **kwargs):
return self.process(func.clause_expr, **kwargs)
- def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
+ def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+ (self.process(c, asfrom=asfrom, parens=False,
+ compound_index=i, **kwargs)
for i, c in enumerate(cs.selects))
)
- group_by = self.process(cs._group_by_clause, asfrom=asfrom)
+ group_by = self.process(cs._group_by_clause, asfrom=asfrom, **kwargs)
if group_by:
text += " GROUP BY " + group_by
- text += self.order_by_clause(cs)
+ text += self.order_by_clause(cs, **kwargs)
text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
self.stack.pop(-1)
@@ -453,32 +471,47 @@ class SQLCompiler(engine.Compiled):
s = s + OPERATORS[unary.modifier]
return s
- def visit_binary(self, binary, **kwargs):
-
+ def visit_binary(self, binary, **kw):
+ # don't allow "? = ?" to render
+ if self.ansi_bind_rules and \
+ isinstance(binary.left, sql._BindParamClause) and \
+ isinstance(binary.right, sql._BindParamClause):
+ kw['literal_binds'] = True
+
return self._operator_dispatch(binary.operator,
binary,
- lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
- **kwargs
+ lambda opstr: self.process(binary.left, **kw) +
+ opstr +
+ self.process(binary.right, **kw),
+ **kw
)
def visit_like_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ return '%s LIKE %s' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def visit_notlike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ return '%s NOT LIKE %s' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def visit_ilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ return 'lower(%s) LIKE lower(%s)' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def visit_notilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ return 'lower(%s) NOT LIKE lower(%s)' % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
def _operator_dispatch(self, operator, element, fn, **kw):
@@ -491,7 +524,16 @@ class SQLCompiler(engine.Compiled):
else:
return fn(" " + operator + " ")
- def visit_bindparam(self, bindparam, **kwargs):
+ def visit_bindparam(self, bindparam, within_columns_clause=False,
+ literal_binds=False, **kwargs):
+ if literal_binds or \
+ (within_columns_clause and \
+ self.ansi_bind_rules):
+ if bindparam.value is None:
+ raise exc.CompileError("Bind parameter without a "
+ "renderable value not allowed here.")
+ return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
+
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
@@ -510,7 +552,36 @@ class SQLCompiler(engine.Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
-
+
+ def render_literal_bindparam(self, bindparam, **kw):
+ value = bindparam.value
+ processor = bindparam.bind_processor(self.dialect)
+ if processor:
+ value = processor(value)
+ return self.render_literal_value(value, bindparam.type)
+
+ def render_literal_value(self, value, type_):
+ """Render the value of a bind parameter as a quoted literal.
+
+ This is used for statement sections that do not accept bind paramters
+ on the target driver/database.
+
+ This should be implemented by subclasses using the quoting services
+ of the DBAPI.
+
+ """
+ if isinstance(value, basestring):
+ value = value.replace("'", "''")
+ return "'%s'" % value
+ elif value is None:
+ return "NULL"
+ elif isinstance(value, (float, int, long)):
+ return repr(value)
+ elif isinstance(value, decimal.Decimal):
+ return str(value)
+ else:
+ raise NotImplementedError("Don't know how to literal-quote value %r" % value)
+
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
@@ -624,33 +695,33 @@ class SQLCompiler(engine.Compiled):
text = "SELECT " # we're off to a good start !
if select._prefixes:
- text += " ".join(self.process(x) for x in select._prefixes) + " "
+ text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
if froms:
text += " \nFROM "
- text += ', '.join(self.process(f, asfrom=True) for f in froms)
+ text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms)
else:
text += self.default_from()
if select._whereclause is not None:
- t = self.process(select._whereclause)
+ t = self.process(select._whereclause, **kwargs)
if t:
text += " \nWHERE " + t
if select._group_by_clause.clauses:
- group_by = self.process(select._group_by_clause)
+ group_by = self.process(select._group_by_clause, **kwargs)
if group_by:
text += " GROUP BY " + group_by
if select._having is not None:
- t = self.process(select._having)
+ t = self.process(select._having, **kwargs)
if t:
text += " \nHAVING " + t
if select._order_by_clause.clauses:
- text += self.order_by_clause(select)
+ text += self.order_by_clause(select, **kwargs)
if select._limit is not None or select._offset is not None:
text += self.limit_clause(select)
if select.for_update:
@@ -670,8 +741,8 @@ class SQLCompiler(engine.Compiled):
"""
return select._distinct and "DISTINCT " or ""
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
+ def order_by_clause(self, select, **kw):
+ order_by = self.process(select._order_by_clause, **kw)
if order_by:
return " ORDER BY " + order_by
else: