summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/base.py47
-rw-r--r--lib/sqlalchemy/sql/coercions.py13
-rw-r--r--lib/sqlalchemy/sql/compiler.py87
-rw-r--r--lib/sqlalchemy/sql/crud.py181
-rw-r--r--lib/sqlalchemy/sql/dml.py537
-rw-r--r--lib/sqlalchemy/sql/elements.py14
-rw-r--r--lib/sqlalchemy/sql/schema.py25
-rw-r--r--lib/sqlalchemy/sql/selectable.py2
-rw-r--r--lib/sqlalchemy/sql/traversals.py218
-rw-r--r--lib/sqlalchemy/sql/visitors.py21
10 files changed, 807 insertions, 338 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 2d336360f..89839ea28 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -16,6 +16,7 @@ import re
from .traversals import HasCacheKey # noqa
from .visitors import ClauseVisitor
+from .visitors import InternalTraversal
from .. import exc
from .. import util
@@ -221,6 +222,10 @@ class DialectKWArgs(object):
"""
+ _dialect_kwargs_traverse_internals = [
+ ("dialect_options", InternalTraversal.dp_dialect_options)
+ ]
+
@classmethod
def argument_for(cls, dialect_name, argument_name, default):
"""Add a new kind of dialect-specific keyword argument for this class.
@@ -386,6 +391,39 @@ class DialectKWArgs(object):
construct_arg_dictionary[arg_name] = kwargs[k]
+class CompileState(object):
+ """Produces additional object state necessary for a statement to be
+ compiled.
+
+ the :class:`.CompileState` class is at the base of classes that assemble
+ state for a particular statement object that is then used by the
+ compiler. This process is essentially an extension of the process that
+ the SQLCompiler.visit_XYZ() method takes, however there is an emphasis
+ on converting raw user intent into more organized structures rather than
+ producing string output. The top-level :class:`.CompileState` for the
+ statement being executed is also accessible when the execution context
+ works with invoking the statement and collecting results.
+
+ The production of :class:`.CompileState` is specific to the compiler, such
+ as within the :meth:`.SQLCompiler.visit_insert`,
+ :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also
+ responsible for associating the :class:`.CompileState` with the
+ :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement,
+ i.e. the outermost SQL statement that's actually being executed.
+ There can be other :class:`.CompileState` objects that are not the
+ toplevel, such as when a SELECT subquery or CTE-nested
+ INSERT/UPDATE/DELETE is generated.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("statement",)
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+
class Generative(object):
"""Provide a method-chaining pattern in conjunction with the
@_generative decorator."""
@@ -396,6 +434,12 @@ class Generative(object):
return s
+class HasCompileState(Generative):
+ """A class that has a :class:`.CompileState` associated with it."""
+
+ _compile_state_cls = CompileState
+
+
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
@@ -627,6 +671,9 @@ class ColumnCollection(object):
def keys(self):
return [k for (k, col) in self._collection]
+ def __bool__(self):
+ return bool(self._collection)
+
def __len__(self):
return len(self._collection)
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index fc841bb4b..679d9c6e9 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -55,7 +55,10 @@ def expect(role, element, **kw):
# elaborate logic up front if possible
impl = _impl_lookup[role]
- if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)):
+ if not isinstance(
+ element,
+ (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue),
+ ):
resolved = impl._resolve_for_clause_element(element, **kw)
else:
resolved = element
@@ -194,7 +197,9 @@ class _ColumnCoercions(object):
def _implicit_coercions(
self, original_element, resolved, argname=None, **kw
):
- if resolved._is_select_statement:
+ if not resolved.is_clause_element:
+ self._raise_for_expected(original_element, argname, resolved)
+ elif resolved._is_select_statement:
self._warn_for_scalar_subquery_coercion()
return resolved.scalar_subquery()
elif resolved._is_from_clause and isinstance(
@@ -290,14 +295,14 @@ class ExpressionElementImpl(
_ColumnCoercions, RoleImpl, roles.ExpressionElementRole
):
def _literal_coercion(
- self, element, name=None, type_=None, argname=None, **kw
+ self, element, name=None, type_=None, argname=None, is_crud=False, **kw
):
if element is None:
return elements.Null()
else:
try:
return elements.BindParameter(
- name, element, type_, unique=True
+ name, element, type_, unique=True, _is_crud=is_crud
)
except exc.ArgumentError as err:
self._raise_for_expected(element, err=err)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 424282951..3ebcf24b0 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -653,6 +653,20 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
+ compile_state = None
+ """Optional :class:`.CompileState` object that maintains additional
+ state used by the compiler.
+
+ Major executable objects such as :class:`.Insert`, :class:`.Update`,
+ :class:`.Delete`, :class:`.Select` will generate this state when compiled
+ in order to calculate additional information about the object. For the
+ top level object that is to be executed, the state can be stored here where
+ it can also have applicability towards result set processing.
+
+ .. versionadded:: 1.4
+
+ """
+
def __init__(
self,
dialect,
@@ -1292,6 +1306,13 @@ class SQLCompiler(Compiled):
else:
return "0"
+ def _generate_delimited_list(self, elements, separator, **kw):
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in elements)
+ if s
+ )
+
def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
@@ -1299,13 +1320,7 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
- text = sep.join(
- s
- for s in (
- c._compiler_dispatch(self, **kw) for c in clauselist.clauses
- )
- if s
- )
+ text = self._generate_delimited_list(clauselist.clauses, sep, **kw)
if clauselist._tuple_values and self.dialect.tuple_in_values:
text = "VALUES " + text
return text
@@ -2810,8 +2825,18 @@ class SQLCompiler(Compiled):
return dialect_hints, table_text
def visit_insert(self, insert_stmt, **kw):
+
+ compile_state = insert_stmt._compile_state_cls(
+ insert_stmt, self, isinsert=True, **kw
+ )
+ insert_stmt = compile_state.statement
+
toplevel = not self.stack
+ if toplevel:
+ self.isinsert = True
+ self.compile_state = compile_state
+
self.stack.append(
{
"correlate_froms": set(),
@@ -2820,8 +2845,8 @@ class SQLCompiler(Compiled):
}
)
- crud_params = crud._setup_crud_params(
- self, insert_stmt, crud.ISINSERT, **kw
+ crud_params = crud._get_crud_params(
+ self, insert_stmt, compile_state, **kw
)
if (
@@ -2835,7 +2860,7 @@ class SQLCompiler(Compiled):
"inserts." % self.dialect.name
)
- if insert_stmt._has_multi_parameters:
+ if compile_state._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
@@ -2888,7 +2913,7 @@ class SQLCompiler(Compiled):
text += " %s" % select_text
elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
- elif insert_stmt._has_multi_parameters:
+ elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (", ".join(c[1] for c in crud_param_set))
@@ -2947,9 +2972,16 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
+ compile_state = update_stmt._compile_state_cls(
+ update_stmt, self, isupdate=True, **kw
+ )
+ update_stmt = compile_state.statement
+
toplevel = not self.stack
+ if toplevel:
+ self.isupdate = True
- extra_froms = update_stmt._extra_froms
+ extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
if is_multitable:
@@ -2981,8 +3013,8 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)
- crud_params = crud._setup_crud_params(
- self, update_stmt, crud.ISUPDATE, **kw
+ crud_params = crud._get_crud_params(
+ self, update_stmt, compile_state, **kw
)
if update_stmt._hints:
@@ -3022,8 +3054,10 @@ class SQLCompiler(Compiled):
if extra_from_text:
text += " " + extra_from_text
- if update_stmt._whereclause is not None:
- t = self.process(update_stmt._whereclause, **kw)
+ if update_stmt._where_criteria:
+ t = self._generate_delimited_list(
+ update_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ )
if t:
text += " WHERE " + t
@@ -3045,10 +3079,6 @@ class SQLCompiler(Compiled):
return text
- @util.memoized_property
- def _key_getters_for_crud_column(self):
- return crud._key_getters_for_crud_column(self, self.statement)
-
def delete_extra_from_clause(
self, update_stmt, from_table, extra_froms, from_hints, **kw
):
@@ -3069,11 +3099,16 @@ class SQLCompiler(Compiled):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, **kw):
- toplevel = not self.stack
+ compile_state = delete_stmt._compile_state_cls(
+ delete_stmt, self, isdelete=True, **kw
+ )
+ delete_stmt = compile_state.statement
- crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
+ toplevel = not self.stack
+ if toplevel:
+ self.isdelete = True
- extra_froms = delete_stmt._extra_froms
+ extra_froms = compile_state._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
self.stack.append(
@@ -3122,8 +3157,10 @@ class SQLCompiler(Compiled):
if extra_from_text:
text += " " + extra_from_text
- if delete_stmt._whereclause is not None:
- t = delete_stmt._whereclause._compiler_dispatch(self, **kw)
+ if delete_stmt._where_criteria:
+ t = self._generate_delimited_list(
+ delete_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ )
if t:
text += " WHERE " + t
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e474952ce..2827a5817 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -16,6 +16,7 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .elements import ClauseElement
from .. import exc
from .. import util
@@ -33,45 +34,8 @@ values present.
""",
)
-ISINSERT = util.symbol("ISINSERT")
-ISUPDATE = util.symbol("ISUPDATE")
-ISDELETE = util.symbol("ISDELETE")
-
-def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
- restore_isinsert = compiler.isinsert
- restore_isupdate = compiler.isupdate
- restore_isdelete = compiler.isdelete
-
- should_restore = (
- (restore_isinsert or restore_isupdate or restore_isdelete)
- or len(compiler.stack) > 1
- or "visiting_cte" in kw
- )
-
- if local_stmt_type is ISINSERT:
- compiler.isupdate = False
- compiler.isinsert = True
- elif local_stmt_type is ISUPDATE:
- compiler.isupdate = True
- compiler.isinsert = False
- elif local_stmt_type is ISDELETE:
- if not should_restore:
- compiler.isdelete = True
- else:
- assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
-
- try:
- if local_stmt_type in (ISINSERT, ISUPDATE):
- return _get_crud_params(compiler, stmt, **kw)
- finally:
- if should_restore:
- compiler.isinsert = restore_isinsert
- compiler.isupdate = restore_isupdate
- compiler.isdelete = restore_isdelete
-
-
-def _get_crud_params(compiler, stmt, **kw):
+def _get_crud_params(compiler, stmt, compile_state, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -87,27 +51,29 @@ def _get_crud_params(compiler, stmt, **kw):
compiler.update_prefetch = []
compiler.returning = []
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ (
+ _column_as_key,
+ _getattr_col_key,
+ _col_bind_name,
+ ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+
+ compiler._key_getters_for_crud_column = getters
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
- if compiler.column_keys is None and stmt.parameters is None:
+ if compiler.column_keys is None and compile_state._no_parameters:
return [
(c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
- if stmt._has_multi_parameters:
- stmt_parameters = stmt.parameters[0]
+ if compile_state._has_multi_parameters:
+ stmt_parameters = compile_state._multi_parameters[0]
else:
- stmt_parameters = stmt.parameters
-
- # getters - these are normally just column.key,
- # but in the case of mysql multi-table update, the rules for
- # .key must conditionally take tablename into account
- (
- _column_as_key,
- _getattr_col_key,
- _col_bind_name,
- ) = _key_getters_for_crud_column(compiler, stmt)
+ stmt_parameters = compile_state._dict_parameters
# if we have statement parameters - set defaults in the
# compiled params
@@ -132,10 +98,15 @@ def _get_crud_params(compiler, stmt, **kw):
# special logic that only occurs for multi-table UPDATE
# statements
- if compiler.isupdate and stmt._extra_froms and stmt_parameters:
+ if (
+ compile_state.isupdate
+ and compile_state._extra_froms
+ and stmt_parameters
+ ):
_get_multitable_params(
compiler,
stmt,
+ compile_state,
stmt_parameters,
check_columns,
_col_bind_name,
@@ -144,10 +115,11 @@ def _get_crud_params(compiler, stmt, **kw):
kw,
)
- if compiler.isinsert and stmt.select_names:
+ if compile_state.isinsert and stmt._select_names:
_scan_insert_from_select_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -160,6 +132,7 @@ def _get_crud_params(compiler, stmt, **kw):
_scan_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -181,8 +154,10 @@ def _get_crud_params(compiler, stmt, **kw):
% (", ".join("%s" % (c,) for c in check))
)
- if stmt._has_multi_parameters:
- values = _extend_values_for_multiparams(compiler, stmt, values, kw)
+ if compile_state._has_multi_parameters:
+ values = _extend_values_for_multiparams(
+ compiler, stmt, compile_state, values, kw
+ )
return values
@@ -201,15 +176,46 @@ def _create_bind_param(
return bindparam
-def _key_getters_for_crud_column(compiler, stmt):
- if compiler.isupdate and stmt._extra_froms:
+def _handle_values_anonymous_param(compiler, col, value, name, **kw):
+ # the insert() and update() constructs as of 1.4 will now produce anonymous
+ # bindparam() objects in the values() collections up front when given plain
+ # literal values. This is so that cache key behaviors, which need to
+ # produce bound parameters in deterministic order without invoking any
+ # compilation here, can be applied to these constructs when they include
+ # values() (but not yet multi-values, which are not included in caching
+ # right now).
+ #
+ # in order to produce the desired "crud" style name for these parameters,
+ # which will also be targetable in engine/default.py through the usual
+ # conventions, apply our desired name to these unique parameters by
+ # populating the compiler truncated names cache with the desired name,
+ # rather than having
+ # compiler.visit_bindparam()->compiler._truncated_identifier make up a
+ # name. Saves on call counts also.
+ if value.unique and isinstance(value.key, elements._truncated_label):
+ compiler.truncated_names[("bindparam", value.key)] = name
+
+ if value.type._isnull:
+ # either unique parameter, or other bound parameters that were
+ # passed in directly
+ # clone using base ClauseElement to retain unique key
+ value = ClauseElement._clone(value)
+
+ # set type to that of the column unconditionally
+ value.type = col.type
+
+ return value._compiler_dispatch(compiler, **kw)
+
+
+def _key_getters_for_crud_column(compiler, stmt, compile_state):
+ if compile_state.isupdate and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
- _et = set(stmt._extra_froms)
+ _et = set(compile_state._extra_froms)
c_key_role = functools.partial(
coercions.expect_as_key, roles.DMLColumnRole
@@ -246,6 +252,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -260,9 +267,9 @@ def _scan_insert_from_select_cols(
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt)
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
- cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
compiler._insert_from_select = stmt.select
@@ -294,6 +301,7 @@ def _scan_insert_from_select_cols(
def _scan_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -308,11 +316,11 @@ def _scan_cols(
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt)
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
- if stmt._parameter_ordering:
+ if compile_state._parameter_ordering:
parameter_ordering = [
- _column_as_key(key) for key in stmt._parameter_ordering
+ _column_as_key(key) for key in compile_state._parameter_ordering
]
ordered_keys = set(parameter_ordering)
cols = [stmt.table.c[key] for key in parameter_ordering] + [
@@ -329,6 +337,7 @@ def _scan_cols(
_append_param_parameter(
compiler,
stmt,
+ compile_state,
c,
col_key,
parameters,
@@ -339,7 +348,7 @@ def _scan_cols(
kw,
)
- elif compiler.isinsert:
+ elif compile_state.isinsert:
if (
c.primary_key
and need_pks
@@ -377,7 +386,7 @@ def _scan_cols(
):
_warn_pk_with_no_anticipated_value(c)
- elif compiler.isupdate:
+ elif compile_state.isupdate:
_append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
)
@@ -386,6 +395,7 @@ def _scan_cols(
def _append_param_parameter(
compiler,
stmt,
+ compile_state,
c,
col_key,
parameters,
@@ -395,7 +405,9 @@ def _append_param_parameter(
values,
kw,
):
+
value = parameters.pop(col_key)
+
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -403,15 +415,21 @@ def _append_param_parameter(
value,
required=value is REQUIRED,
name=_col_bind_name(c)
- if not stmt._has_multi_parameters
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler,
+ c,
+ value,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and value.type._isnull:
- value = value._clone()
- value.type = c.type
-
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -644,6 +662,7 @@ def _append_param_update(
def _get_multitable_params(
compiler,
stmt,
+ compile_state,
stmt_parameters,
check_columns,
_col_bind_name,
@@ -656,7 +675,7 @@ def _get_multitable_params(
for c, param in stmt_parameters.items()
)
affected_tables = set()
- for t in stmt._extra_froms:
+ for t in compile_state._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
@@ -669,6 +688,11 @@ def _get_multitable_params(
value,
required=value is REQUIRED,
name=_col_bind_name(c),
+ **kw # TODO: no test coverage for literal binds here
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler, c, value, name=_col_bind_name(c), **kw
)
else:
compiler.postfetch.append(c)
@@ -704,11 +728,11 @@ def _get_multitable_params(
compiler.postfetch.append(c)
-def _extend_values_for_multiparams(compiler, stmt, values, kw):
+def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
values_0 = values
values = [values]
- for i, row in enumerate(stmt.parameters[1:]):
+ for i, row in enumerate(compile_state._multi_parameters[1:]):
extension = []
for (col, param) in values_0:
if col in row or col.key in row:
@@ -757,12 +781,13 @@ def _get_stmt_parameters_params(
values.append((k, v))
-def _get_returning_modifiers(compiler, stmt):
+def _get_returning_modifiers(compiler, stmt, compile_state):
+
need_pks = (
- compiler.isinsert
+ compile_state.isinsert
and not compiler.inline
and not stmt._returning
- and not stmt._has_multi_parameters
+ and not compile_state._has_multi_parameters
)
implicit_returning = (
@@ -771,9 +796,9 @@ def _get_returning_modifiers(compiler, stmt):
and stmt.table.implicit_returning
)
- if compiler.isinsert:
+ if compile_state.isinsert:
implicit_return_defaults = implicit_returning and stmt._return_defaults
- elif compiler.isupdate:
+ elif compile_state.isupdate:
implicit_return_defaults = (
compiler.dialect.implicit_returning
and stmt.table.implicit_returning
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 097c513b4..171a2cc2c 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -8,25 +8,162 @@
Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
-
+from sqlalchemy.types import NullType
from . import coercions
from . import roles
from .base import _from_objects
from .base import _generative
+from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
-from .elements import and_
+from .base import HasCompileState
from .elements import ClauseElement
from .elements import Null
from .selectable import HasCTE
from .selectable import HasPrefixes
+from .visitors import InternalTraversal
from .. import exc
from .. import util
+from ..util import collections_abc
+
+
+class DMLState(CompileState):
+ _no_parameters = True
+ _dict_parameters = None
+ _multi_parameters = None
+ _parameter_ordering = None
+ _has_multi_parameters = False
+ isupdate = False
+ isdelete = False
+ isinsert = False
+
+ def __init__(
+ self,
+ statement,
+ compiler,
+ isinsert=False,
+ isupdate=False,
+ isdelete=False,
+ **kw
+ ):
+ self.statement = statement
+
+ if isupdate:
+ self.isupdate = True
+ self._preserve_parameter_order = (
+ statement._preserve_parameter_order
+ )
+ if statement._ordered_values is not None:
+ self._process_ordered_values(statement)
+ elif statement._values is not None:
+ self._process_values(statement)
+ elif statement._multi_values:
+ self._process_multi_values(statement)
+ self._extra_froms = self._make_extra_froms(statement)
+ elif isinsert:
+ self.isinsert = True
+ if statement._select_names:
+ self._process_select_values(statement)
+ if statement._values is not None:
+ self._process_values(statement)
+ if statement._multi_values:
+ self._process_multi_values(statement)
+ elif isdelete:
+ self.isdelete = True
+ self._extra_froms = self._make_extra_froms(statement)
+ else:
+ assert False, "one of isinsert, isupdate, or isdelete must be set"
+
+ def _make_extra_froms(self, statement):
+ froms = []
+ seen = {statement.table}
+
+ for crit in statement._where_criteria:
+ for item in _from_objects(crit):
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ return froms
+
+ def _process_multi_values(self, statement):
+ if not statement._supports_multi_parameters:
+ raise exc.InvalidRequestError(
+ "%s construct does not support "
+ "multiple parameter sets." % statement.__visit_name__.upper()
+ )
+
+ for parameters in statement._multi_values:
+ multi_parameters = [
+ {
+ c.key: value
+ for c, value in zip(statement.table.c, parameter_set)
+ }
+ if isinstance(parameter_set, collections_abc.Sequence)
+ else parameter_set
+ for parameter_set in parameters
+ ]
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._has_multi_parameters = True
+ self._multi_parameters = multi_parameters
+ self._dict_parameters = self._multi_parameters[0]
+ elif not self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ self._multi_parameters.extend(multi_parameters)
+
+ def _process_values(self, statement):
+ if self._no_parameters:
+ self._has_multi_parameters = False
+ self._dict_parameters = statement._values
+ self._no_parameters = False
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+
+ def _process_ordered_values(self, statement):
+ parameters = statement._ordered_values
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = dict(parameters)
+ self._parameter_ordering = [key for key, value in parameters]
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ raise exc.InvalidRequestError(
+ "Can only invoke ordered_values() once, and not mixed "
+ "with any other values() call"
+ )
+
+ def _process_select_values(self, statement):
+ parameters = {
+ coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
+ for name in statement._select_names
+ }
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = parameters
+ else:
+ # this condition normally not reachable as the Insert
+ # does not allow this construction to occur
+ assert False, "This statement already has parameters"
+
+ def _cant_mix_formats_error(self):
+ raise exc.InvalidRequestError(
+ "Can't mix single and multiple VALUES "
+ "formats in one INSERT statement; one style appends to a "
+ "list while the other replaces values, so the intent is "
+ "ambiguous."
+ )
class UpdateBase(
roles.DMLRole,
HasCTE,
+ HasCompileState,
DialectKWArgs,
HasPrefixes,
Executable,
@@ -42,10 +179,10 @@ class UpdateBase(
{"autocommit": True}
)
_hints = util.immutabledict()
- _parameter_ordering = None
- _prefixes = ()
named_with_column = False
+ _compile_state_cls = DMLState
+
@classmethod
def _constructor_20_deprecations(cls, fn_name, clsname, names):
@@ -112,43 +249,6 @@ class UpdateBase(
col._make_proxy(fromclause) for col in self._returning
)
- def _process_colparams(self, parameters, preserve_parameter_order=False):
- def process_single(p):
- if isinstance(p, (list, tuple)):
- return dict((c.key, pval) for c, pval in zip(self.table.c, p))
- else:
- return p
-
- if (
- preserve_parameter_order or self._preserve_parameter_order
- ) and parameters is not None:
- if not isinstance(parameters, list) or (
- parameters and not isinstance(parameters[0], tuple)
- ):
- raise ValueError(
- "When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples"
- )
- self._parameter_ordering = [key for key, value in parameters]
-
- return dict(parameters), False
-
- if (
- isinstance(parameters, (list, tuple))
- and parameters
- and isinstance(parameters[0], (list, tuple, dict))
- ):
-
- if not self._supports_multi_parameters:
- raise exc.InvalidRequestError(
- "This construct does not support "
- "multiple parameter sets."
- )
-
- return [process_single(p) for p in parameters], True
- else:
- return process_single(parameters), False
-
def params(self, *arg, **kw):
"""Set the parameters for the statement.
@@ -163,6 +263,29 @@ class UpdateBase(
" stmt.values(**parameters)."
)
+ @_generative
+ def with_dialect_options(self, **opt):
+ """Add dialect options to this INSERT/UPDATE/DELETE object.
+
+ e.g.::
+
+ upd = table.update().dialect_options(mysql_limit=10)
+
+ .. versionadded: 1.4 - this method supersedes the dialect options
+ associated with the constructor.
+
+
+ """
+ self._validate_dialect_kwargs(opt)
+
+ def _validate_dialect_kwargs_deprecated(self, dialect_kw):
+ util.warn_deprecated_20(
+ "Passing dialect keyword arguments directly to the "
+ "constructor is deprecated and will be removed in SQLAlchemy "
+ "2.0. Please use the ``with_dialect_options()`` method."
+ )
+ self._validate_dialect_kwargs(dialect_kw)
+
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
or a :class:`.Table` associated with it.
@@ -266,9 +389,6 @@ class UpdateBase(
self._hints = self._hints.union({(selectable, dialect_name): text})
- def _copy_internals(self, **kw):
- raise NotImplementedError()
-
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
@@ -277,16 +397,21 @@ class ValuesBase(UpdateBase):
__visit_name__ = "values_base"
_supports_multi_parameters = False
- _has_multi_parameters = False
_preserve_parameter_order = False
select = None
_post_values_clause = None
+ _values = None
+ _multi_values = ()
+ _ordered_values = None
+ _select_names = None
+
+ _returning = ()
+
def __init__(self, table, values, prefixes):
self.table = coercions.expect(roles.FromClauseRole, table)
- self.parameters, self._has_multi_parameters = self._process_colparams(
- values
- )
+ if values is not None:
+ self.values.non_generative(self, values)
if prefixes:
self._setup_prefixes(prefixes)
@@ -416,59 +541,96 @@ class ValuesBase(UpdateBase):
:func:`~.expression.update` - produce an ``UPDATE`` statement
"""
- if self.select is not None:
+ if self._select_names:
raise exc.InvalidRequestError(
"This construct already inserts from a SELECT"
)
- if self._has_multi_parameters and kwargs:
- raise exc.InvalidRequestError(
- "This construct already has multiple parameter sets."
+ elif self._ordered_values:
+ raise exc.ArgumentError(
+ "This statement already has ordered values present"
)
if args:
- if len(args) > 1:
+ # positional case. this is currently expensive. we don't
+ # yet have positional-only args so we have to check the length.
+ # then we need to check multiparams vs. single dictionary.
+ # since the parameter format is needed in order to determine
+ # a cache key, we need to determine this up front.
+ arg = args[0]
+
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't pass positional and kwargs to values() "
+ "simultaneously"
+ )
+ elif len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
"dictionaries/tuples is accepted positionally."
)
- v = args[0]
- else:
- v = {}
- if self.parameters is None:
- (
- self.parameters,
- self._has_multi_parameters,
- ) = self._process_colparams(v)
- else:
- if self._has_multi_parameters:
- self.parameters = list(self.parameters)
- p, self._has_multi_parameters = self._process_colparams(v)
- if not self._has_multi_parameters:
- raise exc.ArgumentError(
- "Can't mix single-values and multiple values "
- "formats in one statement"
- )
+ elif not self._preserve_parameter_order and isinstance(
+ arg, collections_abc.Sequence
+ ):
- self.parameters.extend(p)
- else:
- self.parameters = self.parameters.copy()
- p, self._has_multi_parameters = self._process_colparams(v)
- if self._has_multi_parameters:
- raise exc.ArgumentError(
- "Can't mix single-values and multiple values "
- "formats in one statement"
- )
- self.parameters.update(p)
+ if arg and isinstance(arg[0], (list, dict, tuple)):
+ self._multi_values += (arg,)
+ return
- if kwargs:
- if self._has_multi_parameters:
+ # tuple values
+ arg = {c.key: value for c, value in zip(self.table.c, arg)}
+ elif self._preserve_parameter_order and not isinstance(
+ arg, collections_abc.Sequence
+ ):
+ raise ValueError(
+ "When preserve_parameter_order is True, "
+ "values() only accepts a list of 2-tuples"
+ )
+
+ else:
+ # kwarg path. this is the most common path for non-multi-params
+ # so this is fairly quick.
+ arg = kwargs
+ if args:
raise exc.ArgumentError(
- "Can't pass kwargs and multiple parameter sets "
- "simultaneously"
+ "Only a single dictionary/tuple or list of "
+ "dictionaries/tuples is accepted positionally."
)
+
+ # for top level values(), convert literals to anonymous bound
+ # parameters at statement construction time, so that these values can
+ # participate in the cache key process like any other ClauseElement.
+ # crud.py now intercepts bound parameters with unique=True from here
+ # and ensures they get the "crud"-style name when rendered.
+
+ if self._preserve_parameter_order:
+ arg = [
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ ),
+ )
+ for k, v in arg
+ ]
+ self._ordered_values = arg
+ else:
+ arg = {
+ k: coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ )
+ for k, v in arg.items()
+ }
+ if self._values:
+ self._values = self._values.union(arg)
else:
- self.parameters.update(kwargs)
+ self._values = util.immutabledict(arg)
@_generative
def return_defaults(self, *cols):
@@ -555,6 +717,25 @@ class Insert(ValuesBase):
_supports_multi_parameters = True
+ select = None
+ include_insert_from_select_defaults = False
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_select_names", InternalTraversal.dp_string_list),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_multi_values", InternalTraversal.dp_dml_multi_values),
+ ("select", InternalTraversal.dp_clauseelement),
+ ("_post_values_clause", InternalTraversal.dp_clauseelement),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ )
+
@ValuesBase._constructor_20_deprecations(
"insert",
"Insert",
@@ -626,18 +807,13 @@ class Insert(ValuesBase):
"""
super(Insert, self).__init__(table, values, prefixes)
self._bind = bind
- self.select = self.select_names = None
- self.include_insert_from_select_defaults = False
self._inline = inline
- self._returning = returning
- self._validate_dialect_kwargs(dialect_kw)
- self._return_defaults = return_defaults
+ if returning:
+ self._returning = returning
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
- def get_children(self, **kwargs):
- if self.select is not None:
- return (self.select,)
- else:
- return ()
+ self._return_defaults = return_defaults
@_generative
def inline(self):
@@ -702,25 +878,34 @@ class Insert(ValuesBase):
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
"""
- if self.parameters:
+
+ if self._values:
raise exc.InvalidRequestError(
"This construct already inserts value expressions"
)
- self.parameters, self._has_multi_parameters = self._process_colparams(
- {
- coercions.expect(roles.DMLColumnRole, n, as_key=True): Null()
- for n in names
- }
- )
-
- self.select_names = names
+ self._select_names = names
self._inline = True
self.include_insert_from_select_defaults = include_defaults
self.select = coercions.expect(roles.DMLSelectRole, select)
-class Update(ValuesBase):
+class DMLWhereBase(object):
+ _where_criteria = ()
+
+ @_generative
+ def where(self, whereclause):
+ """return a new construct with the given expression added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+ """
+
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
+ )
+
+
+class Update(DMLWhereBase, ValuesBase):
"""Represent an Update construct.
The :class:`.Update` object is created using the :func:`update()`
@@ -730,6 +915,20 @@ class Update(ValuesBase):
__visit_name__ = "update"
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_ordered_values", InternalTraversal.dp_dml_ordered_values),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ )
+
@ValuesBase._constructor_20_deprecations(
"update",
"Update",
@@ -874,21 +1073,14 @@ class Update(ValuesBase):
self._bind = bind
self._returning = returning
if whereclause is not None:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
)
- else:
- self._whereclause = None
self._inline = inline
- self._validate_dialect_kwargs(dialect_kw)
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
self._return_defaults = return_defaults
- def get_children(self, **kwargs):
- if self._whereclause is not None:
- return (self._whereclause,)
- else:
- return ()
-
@_generative
def ordered_values(self, *args):
"""Specify the VALUES clause of this UPDATE statement with an explicit
@@ -912,22 +1104,27 @@ class Update(ValuesBase):
parameter, which will be removed in SQLAlchemy 2.0.
"""
- if self.select is not None:
- raise exc.InvalidRequestError(
- "This construct already inserts from a SELECT"
- )
-
- if self.parameters is None:
- (
- self.parameters,
- self._has_multi_parameters,
- ) = self._process_colparams(
- list(args), preserve_parameter_order=True
- )
- else:
+ if self._values:
raise exc.ArgumentError(
"This statement already has values present"
)
+ elif self._ordered_values:
+ raise exc.ArgumentError(
+ "This statement already has ordered values present"
+ )
+ arg = [
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ ),
+ )
+ for k, v in args
+ ]
+ self._ordered_values = arg
@_generative
def inline(self):
@@ -945,37 +1142,8 @@ class Update(ValuesBase):
"""
self._inline = True
- @_generative
- def where(self, whereclause):
- """return a new update() construct with the given expression added to
- its WHERE clause, joined to the existing clause via AND, if any.
-
- """
- if self._whereclause is not None:
- self._whereclause = and_(
- self._whereclause,
- coercions.expect(roles.WhereHavingRole, whereclause),
- )
- else:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
- )
-
- @property
- def _extra_froms(self):
- froms = []
- seen = {self.table}
-
- if self._whereclause is not None:
- for item in _from_objects(self._whereclause):
- if not seen.intersection(item._cloned_set):
- froms.append(item)
- seen.update(item._cloned_set)
-
- return froms
-
-class Delete(UpdateBase):
+class Delete(DMLWhereBase, UpdateBase):
"""Represent a DELETE construct.
The :class:`.Delete` object is created using the :func:`delete()`
@@ -985,6 +1153,17 @@ class Delete(UpdateBase):
__visit_name__ = "delete"
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ )
+
@ValuesBase._constructor_20_deprecations(
"delete",
"Delete",
@@ -1041,43 +1220,9 @@ class Delete(UpdateBase):
self._setup_prefixes(prefixes)
if whereclause is not None:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
- )
- else:
- self._whereclause = None
-
- self._validate_dialect_kwargs(dialect_kw)
-
- def get_children(self, **kwargs):
- if self._whereclause is not None:
- return (self._whereclause,)
- else:
- return ()
-
- @_generative
- def where(self, whereclause):
- """Add the given WHERE clause to a newly returned delete construct."""
-
- if self._whereclause is not None:
- self._whereclause = and_(
- self._whereclause,
+ self._where_criteria += (
coercions.expect(roles.WhereHavingRole, whereclause),
)
- else:
- self._whereclause = coercions.expect(
- roles.WhereHavingRole, whereclause
- )
-
- @property
- def _extra_froms(self):
- froms = []
- seen = {self.table}
- if self._whereclause is not None:
- for item in _from_objects(self._whereclause):
- if not seen.intersection(item._cloned_set):
- froms.append(item)
- seen.update(item._cloned_set)
-
- return froms
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index d0babb1be..47739a37d 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -200,6 +200,7 @@ class ClauseElement(
_is_from_container = False
_is_select_container = False
_is_select_statement = False
+ _is_bind_parameter = False
_order_by_label_element = None
@@ -1010,6 +1011,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
_is_crud = False
_expanding_in_types = ()
+ _is_bind_parameter = True
def __init__(
self,
@@ -1025,6 +1027,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
literal_execute=False,
_compared_to_operator=None,
_compared_to_type=None,
+ _is_crud=False,
):
r"""Produce a "bound expression".
@@ -1303,6 +1306,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
self.required = required
self.expanding = expanding
self.literal_execute = literal_execute
+ if _is_crud:
+ self._is_crud = True
if type_ is None:
if _compared_to_type is not None:
self.type = _compared_to_type.coerce_compared_value(
@@ -4264,21 +4269,12 @@ class ColumnClause(
else:
return other.proxy_set.intersection(self.proxy_set)
- def _get_table(self):
- return self.__dict__["table"]
-
- def _set_table(self, table):
- self._memoized_property.expire_instance(self)
- self.__dict__["table"] = table
-
def get_children(self, column_tables=False, **kw):
if column_tables and self.table is not None:
return [self.table]
else:
return []
- table = property(_get_table, _set_table)
-
@_memoized_property
def _from_objects(self):
t = self.table
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 5445a1bce..4c627c4cc 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -1413,6 +1413,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"Column must be constructed with a non-blank name or "
"assign a non-blank .name before adding to a Table."
)
+
+ Column._memoized_property.expire_instance(self)
+
if self.key is None:
self.key = self.name
@@ -2080,24 +2083,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._set_target_column(_column)
-class _NotAColumnExpr(object):
- # the coercions system is not used in crud.py for the values passed in
- # the insert().values() and update().values() methods, so the usual
- # pathways to rejecting a coercion in the unlikely case of adding defaut
- # generator objects to insert() or update() constructs aren't available;
- # create a quick coercion rejection here that is specific to what crud.py
- # calls on value objects.
- def _not_a_column_expr(self):
- raise exc.InvalidRequestError(
- "This %s cannot be used directly "
- "as a column expression." % self.__class__.__name__
- )
-
- self_group = lambda self: self._not_a_column_expr() # noqa
- _from_objects = property(lambda self: self._not_a_column_expr())
-
-
-class DefaultGenerator(_NotAColumnExpr, SchemaItem):
+class DefaultGenerator(SchemaItem):
"""Base class for column *default* values."""
__visit_name__ = "default_generator"
@@ -2505,7 +2491,7 @@ class Sequence(roles.StatementRole, DefaultGenerator):
@inspection._self_inspects
-class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
+class FetchedValue(SchemaEventTarget):
"""A marker for a transparent database-side default.
Use :class:`.FetchedValue` when the database is configured
@@ -2528,6 +2514,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
is_server_default = True
reflected = False
has_argument = False
+ is_clause_element = False
def __init__(self, for_update=False):
self.for_update = for_update
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index b972c13be..965ac6e7f 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -3145,8 +3145,6 @@ class Select(
__visit_name__ = "select"
- _prefixes = ()
- _suffixes = ()
_hints = util.immutabledict()
_statement_hints = ()
_distinct = False
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 03ff7c439..c29a04ee0 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -200,6 +200,9 @@ class _CacheKey(ExtendedInternalTraversal):
attrname, inspect(obj), parent, anon_map, bindparams
)
+ def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ return tuple(obj)
+
def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
return (
attrname,
@@ -336,6 +339,25 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
+ def visit_dialect_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ dialect_name,
+ tuple(
+ [
+ (key, obj[dialect_name][key])
+ for key in sorted(obj[dialect_name])
+ ]
+ ),
+ )
+ for dialect_name in sorted(obj)
+ ),
+ )
+
def visit_string_clauseelement_dict(
self, attrname, obj, parent, anon_map, bindparams
):
@@ -366,9 +388,13 @@ class _CacheKey(ExtendedInternalTraversal):
def visit_fromclause_canonical_column_collection(
self, attrname, obj, parent, anon_map, bindparams
):
+ # inlining into the internals of ColumnCollection
return (
attrname,
- tuple(col._gen_cache_key(anon_map, bindparams) for col in obj),
+ tuple(
+ col._gen_cache_key(anon_map, bindparams)
+ for k, col in obj._collection
+ ),
)
def visit_unknown_structure(
@@ -377,6 +403,48 @@ class _CacheKey(ExtendedInternalTraversal):
anon_map[NO_CACHE] = True
return ()
+ def visit_dml_ordered_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key._gen_cache_key(anon_map, bindparams)
+ if hasattr(key, "__clause_element__")
+ else key,
+ value._gen_cache_key(anon_map, bindparams),
+ )
+ for key, value in obj
+ ),
+ )
+
+ def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
+
+ str_values = expr_values.symmetric_difference(obj)
+
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
+
+ def visit_dml_multi_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # multivalues are simply not cacheable right now
+ anon_map[NO_CACHE] = True
+ return ()
+
_cache_key_traversal_visitor = _CacheKey()
@@ -404,6 +472,70 @@ class _CopyInternals(InternalTraversal):
(key, clone(value, **kw)) for key, value in element.items()
)
+ def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw):
+ # sequence of 2-tuples
+ return [
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key,
+ clone(value, **kw),
+ )
+ for key, value in element
+ ]
+
+ def visit_dml_values(self, parent, element, clone=_clone, **kw):
+ # sequence of dictionaries
+ return [
+ {
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key
+ ): clone(value, **kw)
+ for key, value in sub_element.items()
+ }
+ for sub_element in element
+ ]
+
+ def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
+ # sequence of sequences, each sequence contains a list/dict/tuple
+
+ def copy(elem):
+ if isinstance(elem, (list, tuple)):
+ return [
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key,
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value,
+ )
+ for key, value in elem
+ ]
+ elif isinstance(elem, dict):
+ return {
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key
+ ): (
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ for key, value in elem
+ }
+ else:
+ # TODO: use abc classes
+ assert False
+
+ return [
+ [copy(sub_element) for sub_element in sequence]
+ for sequence in element
+ ]
+
_copy_internals = _CopyInternals()
@@ -442,6 +574,25 @@ class _GetChildren(InternalTraversal):
def visit_clauseelement_unordered_set(self, element, **kw):
return tuple(element)
+ def visit_dml_ordered_values(self, element, **kw):
+ for k, v in element:
+ if hasattr(k, "__clause_element__"):
+ yield k
+ yield v
+
+ def visit_dml_values(self, element, **kw):
+ expr_values = {k for k in element if hasattr(k, "__clause_element__")}
+ str_values = expr_values.symmetric_difference(element)
+
+ for k in sorted(str_values):
+ yield element[k]
+ for k in expr_values:
+ yield k
+ yield element[k]
+
+ def visit_dml_multi_values(self, element, **kw):
+ return ()
+
_get_children = _GetChildren()
@@ -644,6 +795,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
def visit_string(self, left_parent, left, right_parent, right, **kw):
return left == right
+ def visit_string_list(self, left_parent, left, right_parent, right, **kw):
+ return left == right
+
def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
return _resolve_name_for_compare(
left_parent, left, self.anon_map[0], **kw
@@ -663,6 +817,11 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
return left == right
+ def visit_dialect_options(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
return left == right
@@ -713,6 +872,55 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
):
raise NotImplementedError()
+ def visit_dml_ordered_values(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ # sequence of tuple pairs
+
+ for (lk, lv), (rk, rv) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ lkce = hasattr(lk, "__clause_element__")
+ rkce = hasattr(rk, "__clause_element__")
+ if lkce != rkce:
+ return COMPARE_FAILED
+ elif lkce and not self.compare_inner(lk, rk, **kw):
+ return COMPARE_FAILED
+ elif not lkce and lk != rk:
+ return COMPARE_FAILED
+ elif not self.compare_inner(lv, rv, **kw):
+ return COMPARE_FAILED
+
+ def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
+ if left is None or right is None or len(left) != len(right):
+ return COMPARE_FAILED
+
+ for lk in left:
+ lv = left[lk]
+
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
+
+ if not self.compare_inner(lv, rv, **kw):
+ return COMPARE_FAILED
+
+ def visit_dml_multi_values(
+ self, left_parent, left, right_parent, right, **kw
+ ):
+ for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
+ if lseq is None or rseq is None:
+ return COMPARE_FAILED
+
+ for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
+ if (
+ self.visit_dml_values(
+ left_parent, ld, right_parent, rd, **kw
+ )
+ is COMPARE_FAILED
+ ):
+ return COMPARE_FAILED
+
def compare_clauselist(self, left, right, **kw):
if left.operator is right.operator:
if operators.is_associative(left.operator):
@@ -731,11 +939,11 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
if left.operator == right.operator:
if operators.is_commutative(left.operator):
if (
- compare(left.left, right.left, **kw)
- and compare(left.right, right.right, **kw)
+ self.compare_inner(left.left, right.left, **kw)
+ and self.compare_inner(left.right, right.right, **kw)
) or (
- compare(left.left, right.right, **kw)
- and compare(left.right, right.left, **kw)
+ self.compare_inner(left.left, right.right, **kw)
+ and self.compare_inner(left.right, right.left, **kw)
):
return ["operator", "negate", "left", "right"]
else:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index fda48c657..a049d9bb0 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -269,6 +269,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_string_list = symbol("SL")
+ """Visit a list of strings."""
+
dp_anon_name = symbol("AN")
"""Visit a potentially "anonymized" string value.
@@ -313,6 +316,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_dialect_options = symbol("DO")
+ """visit a dialect options structure."""
+
dp_string_clauseelement_dict = symbol("CD")
"""Visit a dictionary of string keys to :class:`.ClauseElement`
objects.
@@ -365,6 +371,21 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_dml_ordered_values = symbol("DML_OV")
+ """visit the values() ordered tuple list of an :class:`.Update` object."""
+
+ dp_dml_values = symbol("DML_V")
+ """visit the values() dictionary of a :class:`.ValuesBase
+ (e.g. Insert or Update) object.
+
+ """
+
+ dp_dml_multi_values = symbol("DML_MV")
+ """visit the values() multi-valued list of dictionaries of an
+ :class:`.Insert` object.
+
+ """
+
class ExtendedInternalTraversal(InternalTraversal):
"""defines additional symbols that are useful in caching applications.