diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 47 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 87 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 181 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 537 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 25 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 218 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 21 |
10 files changed, 807 insertions, 338 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 2d336360f..89839ea28 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -16,6 +16,7 @@ import re from .traversals import HasCacheKey # noqa from .visitors import ClauseVisitor +from .visitors import InternalTraversal from .. import exc from .. import util @@ -221,6 +222,10 @@ class DialectKWArgs(object): """ + _dialect_kwargs_traverse_internals = [ + ("dialect_options", InternalTraversal.dp_dialect_options) + ] + @classmethod def argument_for(cls, dialect_name, argument_name, default): """Add a new kind of dialect-specific keyword argument for this class. @@ -386,6 +391,39 @@ class DialectKWArgs(object): construct_arg_dictionary[arg_name] = kwargs[k] +class CompileState(object): + """Produces additional object state necessary for a statement to be + compiled. + + the :class:`.CompileState` class is at the base of classes that assemble + state for a particular statement object that is then used by the + compiler. This process is essentially an extension of the process that + the SQLCompiler.visit_XYZ() method takes, however there is an emphasis + on converting raw user intent into more organized structures rather than + producing string output. The top-level :class:`.CompileState` for the + statement being executed is also accessible when the execution context + works with invoking the statement and collecting results. + + The production of :class:`.CompileState` is specific to the compiler, such + as within the :meth:`.SQLCompiler.visit_insert`, + :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also + responsible for associating the :class:`.CompileState` with the + :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement, + i.e. the outermost SQL statement that's actually being executed. + There can be other :class:`.CompileState` objects that are not the + toplevel, such as when a SELECT subquery or CTE-nested + INSERT/UPDATE/DELETE is generated. + + .. versionadded:: 1.4 + + """ + + __slots__ = ("statement",) + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + class Generative(object): """Provide a method-chaining pattern in conjunction with the @_generative decorator.""" @@ -396,6 +434,12 @@ class Generative(object): return s +class HasCompileState(Generative): + """A class that has a :class:`.CompileState` associated with it.""" + + _compile_state_cls = CompileState + + class Executable(Generative): """Mark a ClauseElement as supporting execution. @@ -627,6 +671,9 @@ class ColumnCollection(object): def keys(self): return [k for (k, col) in self._collection] + def __bool__(self): + return bool(self._collection) + def __len__(self): return len(self._collection) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index fc841bb4b..679d9c6e9 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -55,7 +55,10 @@ def expect(role, element, **kw): # elaborate logic up front if possible impl = _impl_lookup[role] - if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)): + if not isinstance( + element, + (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), + ): resolved = impl._resolve_for_clause_element(element, **kw) else: resolved = element @@ -194,7 +197,9 @@ class _ColumnCoercions(object): def _implicit_coercions( self, original_element, resolved, argname=None, **kw ): - if resolved._is_select_statement: + if not resolved.is_clause_element: + self._raise_for_expected(original_element, argname, resolved) + elif resolved._is_select_statement: self._warn_for_scalar_subquery_coercion() return resolved.scalar_subquery() elif resolved._is_from_clause and isinstance( @@ -290,14 +295,14 @@ class ExpressionElementImpl( _ColumnCoercions, RoleImpl, roles.ExpressionElementRole ): def _literal_coercion( - self, element, name=None, type_=None, argname=None, **kw + self, element, name=None, type_=None, argname=None, is_crud=False, **kw ): if element is None: return elements.Null() else: try: return elements.BindParameter( - name, element, type_, unique=True + name, element, type_, unique=True, _is_crud=is_crud ) except exc.ArgumentError as err: self._raise_for_expected(element, err=err) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 424282951..3ebcf24b0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -653,6 +653,20 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () + compile_state = None + """Optional :class:`.CompileState` object that maintains additional + state used by the compiler. + + Major executable objects such as :class:`.Insert`, :class:`.Update`, + :class:`.Delete`, :class:`.Select` will generate this state when compiled + in order to calculate additional information about the object. For the + top level object that is to be executed, the state can be stored here where + it can also have applicability towards result set processing. + + .. versionadded:: 1.4 + + """ + def __init__( self, dialect, @@ -1292,6 +1306,13 @@ class SQLCompiler(Compiled): else: return "0" + def _generate_delimited_list(self, elements, separator, **kw): + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in elements) + if s + ) + def visit_clauselist(self, clauselist, **kw): sep = clauselist.operator if sep is None: @@ -1299,13 +1320,7 @@ class SQLCompiler(Compiled): else: sep = OPERATORS[clauselist.operator] - text = sep.join( - s - for s in ( - c._compiler_dispatch(self, **kw) for c in clauselist.clauses - ) - if s - ) + text = self._generate_delimited_list(clauselist.clauses, sep, **kw) if clauselist._tuple_values and self.dialect.tuple_in_values: text = "VALUES " + text return text @@ -2810,8 +2825,18 @@ class SQLCompiler(Compiled): return dialect_hints, table_text def visit_insert(self, insert_stmt, **kw): + + compile_state = insert_stmt._compile_state_cls( + insert_stmt, self, isinsert=True, **kw + ) + insert_stmt = compile_state.statement + toplevel = not self.stack + if toplevel: + self.isinsert = True + self.compile_state = compile_state + self.stack.append( { "correlate_froms": set(), @@ -2820,8 +2845,8 @@ class SQLCompiler(Compiled): } ) - crud_params = crud._setup_crud_params( - self, insert_stmt, crud.ISINSERT, **kw + crud_params = crud._get_crud_params( + self, insert_stmt, compile_state, **kw ) if ( @@ -2835,7 +2860,7 @@ class SQLCompiler(Compiled): "inserts." % self.dialect.name ) - if insert_stmt._has_multi_parameters: + if compile_state._has_multi_parameters: if not self.dialect.supports_multivalues_insert: raise exc.CompileError( "The '%s' dialect with current database " @@ -2888,7 +2913,7 @@ class SQLCompiler(Compiled): text += " %s" % select_text elif not crud_params and supports_default_values: text += " DEFAULT VALUES" - elif insert_stmt._has_multi_parameters: + elif compile_state._has_multi_parameters: text += " VALUES %s" % ( ", ".join( "(%s)" % (", ".join(c[1] for c in crud_param_set)) @@ -2947,9 +2972,16 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_cls( + update_stmt, self, isupdate=True, **kw + ) + update_stmt = compile_state.statement + toplevel = not self.stack + if toplevel: + self.isupdate = True - extra_froms = update_stmt._extra_froms + extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) if is_multitable: @@ -2981,8 +3013,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause( update_stmt, update_stmt.table, render_extra_froms, **kw ) - crud_params = crud._setup_crud_params( - self, update_stmt, crud.ISUPDATE, **kw + crud_params = crud._get_crud_params( + self, update_stmt, compile_state, **kw ) if update_stmt._hints: @@ -3022,8 +3054,10 @@ class SQLCompiler(Compiled): if extra_from_text: text += " " + extra_from_text - if update_stmt._whereclause is not None: - t = self.process(update_stmt._whereclause, **kw) + if update_stmt._where_criteria: + t = self._generate_delimited_list( + update_stmt._where_criteria, OPERATORS[operators.and_], **kw + ) if t: text += " WHERE " + t @@ -3045,10 +3079,6 @@ class SQLCompiler(Compiled): return text - @util.memoized_property - def _key_getters_for_crud_column(self): - return crud._key_getters_for_crud_column(self, self.statement) - def delete_extra_from_clause( self, update_stmt, from_table, extra_froms, from_hints, **kw ): @@ -3069,11 +3099,16 @@ class SQLCompiler(Compiled): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, **kw): - toplevel = not self.stack + compile_state = delete_stmt._compile_state_cls( + delete_stmt, self, isdelete=True, **kw + ) + delete_stmt = compile_state.statement - crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) + toplevel = not self.stack + if toplevel: + self.isdelete = True - extra_froms = delete_stmt._extra_froms + extra_froms = compile_state._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) self.stack.append( @@ -3122,8 +3157,10 @@ class SQLCompiler(Compiled): if extra_from_text: text += " " + extra_from_text - if delete_stmt._whereclause is not None: - t = delete_stmt._whereclause._compiler_dispatch(self, **kw) + if delete_stmt._where_criteria: + t = self._generate_delimited_list( + delete_stmt._where_criteria, OPERATORS[operators.and_], **kw + ) if t: text += " WHERE " + t diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e474952ce..2827a5817 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -16,6 +16,7 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ClauseElement from .. import exc from .. import util @@ -33,45 +34,8 @@ values present. """, ) -ISINSERT = util.symbol("ISINSERT") -ISUPDATE = util.symbol("ISUPDATE") -ISDELETE = util.symbol("ISDELETE") - -def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): - restore_isinsert = compiler.isinsert - restore_isupdate = compiler.isupdate - restore_isdelete = compiler.isdelete - - should_restore = ( - (restore_isinsert or restore_isupdate or restore_isdelete) - or len(compiler.stack) > 1 - or "visiting_cte" in kw - ) - - if local_stmt_type is ISINSERT: - compiler.isupdate = False - compiler.isinsert = True - elif local_stmt_type is ISUPDATE: - compiler.isupdate = True - compiler.isinsert = False - elif local_stmt_type is ISDELETE: - if not should_restore: - compiler.isdelete = True - else: - assert False, "ISINSERT, ISUPDATE, or ISDELETE expected" - - try: - if local_stmt_type in (ISINSERT, ISUPDATE): - return _get_crud_params(compiler, stmt, **kw) - finally: - if should_restore: - compiler.isinsert = restore_isinsert - compiler.isupdate = restore_isupdate - compiler.isdelete = restore_isdelete - - -def _get_crud_params(compiler, stmt, **kw): +def _get_crud_params(compiler, stmt, compile_state, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -87,27 +51,29 @@ def _get_crud_params(compiler, stmt, **kw): compiler.update_prefetch = [] compiler.returning = [] + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._key_getters_for_crud_column = getters + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns - if compiler.column_keys is None and stmt.parameters is None: + if compiler.column_keys is None and compile_state._no_parameters: return [ (c, _create_bind_param(compiler, c, None, required=True)) for c in stmt.table.columns ] - if stmt._has_multi_parameters: - stmt_parameters = stmt.parameters[0] + if compile_state._has_multi_parameters: + stmt_parameters = compile_state._multi_parameters[0] else: - stmt_parameters = stmt.parameters - - # getters - these are normally just column.key, - # but in the case of mysql multi-table update, the rules for - # .key must conditionally take tablename into account - ( - _column_as_key, - _getattr_col_key, - _col_bind_name, - ) = _key_getters_for_crud_column(compiler, stmt) + stmt_parameters = compile_state._dict_parameters # if we have statement parameters - set defaults in the # compiled params @@ -132,10 +98,15 @@ def _get_crud_params(compiler, stmt, **kw): # special logic that only occurs for multi-table UPDATE # statements - if compiler.isupdate and stmt._extra_froms and stmt_parameters: + if ( + compile_state.isupdate + and compile_state._extra_froms + and stmt_parameters + ): _get_multitable_params( compiler, stmt, + compile_state, stmt_parameters, check_columns, _col_bind_name, @@ -144,10 +115,11 @@ def _get_crud_params(compiler, stmt, **kw): kw, ) - if compiler.isinsert and stmt.select_names: + if compile_state.isinsert and stmt._select_names: _scan_insert_from_select_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -160,6 +132,7 @@ def _get_crud_params(compiler, stmt, **kw): _scan_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -181,8 +154,10 @@ def _get_crud_params(compiler, stmt, **kw): % (", ".join("%s" % (c,) for c in check)) ) - if stmt._has_multi_parameters: - values = _extend_values_for_multiparams(compiler, stmt, values, kw) + if compile_state._has_multi_parameters: + values = _extend_values_for_multiparams( + compiler, stmt, compile_state, values, kw + ) return values @@ -201,15 +176,46 @@ def _create_bind_param( return bindparam -def _key_getters_for_crud_column(compiler, stmt): - if compiler.isupdate and stmt._extra_froms: +def _handle_values_anonymous_param(compiler, col, value, name, **kw): + # the insert() and update() constructs as of 1.4 will now produce anonymous + # bindparam() objects in the values() collections up front when given plain + # literal values. This is so that cache key behaviors, which need to + # produce bound parameters in deterministic order without invoking any + # compilation here, can be applied to these constructs when they include + # values() (but not yet multi-values, which are not included in caching + # right now). + # + # in order to produce the desired "crud" style name for these parameters, + # which will also be targetable in engine/default.py through the usual + # conventions, apply our desired name to these unique parameters by + # populating the compiler truncated names cache with the desired name, + # rather than having + # compiler.visit_bindparam()->compiler._truncated_identifier make up a + # name. Saves on call counts also. + if value.unique and isinstance(value.key, elements._truncated_label): + compiler.truncated_names[("bindparam", value.key)] = name + + if value.type._isnull: + # either unique parameter, or other bound parameters that were + # passed in directly + # clone using base ClauseElement to retain unique key + value = ClauseElement._clone(value) + + # set type to that of the column unconditionally + value.type = col.type + + return value._compiler_dispatch(compiler, **kw) + + +def _key_getters_for_crud_column(compiler, stmt, compile_state): + if compile_state.isupdate and compile_state._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. # the "main" table of the statement remains unqualified, # allowing the most compatibility with a non-multi-table # statement. - _et = set(stmt._extra_froms) + _et = set(compile_state._extra_froms) c_key_role = functools.partial( coercions.expect_as_key, roles.DMLColumnRole @@ -246,6 +252,7 @@ def _key_getters_for_crud_column(compiler, stmt): def _scan_insert_from_select_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -260,9 +267,9 @@ def _scan_insert_from_select_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt) + ) = _get_returning_modifiers(compiler, stmt, compile_state) - cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names] + cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] compiler._insert_from_select = stmt.select @@ -294,6 +301,7 @@ def _scan_insert_from_select_cols( def _scan_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -308,11 +316,11 @@ def _scan_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt) + ) = _get_returning_modifiers(compiler, stmt, compile_state) - if stmt._parameter_ordering: + if compile_state._parameter_ordering: parameter_ordering = [ - _column_as_key(key) for key in stmt._parameter_ordering + _column_as_key(key) for key in compile_state._parameter_ordering ] ordered_keys = set(parameter_ordering) cols = [stmt.table.c[key] for key in parameter_ordering] + [ @@ -329,6 +337,7 @@ def _scan_cols( _append_param_parameter( compiler, stmt, + compile_state, c, col_key, parameters, @@ -339,7 +348,7 @@ def _scan_cols( kw, ) - elif compiler.isinsert: + elif compile_state.isinsert: if ( c.primary_key and need_pks @@ -377,7 +386,7 @@ def _scan_cols( ): _warn_pk_with_no_anticipated_value(c) - elif compiler.isupdate: + elif compile_state.isupdate: _append_param_update( compiler, stmt, c, implicit_return_defaults, values, kw ) @@ -386,6 +395,7 @@ def _scan_cols( def _append_param_parameter( compiler, stmt, + compile_state, c, col_key, parameters, @@ -395,7 +405,9 @@ def _append_param_parameter( values, kw, ): + value = parameters.pop(col_key) + if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -403,15 +415,21 @@ def _append_param_parameter( value, required=value is REQUIRED, name=_col_bind_name(c) - if not stmt._has_multi_parameters + if not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c), + **kw + ) + elif value._is_bind_parameter: + value = _handle_values_anonymous_param( + compiler, + c, + value, + name=_col_bind_name(c) + if not compile_state._has_multi_parameters else "%s_m0" % _col_bind_name(c), **kw ) else: - if isinstance(value, elements.BindParameter) and value.type._isnull: - value = value._clone() - value.type = c.type - if c.primary_key and implicit_returning: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) @@ -644,6 +662,7 @@ def _append_param_update( def _get_multitable_params( compiler, stmt, + compile_state, stmt_parameters, check_columns, _col_bind_name, @@ -656,7 +675,7 @@ def _get_multitable_params( for c, param in stmt_parameters.items() ) affected_tables = set() - for t in stmt._extra_froms: + for t in compile_state._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) @@ -669,6 +688,11 @@ def _get_multitable_params( value, required=value is REQUIRED, name=_col_bind_name(c), + **kw # TODO: no test coverage for literal binds here + ) + elif value._is_bind_parameter: + value = _handle_values_anonymous_param( + compiler, c, value, name=_col_bind_name(c), **kw ) else: compiler.postfetch.append(c) @@ -704,11 +728,11 @@ def _get_multitable_params( compiler.postfetch.append(c) -def _extend_values_for_multiparams(compiler, stmt, values, kw): +def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): values_0 = values values = [values] - for i, row in enumerate(stmt.parameters[1:]): + for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] for (col, param) in values_0: if col in row or col.key in row: @@ -757,12 +781,13 @@ def _get_stmt_parameters_params( values.append((k, v)) -def _get_returning_modifiers(compiler, stmt): +def _get_returning_modifiers(compiler, stmt, compile_state): + need_pks = ( - compiler.isinsert + compile_state.isinsert and not compiler.inline and not stmt._returning - and not stmt._has_multi_parameters + and not compile_state._has_multi_parameters ) implicit_returning = ( @@ -771,9 +796,9 @@ def _get_returning_modifiers(compiler, stmt): and stmt.table.implicit_returning ) - if compiler.isinsert: + if compile_state.isinsert: implicit_return_defaults = implicit_returning and stmt._return_defaults - elif compiler.isupdate: + elif compile_state.isupdate: implicit_return_defaults = ( compiler.dialect.implicit_returning and stmt.table.implicit_returning diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 097c513b4..171a2cc2c 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -8,25 +8,162 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ - +from sqlalchemy.types import NullType from . import coercions from . import roles from .base import _from_objects from .base import _generative +from .base import CompileState from .base import DialectKWArgs from .base import Executable -from .elements import and_ +from .base import HasCompileState from .elements import ClauseElement from .elements import Null from .selectable import HasCTE from .selectable import HasPrefixes +from .visitors import InternalTraversal from .. import exc from .. import util +from ..util import collections_abc + + +class DMLState(CompileState): + _no_parameters = True + _dict_parameters = None + _multi_parameters = None + _parameter_ordering = None + _has_multi_parameters = False + isupdate = False + isdelete = False + isinsert = False + + def __init__( + self, + statement, + compiler, + isinsert=False, + isupdate=False, + isdelete=False, + **kw + ): + self.statement = statement + + if isupdate: + self.isupdate = True + self._preserve_parameter_order = ( + statement._preserve_parameter_order + ) + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._process_multi_values(statement) + self._extra_froms = self._make_extra_froms(statement) + elif isinsert: + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + elif isdelete: + self.isdelete = True + self._extra_froms = self._make_extra_froms(statement) + else: + assert False, "one of isinsert, isupdate, or isdelete must be set" + + def _make_extra_froms(self, statement): + froms = [] + seen = {statement.table} + + for crit in statement._where_criteria: + for item in _from_objects(crit): + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + return froms + + def _process_multi_values(self, statement): + if not statement._supports_multi_parameters: + raise exc.InvalidRequestError( + "%s construct does not support " + "multiple parameter sets." % statement.__visit_name__.upper() + ) + + for parameters in statement._multi_values: + multi_parameters = [ + { + c.key: value + for c, value in zip(statement.table.c, parameter_set) + } + if isinstance(parameter_set, collections_abc.Sequence) + else parameter_set + for parameter_set in parameters + ] + + if self._no_parameters: + self._no_parameters = False + self._has_multi_parameters = True + self._multi_parameters = multi_parameters + self._dict_parameters = self._multi_parameters[0] + elif not self._has_multi_parameters: + self._cant_mix_formats_error() + else: + self._multi_parameters.extend(multi_parameters) + + def _process_values(self, statement): + if self._no_parameters: + self._has_multi_parameters = False + self._dict_parameters = statement._values + self._no_parameters = False + elif self._has_multi_parameters: + self._cant_mix_formats_error() + + def _process_ordered_values(self, statement): + parameters = statement._ordered_values + + if self._no_parameters: + self._no_parameters = False + self._dict_parameters = dict(parameters) + self._parameter_ordering = [key for key, value in parameters] + elif self._has_multi_parameters: + self._cant_mix_formats_error() + else: + raise exc.InvalidRequestError( + "Can only invoke ordered_values() once, and not mixed " + "with any other values() call" + ) + + def _process_select_values(self, statement): + parameters = { + coercions.expect(roles.DMLColumnRole, name, as_key=True): Null() + for name in statement._select_names + } + + if self._no_parameters: + self._no_parameters = False + self._dict_parameters = parameters + else: + # this condition normally not reachable as the Insert + # does not allow this construction to occur + assert False, "This statement already has parameters" + + def _cant_mix_formats_error(self): + raise exc.InvalidRequestError( + "Can't mix single and multiple VALUES " + "formats in one INSERT statement; one style appends to a " + "list while the other replaces values, so the intent is " + "ambiguous." + ) class UpdateBase( roles.DMLRole, HasCTE, + HasCompileState, DialectKWArgs, HasPrefixes, Executable, @@ -42,10 +179,10 @@ class UpdateBase( {"autocommit": True} ) _hints = util.immutabledict() - _parameter_ordering = None - _prefixes = () named_with_column = False + _compile_state_cls = DMLState + @classmethod def _constructor_20_deprecations(cls, fn_name, clsname, names): @@ -112,43 +249,6 @@ class UpdateBase( col._make_proxy(fromclause) for col in self._returning ) - def _process_colparams(self, parameters, preserve_parameter_order=False): - def process_single(p): - if isinstance(p, (list, tuple)): - return dict((c.key, pval) for c, pval in zip(self.table.c, p)) - else: - return p - - if ( - preserve_parameter_order or self._preserve_parameter_order - ) and parameters is not None: - if not isinstance(parameters, list) or ( - parameters and not isinstance(parameters[0], tuple) - ): - raise ValueError( - "When preserve_parameter_order is True, " - "values() only accepts a list of 2-tuples" - ) - self._parameter_ordering = [key for key, value in parameters] - - return dict(parameters), False - - if ( - isinstance(parameters, (list, tuple)) - and parameters - and isinstance(parameters[0], (list, tuple, dict)) - ): - - if not self._supports_multi_parameters: - raise exc.InvalidRequestError( - "This construct does not support " - "multiple parameter sets." - ) - - return [process_single(p) for p in parameters], True - else: - return process_single(parameters), False - def params(self, *arg, **kw): """Set the parameters for the statement. @@ -163,6 +263,29 @@ class UpdateBase( " stmt.values(**parameters)." ) + @_generative + def with_dialect_options(self, **opt): + """Add dialect options to this INSERT/UPDATE/DELETE object. + + e.g.:: + + upd = table.update().dialect_options(mysql_limit=10) + + .. versionadded: 1.4 - this method supersedes the dialect options + associated with the constructor. + + + """ + self._validate_dialect_kwargs(opt) + + def _validate_dialect_kwargs_deprecated(self, dialect_kw): + util.warn_deprecated_20( + "Passing dialect keyword arguments directly to the " + "constructor is deprecated and will be removed in SQLAlchemy " + "2.0. Please use the ``with_dialect_options()`` method." + ) + self._validate_dialect_kwargs(dialect_kw) + def bind(self): """Return a 'bind' linked to this :class:`.UpdateBase` or a :class:`.Table` associated with it. @@ -266,9 +389,6 @@ class UpdateBase( self._hints = self._hints.union({(selectable, dialect_name): text}) - def _copy_internals(self, **kw): - raise NotImplementedError() - class ValuesBase(UpdateBase): """Supplies support for :meth:`.ValuesBase.values` to @@ -277,16 +397,21 @@ class ValuesBase(UpdateBase): __visit_name__ = "values_base" _supports_multi_parameters = False - _has_multi_parameters = False _preserve_parameter_order = False select = None _post_values_clause = None + _values = None + _multi_values = () + _ordered_values = None + _select_names = None + + _returning = () + def __init__(self, table, values, prefixes): self.table = coercions.expect(roles.FromClauseRole, table) - self.parameters, self._has_multi_parameters = self._process_colparams( - values - ) + if values is not None: + self.values.non_generative(self, values) if prefixes: self._setup_prefixes(prefixes) @@ -416,59 +541,96 @@ class ValuesBase(UpdateBase): :func:`~.expression.update` - produce an ``UPDATE`` statement """ - if self.select is not None: + if self._select_names: raise exc.InvalidRequestError( "This construct already inserts from a SELECT" ) - if self._has_multi_parameters and kwargs: - raise exc.InvalidRequestError( - "This construct already has multiple parameter sets." + elif self._ordered_values: + raise exc.ArgumentError( + "This statement already has ordered values present" ) if args: - if len(args) > 1: + # positional case. this is currently expensive. we don't + # yet have positional-only args so we have to check the length. + # then we need to check multiparams vs. single dictionary. + # since the parameter format is needed in order to determine + # a cache key, we need to determine this up front. + arg = args[0] + + if kwargs: + raise exc.ArgumentError( + "Can't pass positional and kwargs to values() " + "simultaneously" + ) + elif len(args) > 1: raise exc.ArgumentError( "Only a single dictionary/tuple or list of " "dictionaries/tuples is accepted positionally." ) - v = args[0] - else: - v = {} - if self.parameters is None: - ( - self.parameters, - self._has_multi_parameters, - ) = self._process_colparams(v) - else: - if self._has_multi_parameters: - self.parameters = list(self.parameters) - p, self._has_multi_parameters = self._process_colparams(v) - if not self._has_multi_parameters: - raise exc.ArgumentError( - "Can't mix single-values and multiple values " - "formats in one statement" - ) + elif not self._preserve_parameter_order and isinstance( + arg, collections_abc.Sequence + ): - self.parameters.extend(p) - else: - self.parameters = self.parameters.copy() - p, self._has_multi_parameters = self._process_colparams(v) - if self._has_multi_parameters: - raise exc.ArgumentError( - "Can't mix single-values and multiple values " - "formats in one statement" - ) - self.parameters.update(p) + if arg and isinstance(arg[0], (list, dict, tuple)): + self._multi_values += (arg,) + return - if kwargs: - if self._has_multi_parameters: + # tuple values + arg = {c.key: value for c, value in zip(self.table.c, arg)} + elif self._preserve_parameter_order and not isinstance( + arg, collections_abc.Sequence + ): + raise ValueError( + "When preserve_parameter_order is True, " + "values() only accepts a list of 2-tuples" + ) + + else: + # kwarg path. this is the most common path for non-multi-params + # so this is fairly quick. + arg = kwargs + if args: raise exc.ArgumentError( - "Can't pass kwargs and multiple parameter sets " - "simultaneously" + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." ) + + # for top level values(), convert literals to anonymous bound + # parameters at statement construction time, so that these values can + # participate in the cache key process like any other ClauseElement. + # crud.py now intercepts bound parameters with unique=True from here + # and ensures they get the "crud"-style name when rendered. + + if self._preserve_parameter_order: + arg = [ + ( + k, + coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ), + ) + for k, v in arg + ] + self._ordered_values = arg + else: + arg = { + k: coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ) + for k, v in arg.items() + } + if self._values: + self._values = self._values.union(arg) else: - self.parameters.update(kwargs) + self._values = util.immutabledict(arg) @_generative def return_defaults(self, *cols): @@ -555,6 +717,25 @@ class Insert(ValuesBase): _supports_multi_parameters = True + select = None + include_insert_from_select_defaults = False + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_inline", InternalTraversal.dp_boolean), + ("_select_names", InternalTraversal.dp_string_list), + ("_values", InternalTraversal.dp_dml_values), + ("_multi_values", InternalTraversal.dp_dml_multi_values), + ("select", InternalTraversal.dp_clauseelement), + ("_post_values_clause", InternalTraversal.dp_clauseelement), + ("_returning", InternalTraversal.dp_clauseelement_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + ) + @ValuesBase._constructor_20_deprecations( "insert", "Insert", @@ -626,18 +807,13 @@ class Insert(ValuesBase): """ super(Insert, self).__init__(table, values, prefixes) self._bind = bind - self.select = self.select_names = None - self.include_insert_from_select_defaults = False self._inline = inline - self._returning = returning - self._validate_dialect_kwargs(dialect_kw) - self._return_defaults = return_defaults + if returning: + self._returning = returning + if dialect_kw: + self._validate_dialect_kwargs_deprecated(dialect_kw) - def get_children(self, **kwargs): - if self.select is not None: - return (self.select,) - else: - return () + self._return_defaults = return_defaults @_generative def inline(self): @@ -702,25 +878,34 @@ class Insert(ValuesBase): :attr:`.ResultProxy.inserted_primary_key` accessor does not apply. """ - if self.parameters: + + if self._values: raise exc.InvalidRequestError( "This construct already inserts value expressions" ) - self.parameters, self._has_multi_parameters = self._process_colparams( - { - coercions.expect(roles.DMLColumnRole, n, as_key=True): Null() - for n in names - } - ) - - self.select_names = names + self._select_names = names self._inline = True self.include_insert_from_select_defaults = include_defaults self.select = coercions.expect(roles.DMLSelectRole, select) -class Update(ValuesBase): +class DMLWhereBase(object): + _where_criteria = () + + @_generative + def where(self, whereclause): + """return a new construct with the given expression added to + its WHERE clause, joined to the existing clause via AND, if any. + + """ + + self._where_criteria += ( + coercions.expect(roles.WhereHavingRole, whereclause), + ) + + +class Update(DMLWhereBase, ValuesBase): """Represent an Update construct. The :class:`.Update` object is created using the :func:`update()` @@ -730,6 +915,20 @@ class Update(ValuesBase): __visit_name__ = "update" + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_list), + ("_inline", InternalTraversal.dp_boolean), + ("_ordered_values", InternalTraversal.dp_dml_ordered_values), + ("_values", InternalTraversal.dp_dml_values), + ("_returning", InternalTraversal.dp_clauseelement_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + ) + @ValuesBase._constructor_20_deprecations( "update", "Update", @@ -874,21 +1073,14 @@ class Update(ValuesBase): self._bind = bind self._returning = returning if whereclause is not None: - self._whereclause = coercions.expect( - roles.WhereHavingRole, whereclause + self._where_criteria += ( + coercions.expect(roles.WhereHavingRole, whereclause), ) - else: - self._whereclause = None self._inline = inline - self._validate_dialect_kwargs(dialect_kw) + if dialect_kw: + self._validate_dialect_kwargs_deprecated(dialect_kw) self._return_defaults = return_defaults - def get_children(self, **kwargs): - if self._whereclause is not None: - return (self._whereclause,) - else: - return () - @_generative def ordered_values(self, *args): """Specify the VALUES clause of this UPDATE statement with an explicit @@ -912,22 +1104,27 @@ class Update(ValuesBase): parameter, which will be removed in SQLAlchemy 2.0. """ - if self.select is not None: - raise exc.InvalidRequestError( - "This construct already inserts from a SELECT" - ) - - if self.parameters is None: - ( - self.parameters, - self._has_multi_parameters, - ) = self._process_colparams( - list(args), preserve_parameter_order=True - ) - else: + if self._values: raise exc.ArgumentError( "This statement already has values present" ) + elif self._ordered_values: + raise exc.ArgumentError( + "This statement already has ordered values present" + ) + arg = [ + ( + k, + coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ), + ) + for k, v in args + ] + self._ordered_values = arg @_generative def inline(self): @@ -945,37 +1142,8 @@ class Update(ValuesBase): """ self._inline = True - @_generative - def where(self, whereclause): - """return a new update() construct with the given expression added to - its WHERE clause, joined to the existing clause via AND, if any. - - """ - if self._whereclause is not None: - self._whereclause = and_( - self._whereclause, - coercions.expect(roles.WhereHavingRole, whereclause), - ) - else: - self._whereclause = coercions.expect( - roles.WhereHavingRole, whereclause - ) - - @property - def _extra_froms(self): - froms = [] - seen = {self.table} - - if self._whereclause is not None: - for item in _from_objects(self._whereclause): - if not seen.intersection(item._cloned_set): - froms.append(item) - seen.update(item._cloned_set) - - return froms - -class Delete(UpdateBase): +class Delete(DMLWhereBase, UpdateBase): """Represent a DELETE construct. The :class:`.Delete` object is created using the :func:`delete()` @@ -985,6 +1153,17 @@ class Delete(UpdateBase): __visit_name__ = "delete" + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_list), + ("_returning", InternalTraversal.dp_clauseelement_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + ) + @ValuesBase._constructor_20_deprecations( "delete", "Delete", @@ -1041,43 +1220,9 @@ class Delete(UpdateBase): self._setup_prefixes(prefixes) if whereclause is not None: - self._whereclause = coercions.expect( - roles.WhereHavingRole, whereclause - ) - else: - self._whereclause = None - - self._validate_dialect_kwargs(dialect_kw) - - def get_children(self, **kwargs): - if self._whereclause is not None: - return (self._whereclause,) - else: - return () - - @_generative - def where(self, whereclause): - """Add the given WHERE clause to a newly returned delete construct.""" - - if self._whereclause is not None: - self._whereclause = and_( - self._whereclause, + self._where_criteria += ( coercions.expect(roles.WhereHavingRole, whereclause), ) - else: - self._whereclause = coercions.expect( - roles.WhereHavingRole, whereclause - ) - - @property - def _extra_froms(self): - froms = [] - seen = {self.table} - if self._whereclause is not None: - for item in _from_objects(self._whereclause): - if not seen.intersection(item._cloned_set): - froms.append(item) - seen.update(item._cloned_set) - - return froms + if dialect_kw: + self._validate_dialect_kwargs_deprecated(dialect_kw) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index d0babb1be..47739a37d 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -200,6 +200,7 @@ class ClauseElement( _is_from_container = False _is_select_container = False _is_select_statement = False + _is_bind_parameter = False _order_by_label_element = None @@ -1010,6 +1011,7 @@ class BindParameter(roles.InElementRole, ColumnElement): _is_crud = False _expanding_in_types = () + _is_bind_parameter = True def __init__( self, @@ -1025,6 +1027,7 @@ class BindParameter(roles.InElementRole, ColumnElement): literal_execute=False, _compared_to_operator=None, _compared_to_type=None, + _is_crud=False, ): r"""Produce a "bound expression". @@ -1303,6 +1306,8 @@ class BindParameter(roles.InElementRole, ColumnElement): self.required = required self.expanding = expanding self.literal_execute = literal_execute + if _is_crud: + self._is_crud = True if type_ is None: if _compared_to_type is not None: self.type = _compared_to_type.coerce_compared_value( @@ -4264,21 +4269,12 @@ class ColumnClause( else: return other.proxy_set.intersection(self.proxy_set) - def _get_table(self): - return self.__dict__["table"] - - def _set_table(self, table): - self._memoized_property.expire_instance(self) - self.__dict__["table"] = table - def get_children(self, column_tables=False, **kw): if column_tables and self.table is not None: return [self.table] else: return [] - table = property(_get_table, _set_table) - @_memoized_property def _from_objects(self): t = self.table diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 5445a1bce..4c627c4cc 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1413,6 +1413,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): "Column must be constructed with a non-blank name or " "assign a non-blank .name before adding to a Table." ) + + Column._memoized_property.expire_instance(self) + if self.key is None: self.key = self.name @@ -2080,24 +2083,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._set_target_column(_column) -class _NotAColumnExpr(object): - # the coercions system is not used in crud.py for the values passed in - # the insert().values() and update().values() methods, so the usual - # pathways to rejecting a coercion in the unlikely case of adding defaut - # generator objects to insert() or update() constructs aren't available; - # create a quick coercion rejection here that is specific to what crud.py - # calls on value objects. - def _not_a_column_expr(self): - raise exc.InvalidRequestError( - "This %s cannot be used directly " - "as a column expression." % self.__class__.__name__ - ) - - self_group = lambda self: self._not_a_column_expr() # noqa - _from_objects = property(lambda self: self._not_a_column_expr()) - - -class DefaultGenerator(_NotAColumnExpr, SchemaItem): +class DefaultGenerator(SchemaItem): """Base class for column *default* values.""" __visit_name__ = "default_generator" @@ -2505,7 +2491,7 @@ class Sequence(roles.StatementRole, DefaultGenerator): @inspection._self_inspects -class FetchedValue(_NotAColumnExpr, SchemaEventTarget): +class FetchedValue(SchemaEventTarget): """A marker for a transparent database-side default. Use :class:`.FetchedValue` when the database is configured @@ -2528,6 +2514,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget): is_server_default = True reflected = False has_argument = False + is_clause_element = False def __init__(self, for_update=False): self.for_update = for_update diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index b972c13be..965ac6e7f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3145,8 +3145,6 @@ class Select( __visit_name__ = "select" - _prefixes = () - _suffixes = () _hints = util.immutabledict() _statement_hints = () _distinct = False diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 03ff7c439..c29a04ee0 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -200,6 +200,9 @@ class _CacheKey(ExtendedInternalTraversal): attrname, inspect(obj), parent, anon_map, bindparams ) + def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + return tuple(obj) + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): return ( attrname, @@ -336,6 +339,25 @@ class _CacheKey(ExtendedInternalTraversal): def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) + def visit_dialect_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + dialect_name, + tuple( + [ + (key, obj[dialect_name][key]) + for key in sorted(obj[dialect_name]) + ] + ), + ) + for dialect_name in sorted(obj) + ), + ) + def visit_string_clauseelement_dict( self, attrname, obj, parent, anon_map, bindparams ): @@ -366,9 +388,13 @@ class _CacheKey(ExtendedInternalTraversal): def visit_fromclause_canonical_column_collection( self, attrname, obj, parent, anon_map, bindparams ): + # inlining into the internals of ColumnCollection return ( attrname, - tuple(col._gen_cache_key(anon_map, bindparams) for col in obj), + tuple( + col._gen_cache_key(anon_map, bindparams) + for k, col in obj._collection + ), ) def visit_unknown_structure( @@ -377,6 +403,48 @@ class _CacheKey(ExtendedInternalTraversal): anon_map[NO_CACHE] = True return () + def visit_dml_ordered_values( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key, + value._gen_cache_key(anon_map, bindparams), + ) + for key, value in obj + ), + ) + + def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + + expr_values = {k for k in obj if hasattr(k, "__clause_element__")} + if expr_values: + # expr values can't be sorted deterministically right now, + # so no cache + anon_map[NO_CACHE] = True + return () + + str_values = expr_values.symmetric_difference(obj) + + return ( + attrname, + tuple( + (k, obj[k]._gen_cache_key(anon_map, bindparams)) + for k in sorted(str_values) + ), + ) + + def visit_dml_multi_values( + self, attrname, obj, parent, anon_map, bindparams + ): + # multivalues are simply not cacheable right now + anon_map[NO_CACHE] = True + return () + _cache_key_traversal_visitor = _CacheKey() @@ -404,6 +472,70 @@ class _CopyInternals(InternalTraversal): (key, clone(value, **kw)) for key, value in element.items() ) + def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw): + # sequence of 2-tuples + return [ + ( + clone(key, **kw) + if hasattr(key, "__clause_element__") + else key, + clone(value, **kw), + ) + for key, value in element + ] + + def visit_dml_values(self, parent, element, clone=_clone, **kw): + # sequence of dictionaries + return [ + { + ( + clone(key, **kw) + if hasattr(key, "__clause_element__") + else key + ): clone(value, **kw) + for key, value in sub_element.items() + } + for sub_element in element + ] + + def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): + # sequence of sequences, each sequence contains a list/dict/tuple + + def copy(elem): + if isinstance(elem, (list, tuple)): + return [ + ( + clone(key, **kw) + if hasattr(key, "__clause_element__") + else key, + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value, + ) + for key, value in elem + ] + elif isinstance(elem, dict): + return { + ( + clone(key, **kw) + if hasattr(key, "__clause_element__") + else key + ): ( + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + ) + for key, value in elem + } + else: + # TODO: use abc classes + assert False + + return [ + [copy(sub_element) for sub_element in sequence] + for sequence in element + ] + _copy_internals = _CopyInternals() @@ -442,6 +574,25 @@ class _GetChildren(InternalTraversal): def visit_clauseelement_unordered_set(self, element, **kw): return tuple(element) + def visit_dml_ordered_values(self, element, **kw): + for k, v in element: + if hasattr(k, "__clause_element__"): + yield k + yield v + + def visit_dml_values(self, element, **kw): + expr_values = {k for k in element if hasattr(k, "__clause_element__")} + str_values = expr_values.symmetric_difference(element) + + for k in sorted(str_values): + yield element[k] + for k in expr_values: + yield k + yield element[k] + + def visit_dml_multi_values(self, element, **kw): + return () + _get_children = _GetChildren() @@ -644,6 +795,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): def visit_string(self, left_parent, left, right_parent, right, **kw): return left == right + def visit_string_list(self, left_parent, left, right_parent, right, **kw): + return left == right + def visit_anon_name(self, left_parent, left, right_parent, right, **kw): return _resolve_name_for_compare( left_parent, left, self.anon_map[0], **kw @@ -663,6 +817,11 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): return left == right + def visit_dialect_options( + self, left_parent, left, right_parent, right, **kw + ): + return left == right + def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): return left == right @@ -713,6 +872,55 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): ): raise NotImplementedError() + def visit_dml_ordered_values( + self, left_parent, left, right_parent, right, **kw + ): + # sequence of tuple pairs + + for (lk, lv), (rk, rv) in util.zip_longest( + left, right, fillvalue=(None, None) + ): + lkce = hasattr(lk, "__clause_element__") + rkce = hasattr(rk, "__clause_element__") + if lkce != rkce: + return COMPARE_FAILED + elif lkce and not self.compare_inner(lk, rk, **kw): + return COMPARE_FAILED + elif not lkce and lk != rk: + return COMPARE_FAILED + elif not self.compare_inner(lv, rv, **kw): + return COMPARE_FAILED + + def visit_dml_values(self, left_parent, left, right_parent, right, **kw): + if left is None or right is None or len(left) != len(right): + return COMPARE_FAILED + + for lk in left: + lv = left[lk] + + if lk not in right: + return COMPARE_FAILED + rv = right[lk] + + if not self.compare_inner(lv, rv, **kw): + return COMPARE_FAILED + + def visit_dml_multi_values( + self, left_parent, left, right_parent, right, **kw + ): + for lseq, rseq in util.zip_longest(left, right, fillvalue=None): + if lseq is None or rseq is None: + return COMPARE_FAILED + + for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None): + if ( + self.visit_dml_values( + left_parent, ld, right_parent, rd, **kw + ) + is COMPARE_FAILED + ): + return COMPARE_FAILED + def compare_clauselist(self, left, right, **kw): if left.operator is right.operator: if operators.is_associative(left.operator): @@ -731,11 +939,11 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): if left.operator == right.operator: if operators.is_commutative(left.operator): if ( - compare(left.left, right.left, **kw) - and compare(left.right, right.right, **kw) + self.compare_inner(left.left, right.left, **kw) + and self.compare_inner(left.right, right.right, **kw) ) or ( - compare(left.left, right.right, **kw) - and compare(left.right, right.left, **kw) + self.compare_inner(left.left, right.right, **kw) + and self.compare_inner(left.right, right.left, **kw) ): return ["operator", "negate", "left", "right"] else: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index fda48c657..a049d9bb0 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -269,6 +269,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_string_list = symbol("SL") + """Visit a list of strings.""" + dp_anon_name = symbol("AN") """Visit a potentially "anonymized" string value. @@ -313,6 +316,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_dialect_options = symbol("DO") + """visit a dialect options structure.""" + dp_string_clauseelement_dict = symbol("CD") """Visit a dictionary of string keys to :class:`.ClauseElement` objects. @@ -365,6 +371,21 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_dml_ordered_values = symbol("DML_OV") + """visit the values() ordered tuple list of an :class:`.Update` object.""" + + dp_dml_values = symbol("DML_V") + """visit the values() dictionary of a :class:`.ValuesBase + (e.g. Insert or Update) object. + + """ + + dp_dml_multi_values = symbol("DML_MV") + """visit the values() multi-valued list of dictionaries of an + :class:`.Insert` object. + + """ + class ExtendedInternalTraversal(InternalTraversal): """defines additional symbols that are useful in caching applications. |