summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py350
1 files changed, 328 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 1d13ffa9a..201324a2a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -94,6 +94,7 @@ if typing.TYPE_CHECKING:
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .elements import Label
from .functions import Function
from .selectable import AliasedReturnsRows
@@ -390,6 +391,19 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
class ExpandedState(NamedTuple):
+ """represents state to use when producing "expanded" and
+ "post compile" bound parameters for a statement.
+
+ "expanded" parameters are parameters that are generated at
+ statement execution time to suit a number of parameters passed, the most
+ prominent example being the individual elements inside of an IN expression.
+
+ "post compile" parameters are parameters where the SQL literal value
+ will be rendered into the SQL statement at execution time, rather than
+ being passed as separate parameters to the driver.
+
+ """
+
statement: str
additional_parameters: _CoreSingleExecuteParams
processors: Mapping[str, _BindProcessorType[Any]]
@@ -397,7 +411,23 @@ class ExpandedState(NamedTuple):
parameter_expansion: Mapping[str, List[str]]
+class _InsertManyValues(NamedTuple):
+ """represents state to use for executing an "insertmanyvalues" statement"""
+
+ is_default_expr: bool
+ single_values_expr: str
+ insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]]
+ num_positional_params_counted: int
+
+
class Linting(IntEnum):
+ """represent preferences for the 'SQL linting' feature.
+
+ this feature currently includes support for flagging cartesian products
+ in SQL statements.
+
+ """
+
NO_LINTING = 0
"Disable all linting."
@@ -419,6 +449,9 @@ NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
+ """represents current state for the "cartesian product" detection
+ feature."""
+
def lint(self, start=None):
froms = self.froms
if not froms:
@@ -762,8 +795,6 @@ class SQLCompiler(Compiled):
is_sql = True
- _result_columns: List[ResultColumnsEntry]
-
compound_keywords = COMPOUND_KEYWORDS
isdelete: bool = False
@@ -810,12 +841,6 @@ class SQLCompiler(Compiled):
"""major statements such as SELECT, INSERT, UPDATE, DELETE are
tracked in this stack using an entry format."""
- result_columns: List[ResultColumnsEntry]
- """relates label names in the final SQL to a tuple of local
- column/label name, ColumnElement object (if any) and
- TypeEngine. CursorResult uses this for type processing and
- column targeting"""
-
returning_precedes_values: bool = False
"""set to True classwide to generate RETURNING
clauses before the VALUES or WHERE clause (i.e. MSSQL)
@@ -835,6 +860,12 @@ class SQLCompiler(Compiled):
driver/DB enforces this
"""
+ _result_columns: List[ResultColumnsEntry]
+ """relates label names in the final SQL to a tuple of local
+ column/label name, ColumnElement object (if any) and
+ TypeEngine. CursorResult uses this for type processing and
+ column targeting"""
+
_textual_ordered_columns: bool = False
"""tell the result object that the column names as rendered are important,
but they are also "ordered" vs. what is in the compiled object here.
@@ -881,14 +912,9 @@ class SQLCompiler(Compiled):
"""
- insert_single_values_expr: Optional[str] = None
- """When an INSERT is compiled with a single set of parameters inside
- a VALUES expression, the string is assigned here, where it can be
- used for insert batching schemes to rewrite the VALUES expression.
+ _insertmanyvalues: Optional[_InsertManyValues] = None
- .. versionadded:: 1.3.8
-
- """
+ _insert_crud_params: Optional[crud._CrudParamSequence] = None
literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as literal values at statement
@@ -1072,6 +1098,25 @@ class SQLCompiler(Compiled):
if self._render_postcompile:
self._process_parameters_for_postcompile(_populate_self=True)
+ @property
+ def insert_single_values_expr(self) -> Optional[str]:
+ """When an INSERT is compiled with a single set of parameters inside
+ a VALUES expression, the string is assigned here, where it can be
+ used for insert batching schemes to rewrite the VALUES expression.
+
+ .. versionadded:: 1.3.8
+
+ .. versionchanged:: 2.0 This collection is no longer used by
+ SQLAlchemy's built-in dialects, in favor of the currently
+ internal ``_insertmanyvalues`` collection that is used only by
+ :class:`.SQLCompiler`.
+
+ """
+ if self._insertmanyvalues is None:
+ return None
+ else:
+ return self._insertmanyvalues.single_values_expr
+
@util.ro_memoized_property
def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
"""The effective "returning" columns for INSERT, UPDATE or DELETE.
@@ -1620,10 +1665,13 @@ class SQLCompiler(Compiled):
param_key_getter = self._within_exec_param_key_getter
+ assert self.compile_state is not None
+ statement = self.compile_state.statement
+
if TYPE_CHECKING:
- assert isinstance(self.statement, Insert)
+ assert isinstance(statement, Insert)
- table = self.statement.table
+ table = statement.table
getters = [
(operator.methodcaller("get", param_key_getter(col), None), col)
@@ -1697,11 +1745,14 @@ class SQLCompiler(Compiled):
else:
result = util.preloaded.engine_result
+ assert self.compile_state is not None
+ statement = self.compile_state.statement
+
if TYPE_CHECKING:
- assert isinstance(self.statement, Insert)
+ assert isinstance(statement, Insert)
param_key_getter = self._within_exec_param_key_getter
- table = self.statement.table
+ table = statement.table
returning = self.implicit_returning
assert returning is not None
@@ -4506,7 +4557,202 @@ class SQLCompiler(Compiled):
)
return dialect_hints, table_text
+ def _insert_stmt_should_use_insertmanyvalues(self, statement):
+ return (
+ self.dialect.supports_multivalues_insert
+ and self.dialect.use_insertmanyvalues
+ # note self.implicit_returning or self._result_columns
+ # implies self.dialect.insert_returning capability
+ and (
+ self.dialect.use_insertmanyvalues_wo_returning
+ or self.implicit_returning
+ or self._result_columns
+ )
+ )
+
+ def _deliver_insertmanyvalues_batches(
+ self, statement, parameters, generic_setinputsizes, batch_size
+ ):
+ imv = self._insertmanyvalues
+ assert imv is not None
+
+ executemany_values = f"({imv.single_values_expr})"
+
+ lenparams = len(parameters)
+ if imv.is_default_expr and not self.dialect.supports_default_metavalue:
+ # backend doesn't support
+ # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ...
+ # at the moment this is basically SQL Server due to
+ # not being able to use DEFAULT for identity column
+ # just yield out that many single statements! still
+ # faster than a whole connection.execute() call ;)
+ #
+ # note we still are taking advantage of the fact that we know
+ # we are using RETURNING. The generalized approach of fetching
+ # cursor.lastrowid etc. still goes through the more heavyweight
+ # "ExecutionContext per statement" system as it isn't usable
+ # as a generic "RETURNING" approach
+ for batchnum, param in enumerate(parameters, 1):
+ yield (
+ statement,
+ param,
+ generic_setinputsizes,
+ batchnum,
+ lenparams,
+ )
+ return
+ else:
+ statement = statement.replace(
+ executemany_values, "__EXECMANY_TOKEN__"
+ )
+
+ # Use optional insertmanyvalues_max_parameters
+ # to further shrink the batch size so that there are no more than
+ # insertmanyvalues_max_parameters params.
+ # Currently used by SQL Server, which limits statements to 2100 bound
+ # parameters (actually 2099).
+ max_params = self.dialect.insertmanyvalues_max_parameters
+ if max_params:
+ total_num_of_params = len(self.bind_names)
+ num_params_per_batch = len(imv.insert_crud_params)
+ num_params_outside_of_batch = (
+ total_num_of_params - num_params_per_batch
+ )
+ batch_size = min(
+ batch_size,
+ (
+ (max_params - num_params_outside_of_batch)
+ // num_params_per_batch
+ ),
+ )
+
+ batches = list(parameters)
+
+ processed_setinputsizes = None
+ batchnum = 1
+ total_batches = lenparams // batch_size + (
+ 1 if lenparams % batch_size else 0
+ )
+
+ insert_crud_params = imv.insert_crud_params
+ assert insert_crud_params is not None
+
+ escaped_bind_names: Mapping[str, str]
+ if not self.positional:
+ if self.escaped_bind_names:
+ escaped_bind_names = self.escaped_bind_names
+ else:
+ escaped_bind_names = {}
+
+ all_keys = set(parameters[0])
+
+ escaped_insert_crud_params: Sequence[Any] = [
+ (escaped_bind_names.get(col.key, col.key), formatted)
+ for col, _, formatted in insert_crud_params
+ ]
+
+ keys_to_replace = all_keys.intersection(
+ key for key, _ in escaped_insert_crud_params
+ )
+ base_parameters = {
+ key: parameters[0][key]
+ for key in all_keys.difference(keys_to_replace)
+ }
+ executemany_values_w_comma = ""
+ else:
+ escaped_insert_crud_params = ()
+ keys_to_replace = set()
+ base_parameters = {}
+ executemany_values_w_comma = f"({imv.single_values_expr}), "
+
+ while batches:
+ batch = batches[0:batch_size]
+ batches[0:batch_size] = []
+
+ if generic_setinputsizes:
+ # if setinputsizes is present, expand this collection to
+ # suit the batch length as well
+ # currently this will be mssql+pyodbc for internal dialects
+ processed_setinputsizes = [
+ (new_key, len_, typ)
+ for new_key, len_, typ in (
+ (f"{key}_{index}", len_, typ)
+ for index in range(len(batch))
+ for key, len_, typ in generic_setinputsizes
+ )
+ ]
+
+ replaced_parameters: Any
+ if self.positional:
+ # the assumption here is that any parameters that are not
+ # in the VALUES clause are expected to be parameterized
+ # expressions in the RETURNING (or maybe ON CONFLICT) clause.
+ # So based on
+ # which sequence comes first in the compiler's INSERT
+ # statement tells us where to expand the parameters.
+
+ # otherwise we probably shouldn't be doing insertmanyvalues
+ # on the statement.
+
+ num_ins_params = imv.num_positional_params_counted
+
+ if num_ins_params == len(batch[0]):
+ extra_params = ()
+ batch_iterator: Iterable[Tuple[Any, ...]] = batch
+ elif self.returning_precedes_values:
+ extra_params = batch[0][:-num_ins_params]
+ batch_iterator = (b[-num_ins_params:] for b in batch)
+ else:
+ extra_params = batch[0][num_ins_params:]
+ batch_iterator = (b[:num_ins_params] for b in batch)
+
+ replaced_statement = statement.replace(
+ "__EXECMANY_TOKEN__",
+ (executemany_values_w_comma * len(batch))[:-2],
+ )
+
+ replaced_parameters = tuple(
+ itertools.chain.from_iterable(batch_iterator)
+ )
+ if self.returning_precedes_values:
+ replaced_parameters = extra_params + replaced_parameters
+ else:
+ replaced_parameters = replaced_parameters + extra_params
+ else:
+ replaced_values_clauses = []
+ replaced_parameters = base_parameters.copy()
+
+ for i, param in enumerate(batch):
+ new_tokens = [
+ formatted.replace(key, f"{key}__{i}")
+ if key in param
+ else formatted
+ for key, formatted in escaped_insert_crud_params
+ ]
+ replaced_values_clauses.append(
+ f"({', '.join(new_tokens)})"
+ )
+
+ replaced_parameters.update(
+ {f"{key}__{i}": param[key] for key in keys_to_replace}
+ )
+
+ replaced_statement = statement.replace(
+ "__EXECMANY_TOKEN__",
+ ", ".join(replaced_values_clauses),
+ )
+
+ yield (
+ replaced_statement,
+ replaced_parameters,
+ processed_setinputsizes,
+ batchnum,
+ total_batches,
+ )
+ batchnum += 1
+
def visit_insert(self, insert_stmt, **kw):
+
compile_state = insert_stmt._compile_state_factory(
insert_stmt, self, **kw
)
@@ -4529,9 +4775,24 @@ class SQLCompiler(Compiled):
}
)
+ positiontup_before = positiontup_after = 0
+
+ # for positional, insertmanyvalues needs to know how many
+ # bound parameters are in the VALUES sequence; there's no simple
+ # rule because default expressions etc. can have zero or more
+ # params inside them. After multiple attempts to figure this out,
+ # this very simplistic "count before, then count after" works and is
+ # likely the least amount of callcounts, though looks clumsy
+ if self.positiontup:
+ positiontup_before = len(self.positiontup)
+
crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, toplevel, **kw
)
+
+ if self.positiontup:
+ positiontup_after = len(self.positiontup)
+
crud_params_single = crud_params_struct.single_params
if (
@@ -4584,14 +4845,34 @@ class SQLCompiler(Compiled):
)
if self.implicit_returning or insert_stmt._returning:
+
+ # if returning clause is rendered first, capture bound parameters
+ # while visiting and place them prior to the VALUES oriented
+ # bound parameters, when using positional parameter scheme
+ rpv = self.returning_precedes_values
+ flip_pt = rpv and self.positional
+ if flip_pt:
+ pt: Optional[List[str]] = self.positiontup
+ temp_pt: Optional[List[str]]
+ self.positiontup = temp_pt = []
+ else:
+ temp_pt = pt = None
+
returning_clause = self.returning_clause(
insert_stmt,
self.implicit_returning or insert_stmt._returning,
populate_result_map=toplevel,
)
- if self.returning_precedes_values:
+ if flip_pt:
+ if TYPE_CHECKING:
+ assert temp_pt is not None
+ assert pt is not None
+ self.positiontup = temp_pt + pt
+
+ if rpv:
text += " " + returning_clause
+
else:
returning_clause = None
@@ -4614,6 +4895,18 @@ class SQLCompiler(Compiled):
text += " %s" % select_text
elif not crud_params_single and supports_default_values:
text += " DEFAULT VALUES"
+ if toplevel and self._insert_stmt_should_use_insertmanyvalues(
+ insert_stmt
+ ):
+ self._insertmanyvalues = _InsertManyValues(
+ True,
+ self.dialect.default_metavalue_token,
+ cast(
+ "List[Tuple[KeyedColumnElement[Any], str, str]]",
+ crud_params_single,
+ ),
+ (positiontup_after - positiontup_before),
+ )
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
@@ -4623,6 +4916,8 @@ class SQLCompiler(Compiled):
)
)
else:
+ # TODO: why is third element of crud_params_single not str
+ # already?
insert_single_values_expr = ", ".join(
[
value
@@ -4631,9 +4926,20 @@ class SQLCompiler(Compiled):
)
]
)
+
text += " VALUES (%s)" % insert_single_values_expr
- if toplevel:
- self.insert_single_values_expr = insert_single_values_expr
+ if toplevel and self._insert_stmt_should_use_insertmanyvalues(
+ insert_stmt
+ ):
+ self._insertmanyvalues = _InsertManyValues(
+ False,
+ insert_single_values_expr,
+ cast(
+ "List[Tuple[KeyedColumnElement[Any], str, str]]",
+ crud_params_single,
+ ),
+ positiontup_after - positiontup_before,
+ )
if insert_stmt._post_values_clause is not None:
post_values_clause = self.process(