diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 66 |
1 files changed, 44 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8a8c773f8..e0cdbe24c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -808,6 +808,7 @@ class SQLCompiler(engine.Compiled): return text def visit_alias(self, alias, asfrom=False, ashint=False, + iscrud=False, fromhints=None, **kwargs): if asfrom or ashint: if isinstance(alias.name, sql._truncated_label): @@ -824,9 +825,8 @@ class SQLCompiler(engine.Compiled): self.preparer.format_alias(alias, alias_name) if fromhints and alias in fromhints: - hinttext = self.get_from_hint_text(alias, fromhints[alias]) - if hinttext: - ret += " " + hinttext + ret = self.format_from_hint_text(ret, alias, + fromhints[alias], iscrud) return ret else: @@ -863,6 +863,12 @@ class SQLCompiler(engine.Compiled): else: return column + def format_from_hint_text(self, sqltext, table, hint, iscrud): + hinttext = self.get_from_hint_text(table, hint) + if hinttext: + sqltext += " " + hinttext + return sqltext + def get_select_hint_text(self, byfroms): return None @@ -1025,7 +1031,7 @@ class SQLCompiler(engine.Compiled): text += " OFFSET " + self.process(sql.literal(select._offset)) return text - def visit_table(self, table, asfrom=False, ashint=False, + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: if getattr(table, "schema", None): @@ -1036,9 +1042,8 @@ class SQLCompiler(engine.Compiled): else: ret = self.preparer.quote(table.name, table.quote) if fromhints and table in fromhints: - hinttext = self.get_from_hint_text(table, fromhints[table]) - if hinttext: - ret += " " + hinttext + ret = self.format_from_hint_text(ret, table, + fromhints[table], iscrud) return ret else: return "" @@ -1073,7 +1078,8 @@ class SQLCompiler(engine.Compiled): if prefixes: text += " " + " ".join(prefixes) - text += " INTO " + preparer.format_table(insert_stmt.table) + text += " INTO " + table_text = preparer.format_table(insert_stmt.table) if insert_stmt._hints: dialect_hints = dict([ @@ -1083,11 +1089,15 @@ class SQLCompiler(engine.Compiled): if dialect in ('*', self.dialect.name) ]) if insert_stmt.table in dialect_hints: - text += " " + self.get_crud_hint_text( + table_text = self.format_from_hint_text( + table_text, insert_stmt.table, - dialect_hints[insert_stmt.table] + dialect_hints[insert_stmt.table], + True ) + text += table_text + if colparams or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in colparams]) @@ -1123,7 +1133,8 @@ class SQLCompiler(engine.Compiled): MySQL overrides this. """ - return self.preparer.format_table(from_table) + return from_table._compiler_dispatch(self, asfrom=True, + iscrud=True, **kw) def update_from_clause(self, update_stmt, from_table, extra_froms, @@ -1149,10 +1160,9 @@ class SQLCompiler(engine.Compiled): colparams = self._get_colparams(update_stmt, extra_froms) - text = "UPDATE " + self.update_tables_clause( - update_stmt, - update_stmt.table, - extra_froms, **kw) + text = "UPDATE " + table_text = self.update_tables_clause(update_stmt, update_stmt.table, + extra_froms, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1162,13 +1172,17 @@ class SQLCompiler(engine.Compiled): if dialect in ('*', self.dialect.name) ]) if update_stmt.table in dialect_hints: - text += " " + self.get_crud_hint_text( + table_text = self.format_from_hint_text( + table_text, update_stmt.table, - dialect_hints[update_stmt.table] + dialect_hints[update_stmt.table], + True ) else: dialect_hints = None + text += table_text + text += ' SET ' if extra_froms and self.render_table_with_column_in_update_from: text += ', '.join( @@ -1430,7 +1444,9 @@ class SQLCompiler(engine.Compiled): self.stack.append({'from': set([delete_stmt.table])}) self.isdelete = True - text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + text = "DELETE FROM " + table_text = delete_stmt.table._compiler_dispatch(self, + asfrom=True, iscrud=True) if delete_stmt._hints: dialect_hints = dict([ @@ -1440,13 +1456,18 @@ class SQLCompiler(engine.Compiled): if dialect in ('*', self.dialect.name) ]) if delete_stmt.table in dialect_hints: - text += " " + self.get_crud_hint_text( - delete_stmt.table, - dialect_hints[delete_stmt.table] + table_text = self.format_from_hint_text( + table_text, + delete_stmt.table, + dialect_hints[delete_stmt.table], + True ) + else: dialect_hints = None + text += table_text + if delete_stmt._returning: self.returning = delete_stmt._returning if self.returning_precedes_values: @@ -1454,7 +1475,8 @@ class SQLCompiler(engine.Compiled): delete_stmt, delete_stmt._returning) if delete_stmt._whereclause is not None: - text += " WHERE " + self.process(delete_stmt._whereclause) + text += " WHERE " + text += delete_stmt._whereclause._compiler_dispatch(self) if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( |