summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
1 files changed, 52 insertions, 20 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 320c7b782..453ff56d2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -871,12 +871,11 @@ class SQLCompiler(Compiled):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
- add_to_result_map(
- name,
- orig_name,
- (column, name, column.key, column._label) + result_map_targets,
- column.type,
- )
+ targets = (column, name, column.key) + result_map_targets
+ if column._label:
+ targets += (column._label,)
+
+ add_to_result_map(name, orig_name, targets, column.type)
if is_literal:
# note we are not currently accommodating for
@@ -925,7 +924,7 @@ class SQLCompiler(Compiled):
text = text.replace("%", "%%")
return text
- def visit_textclause(self, textclause, **kw):
+ def visit_textclause(self, textclause, add_to_result_map=None, **kw):
def do_bindparam(m):
name = m.group(1)
if name in textclause._bindparams:
@@ -936,6 +935,12 @@ class SQLCompiler(Compiled):
if not self.stack:
self.isplaintext = True
+ if add_to_result_map:
+ # text() object is present in the columns clause of a
+ # select(). Add a no-name entry to the result map so that
+ # row[text()] produces a result
+ add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
+
# un-escape any \:params
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
@@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled):
return " AS " + alias_name_text
def _add_to_result_map(self, keyname, name, objects, type_):
+ if keyname is None:
+ self._ordered_columns = False
+ self._textual_ordered_columns = True
self._result_columns.append((keyname, name, objects, type_))
def _label_select_column(
@@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled):
column_clause_args,
name=None,
within_columns_clause=True,
+ column_is_repeated=False,
need_column_expressions=False,
):
"""produce labeled columns present in a select()."""
@@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled):
need_column_expressions or populate_result_map
):
col_expr = impl.column_expression(column)
+ else:
+ col_expr = column
- if populate_result_map:
+ if populate_result_map:
+ # pass an "add_to_result_map" callable into the compilation
+ # of embedded columns. this collects information about the
+ # column as it will be fetched in the result and is coordinated
+ # with cursor.description when the query is executed.
+ add_to_result_map = self._add_to_result_map
+
+ # if the SELECT statement told us this column is a repeat,
+ # wrap the callable with one that prevents the addition of the
+ # targets
+ if column_is_repeated:
+ _add_to_result_map = add_to_result_map
def add_to_result_map(keyname, name, objects, type_):
- self._add_to_result_map(
+ _add_to_result_map(keyname, name, (), type_)
+
+ # if we redefined col_expr for type expressions, wrap the
+ # callable with one that adds the original column to the targets
+ elif col_expr is not column:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(
keyname, name, (column,) + objects, type_
)
- else:
- add_to_result_map = None
else:
- col_expr = column
- if populate_result_map:
- add_to_result_map = self._add_to_result_map
- else:
- add_to_result_map = None
+ add_to_result_map = None
if not within_columns_clause:
result_expr = col_expr
@@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled):
)
and (
not hasattr(column, "name")
- or isinstance(column, functions.Function)
+ or isinstance(column, functions.FunctionElement)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
@@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled):
asfrom,
column_clause_args,
name=name,
+ column_is_repeated=repeated,
need_column_expressions=need_column_expressions,
)
- for name, column in select._columns_plus_names
+ for name, column, repeated in select._columns_plus_names
]
if c is not None
]
@@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
- [name for (key, name) in select._columns_plus_names],
[
name
- for (key, name) in select_wraps_for._columns_plus_names
+ for (key, name, repeated) in select._columns_plus_names
+ ],
+ [
+ name
+ for (
+ key,
+ name,
+ repeated,
+ ) in select_wraps_for._columns_plus_names
],
)
)