summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-03-24 21:30:52 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-03-24 21:30:52 +0000
commit221aff778e1eb3c3aa8f8a1f72629177442694bc (patch)
tree5f668dd290cce756f4bfc1e60a0ec3c59b0951c8 /lib/sqlalchemy/sql/compiler.py
parent1c1c925fe3a77581f4879f6b6fe0bb6b6158cc3d (diff)
parent6f02d5edd88fe2475629438b0730181a2b00c5fe (diff)
downloadsqlalchemy-221aff778e1eb3c3aa8f8a1f72629177442694bc.tar.gz
Merge "pep484 - SQL internals" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py59
1 files changed, 38 insertions, 21 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d3e91a8d5..176e3637e 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -71,6 +71,7 @@ from .schema import Column
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import prefix_anon_map
+from .visitors import Visitable
from .. import exc
from .. import util
from ..util.typing import Literal
@@ -614,10 +615,10 @@ class Compiled:
raise NotImplementedError()
- def process(self, obj, **kwargs):
+ def process(self, obj: Visitable, **kwargs: Any) -> str:
return obj._compiler_dispatch(self, **kwargs)
- def __str__(self):
+ def __str__(self) -> str:
"""Return the string text of the generated SQL or DDL."""
return self.string or ""
@@ -723,7 +724,7 @@ class SQLCompiler(Compiled):
"""list of columns for which onupdate default values should be evaluated
before an UPDATE takes place"""
- returning: Optional[List[Column[Any]]]
+ returning: Optional[List[ColumnClause[Any]]]
"""list of columns that will be delivered to cursor.description or
dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
@@ -1485,15 +1486,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
- _key_getters_for_crud_column: Tuple[
- Callable[[Union[str, Column[Any]]], str],
- Callable[[Column[Any]], str],
- Callable[[Column[Any]], str],
- ]
+ # assigned by crud.py for insert/update statements
+ _get_bind_name_for_col: _BindNameForColProtocol
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
- getter = self._key_getters_for_crud_column[2]
+ getter = self._get_bind_name_for_col
if self.escaped_bind_names:
def _get(obj):
@@ -4100,7 +4098,9 @@ class SQLCompiler(Compiled):
def for_update_clause(self, select, **kw):
return " FOR UPDATE"
- def returning_clause(self, stmt, returning_cols):
+ def returning_clause(
+ self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]]
+ ) -> str:
raise exc.CompileError(
"RETURNING is not supported by this "
"dialect's statement compiler."
@@ -4245,12 +4245,13 @@ class SQLCompiler(Compiled):
}
)
- crud_params = crud._get_crud_params(
+ crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, **kw
)
+ crud_params_single = crud_params_struct.single_params
if (
- not crud_params
+ not crud_params_single
and not self.dialect.supports_default_values
and not self.dialect.supports_default_metavalue
and not self.dialect.supports_empty_insert
@@ -4268,9 +4269,9 @@ class SQLCompiler(Compiled):
"version settings does not support "
"in-place multirow inserts." % self.dialect.name
)
- crud_params_single = crud_params[0]
+ crud_params_single = crud_params_struct.single_params
else:
- crud_params_single = crud_params
+ crud_params_single = crud_params_struct.single_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -4295,7 +4296,7 @@ class SQLCompiler(Compiled):
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
- [expr for c, expr, value in crud_params_single]
+ [expr for _, expr, _ in crud_params_single]
)
if self.returning or insert_stmt._returning:
@@ -4325,19 +4326,24 @@ class SQLCompiler(Compiled):
)
else:
text += " %s" % select_text
- elif not crud_params and supports_default_values:
+ elif not crud_params_single and supports_default_values:
text += " DEFAULT VALUES"
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)"
- % (", ".join(value for c, expr, value in crud_param_set))
- for crud_param_set in crud_params
+ % (", ".join(value for _, _, value in crud_param_set))
+ for crud_param_set in crud_params_struct.all_multi_params
)
)
else:
insert_single_values_expr = ", ".join(
- [value for c, expr, value in crud_params]
+ [
+ value
+ for _, _, value in cast(
+ "List[Tuple[Any, Any, str]]", crud_params_single
+ )
+ ]
)
text += " VALUES (%s)" % insert_single_values_expr
if toplevel and insert_stmt._post_values_clause is None:
@@ -4445,9 +4451,10 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)
- crud_params = crud._get_crud_params(
+ crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, **kw
)
+ crud_params = crud_params_struct.single_params
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
@@ -4462,7 +4469,12 @@ class SQLCompiler(Compiled):
text += table_text
text += " SET "
- text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
+ text += ", ".join(
+ expr + "=" + value
+ for _, expr, value in cast(
+ "List[Tuple[Any, str, str]]", crud_params
+ )
+ )
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
@@ -5448,6 +5460,11 @@ class _SchemaForObjectCallable(Protocol):
...
+class _BindNameForColProtocol(Protocol):
+ def __call__(self, col: ColumnClause[Any]) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""