summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-10-03 17:36:27 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2019-10-07 23:06:06 -0400
commit65aee6cce57fd1cca3a95814feff3ed99a5a51ee (patch)
tree0352d74938902a9242dfb97ca5215d9191a2ad16 /lib/sqlalchemy/sql/compiler.py
parentebd9788c986c56b8b845fa83609a6eb2c0cef083 (diff)
downloadsqlalchemy-65aee6cce57fd1cca3a95814feff3ed99a5a51ee.tar.gz
Add result map targeting for custom compiled, text objects
In order for text(), custom compiled objects, etc. to be usable by Query(), they are all targeted by object key in the result map. As we no longer want Query to implicitly label these, as well as that text() has no label feature, support adding entries to the result map that have no name, key, or type, only the object itself, and then ensure that the compiler sets up for positional targeting when this condition is detected. Allows for more flexible ORM query usage with custom expressions and text() while having less special logic in query itself. Fixes: #4887 Change-Id: Ie073da127d292d43cb132a2b31bc90af88bfe2fd
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
1 files changed, 52 insertions, 20 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 320c7b782..453ff56d2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -871,12 +871,11 @@ class SQLCompiler(Compiled):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
- add_to_result_map(
- name,
- orig_name,
- (column, name, column.key, column._label) + result_map_targets,
- column.type,
- )
+ targets = (column, name, column.key) + result_map_targets
+ if column._label:
+ targets += (column._label,)
+
+ add_to_result_map(name, orig_name, targets, column.type)
if is_literal:
# note we are not currently accommodating for
@@ -925,7 +924,7 @@ class SQLCompiler(Compiled):
text = text.replace("%", "%%")
return text
- def visit_textclause(self, textclause, **kw):
+ def visit_textclause(self, textclause, add_to_result_map=None, **kw):
def do_bindparam(m):
name = m.group(1)
if name in textclause._bindparams:
@@ -936,6 +935,12 @@ class SQLCompiler(Compiled):
if not self.stack:
self.isplaintext = True
+ if add_to_result_map:
+ # text() object is present in the columns clause of a
+ # select(). Add a no-name entry to the result map so that
+ # row[text()] produces a result
+ add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
+
# un-escape any \:params
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
@@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled):
return " AS " + alias_name_text
def _add_to_result_map(self, keyname, name, objects, type_):
+ if keyname is None:
+ self._ordered_columns = False
+ self._textual_ordered_columns = True
self._result_columns.append((keyname, name, objects, type_))
def _label_select_column(
@@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled):
column_clause_args,
name=None,
within_columns_clause=True,
+ column_is_repeated=False,
need_column_expressions=False,
):
"""produce labeled columns present in a select()."""
@@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled):
need_column_expressions or populate_result_map
):
col_expr = impl.column_expression(column)
+ else:
+ col_expr = column
- if populate_result_map:
+ if populate_result_map:
+ # pass an "add_to_result_map" callable into the compilation
+ # of embedded columns. this collects information about the
+ # column as it will be fetched in the result and is coordinated
+ # with cursor.description when the query is executed.
+ add_to_result_map = self._add_to_result_map
+
+ # if the SELECT statement told us this column is a repeat,
+ # wrap the callable with one that prevents the addition of the
+ # targets
+ if column_is_repeated:
+ _add_to_result_map = add_to_result_map
def add_to_result_map(keyname, name, objects, type_):
- self._add_to_result_map(
+ _add_to_result_map(keyname, name, (), type_)
+
+ # if we redefined col_expr for type expressions, wrap the
+ # callable with one that adds the original column to the targets
+ elif col_expr is not column:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(
keyname, name, (column,) + objects, type_
)
- else:
- add_to_result_map = None
else:
- col_expr = column
- if populate_result_map:
- add_to_result_map = self._add_to_result_map
- else:
- add_to_result_map = None
+ add_to_result_map = None
if not within_columns_clause:
result_expr = col_expr
@@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled):
)
and (
not hasattr(column, "name")
- or isinstance(column, functions.Function)
+ or isinstance(column, functions.FunctionElement)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
@@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled):
asfrom,
column_clause_args,
name=name,
+ column_is_repeated=repeated,
need_column_expressions=need_column_expressions,
)
- for name, column in select._columns_plus_names
+ for name, column, repeated in select._columns_plus_names
]
if c is not None
]
@@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
- [name for (key, name) in select._columns_plus_names],
[
name
- for (key, name) in select_wraps_for._columns_plus_names
+ for (key, name, repeated) in select._columns_plus_names
+ ],
+ [
+ name
+ for (
+ key,
+ name,
+ repeated,
+ ) in select_wraps_for._columns_plus_names
],
)
)