diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 158 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 52 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 2 |
4 files changed, 188 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 8a9f0b979..a7a856bba 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -294,11 +294,11 @@ class BinaryElementImpl( def _post_coercion(self, resolved, expr, **kw): if ( - isinstance(resolved, elements.BindParameter) + isinstance(resolved, (elements.Grouping, elements.BindParameter)) and resolved.type._isnull + and not expr.type._isnull ): - resolved = resolved._clone() - resolved.type = expr.type + resolved = resolved._with_binary_element_type(expr.type) return resolved @@ -360,6 +360,7 @@ class InElementImpl(RoleImpl, roles.InElementRole): element = element._with_expanding_in_types( [elem.type for elem in expr] ) + return element else: return element diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fa7eeaecf..8df93a60b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -36,6 +36,7 @@ from . import roles from . import schema from . import selectable from . import sqltypes +from .base import NO_ARG from .. import exc from .. import util @@ -463,14 +464,6 @@ class SQLCompiler(Compiled): columns with the table name (i.e. MySQL only) """ - contains_expanding_parameters = False - """True if we've encountered bindparam(..., expanding=True). - - These need to be converted before execution time against the - string statement. - - """ - ansi_bind_rules = False """SQL 92 doesn't allow bind parameters to be used in the columns clause of a SELECT, nor does it allow @@ -507,6 +500,8 @@ class SQLCompiler(Compiled): """ + literal_execute_params = frozenset() + insert_prefetch = update_prefetch = () def __init__( @@ -1267,6 +1262,81 @@ class SQLCompiler(Compiled): % self.dialect.name ) + def _literal_execute_expanding_parameter_literal_binds( + self, parameter, values + ): + if not values: + replacement_expression = self.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] + ) + + elif isinstance(values[0], (tuple, list)): + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.render_literal_value(value, parameter.type) + for value in tuple_element + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + replacement_expression = ", ".join( + self.render_literal_value(value, parameter.type) + for value in values + ) + + return (), replacement_expression + + def _literal_execute_expanding_parameter(self, name, parameter, values): + if parameter.literal_execute: + return self._literal_execute_expanding_parameter_literal_binds( + parameter, values + ) + + if not values: + to_update = [] + replacement_expression = self.visit_empty_set_expr( + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] + ) + + elif isinstance(values[0], (tuple, list)): + to_update = [ + ("%s_%s_%s" % (name, i, j), value) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + ] + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.bindtemplate + % {"name": to_update[i * len(tuple_element) + j][0]} + for j, value in enumerate(tuple_element) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + to_update = [ + ("%s_%s" % (name, i), value) + for i, value in enumerate(values, 1) + ] + replacement_expression = ", ".join( + self.bindtemplate % {"name": key} for key, value in to_update + ) + + return to_update, replacement_expression + def visit_binary( self, binary, override_operator=None, eager_grouping=False, **kw ): @@ -1457,6 +1527,7 @@ class SQLCompiler(Compiled): within_columns_clause=False, literal_binds=False, skip_bind_expression=False, + literal_execute=False, **kwargs ): @@ -1469,18 +1540,28 @@ class SQLCompiler(Compiled): skip_bind_expression=True, within_columns_clause=within_columns_clause, literal_binds=literal_binds, + literal_execute=literal_execute, **kwargs ) - if literal_binds or (within_columns_clause and self.ansi_bind_rules): - if bindparam.value is None and bindparam.callable is None: - raise exc.CompileError( - "Bind parameter '%s' without a " - "renderable value not allowed here." % bindparam.key - ) - return self.render_literal_bindparam( + if not literal_binds: + post_compile = ( + literal_execute + or bindparam.literal_execute + or bindparam.expanding + ) + else: + post_compile = False + + if not literal_execute and ( + literal_binds or (within_columns_clause and self.ansi_bind_rules) + ): + ret = self.render_literal_bindparam( bindparam, within_columns_clause=True, **kwargs ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret name = self._truncate_bindparam(bindparam) @@ -1508,13 +1589,38 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string( - name, expanding=bindparam.expanding, **kwargs + if post_compile: + self.literal_execute_params |= {bindparam} + + ret = self.bindparam_string( + name, + post_compile=post_compile, + expanding=bindparam.expanding, + **kwargs ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret + + def render_literal_bindparam( + self, bindparam, render_literal_value=NO_ARG, **kw + ): + if render_literal_value is not NO_ARG: + value = render_literal_value + else: + if bindparam.value is None and bindparam.callable is None: + raise exc.CompileError( + "Bind parameter '%s' without a " + "renderable value not allowed here." % bindparam.key + ) + value = bindparam.effective_value - def render_literal_bindparam(self, bindparam, **kw): - value = bindparam.effective_value - return self.render_literal_value(value, bindparam.type) + if bindparam.expanding: + leep = self._literal_execute_expanding_parameter_literal_binds + to_update, replacement_expr = leep(bindparam, value) + return replacement_expr + else: + return self.render_literal_value(value, bindparam.type) def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. @@ -1577,16 +1683,20 @@ class SQLCompiler(Compiled): return derived + "_" + str(anonymous_counter) def bindparam_string( - self, name, positional_names=None, expanding=False, **kw + self, + name, + positional_names=None, + post_compile=False, + expanding=False, + **kw ): if self.positional: if positional_names is not None: positional_names.append(name) else: self.positiontup.append(name) - if expanding: - self.contains_expanding_parameters = True - return "([EXPANDING_%s])" % name + if post_compile: + return "[POSTCOMPILE_%s]" % name else: return self.bindtemplate % {"name": name} diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 669519d1a..42e7522ae 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -211,6 +211,15 @@ class ClauseElement(roles.SQLRole, Visitable): return c + def _with_binary_element_type(self, type_): + """in the context of binary expression, convert the type of this + object to the one given. + + applies only to :class:`.ColumnElement` classes. + + """ + return self + def _cache_key(self, **kw): """return an optional cache key. @@ -732,6 +741,14 @@ class ColumnElement( def type(self): return type_api.NULLTYPE + def _with_binary_element_type(self, type_): + cloned = self._clone() + cloned._copy_internals( + clone=lambda element: element._with_binary_element_type(type_) + ) + cloned.type = type_ + return cloned + @util.memoized_property def comparator(self): try: @@ -986,6 +1003,7 @@ class BindParameter(roles.InElementRole, ColumnElement): callable_=None, expanding=False, isoutparam=False, + literal_execute=False, _compared_to_operator=None, _compared_to_type=None, ): @@ -1198,6 +1216,30 @@ class BindParameter(roles.InElementRole, ColumnElement): :func:`.outparam` + :param literal_execute: + if True, the bound parameter will be rendered in the compile phase + with a special "POSTCOMPILE" token, and the SQLAlchemy compiler will + render the final value of the parameter into the SQL statement at + statement execution time, omitting the value from the parameter + dictionary / list passed to DBAPI ``cursor.execute()``. This + produces a similar effect as that of using the ``literal_binds``, + compilation flag, however takes place as the statement is sent to + the DBAPI ``cursor.execute()`` method, rather than when the statement + is compiled. The primary use of this + capability is for rendering LIMIT / OFFSET clauses for database + drivers that can't accommodate for bound parameters in these + contexts, while allowing SQL constructs to be cacheable at the + compilation level. + + .. versionadded:: 1.4 Added "post compile" bound parameters + + .. seealso:: + + :ref:`change_4808`. + + + + """ if isinstance(key, ColumnClause): type_ = key.type @@ -1235,6 +1277,7 @@ class BindParameter(roles.InElementRole, ColumnElement): self.isoutparam = isoutparam self.required = required self.expanding = expanding + self.literal_execute = literal_execute if type_ is None: if _compared_to_type is not None: @@ -1643,14 +1686,17 @@ class TextClause( for bind in binds: try: - existing = new_params[bind.key] + # the regex used for text() currently will not match + # a unique/anonymous key in any case, so use the _orig_key + # so that a text() construct can support unique parameters + existing = new_params[bind._orig_key] except KeyError: raise exc.ArgumentError( "This text() construct doesn't define a " - "bound parameter named %r" % bind.key + "bound parameter named %r" % bind._orig_key ) else: - new_params[existing.key] = bind + new_params[existing._orig_key] = bind for key, value in names_to_values.items(): try: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index fd15d7c79..7829eb4d0 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -471,7 +471,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine): def literal_processor(self, dialect): def process(value): - return str(value) + return str(int(value)) return process |