diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 |
1 files changed, 52 insertions, 20 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 320c7b782..453ff56d2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -871,12 +871,11 @@ class SQLCompiler(Compiled): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - add_to_result_map( - name, - orig_name, - (column, name, column.key, column._label) + result_map_targets, - column.type, - ) + targets = (column, name, column.key) + result_map_targets + if column._label: + targets += (column._label,) + + add_to_result_map(name, orig_name, targets, column.type) if is_literal: # note we are not currently accommodating for @@ -925,7 +924,7 @@ class SQLCompiler(Compiled): text = text.replace("%", "%%") return text - def visit_textclause(self, textclause, **kw): + def visit_textclause(self, textclause, add_to_result_map=None, **kw): def do_bindparam(m): name = m.group(1) if name in textclause._bindparams: @@ -936,6 +935,12 @@ class SQLCompiler(Compiled): if not self.stack: self.isplaintext = True + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), @@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled): return " AS " + alias_name_text def _add_to_result_map(self, keyname, name, objects, type_): + if keyname is None: + self._ordered_columns = False + self._textual_ordered_columns = True self._result_columns.append((keyname, name, objects, type_)) def _label_select_column( @@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled): column_clause_args, name=None, within_columns_clause=True, + column_is_repeated=False, need_column_expressions=False, ): """produce labeled columns present in a select().""" @@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled): need_column_expressions or populate_result_map ): col_expr = impl.column_expression(column) + else: + col_expr = column - if populate_result_map: + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - self._add_to_result_map( + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( keyname, name, (column,) + objects, type_ ) - else: - add_to_result_map = None else: - col_expr = column - if populate_result_map: - add_to_result_map = self._add_to_result_map - else: - add_to_result_map = None + add_to_result_map = None if not within_columns_clause: result_expr = col_expr @@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled): ) and ( not hasattr(column, "name") - or isinstance(column, functions.Function) + or isinstance(column, functions.FunctionElement) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) @@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled): asfrom, column_clause_args, name=name, + column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column in select._columns_plus_names + for name, column, repeated in select._columns_plus_names ] if c is not None ] @@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled): translate = dict( zip( - [name for (key, name) in select._columns_plus_names], [ name - for (key, name) in select_wraps_for._columns_plus_names + for (key, name, repeated) in select._columns_plus_names + ], + [ + name + for ( + key, + name, + repeated, + ) in select_wraps_for._columns_plus_names ], ) ) |