diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-08-17 18:35:25 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-08-17 18:35:25 -0400 |
commit | a2468c8a31c8308cdb5740f2401e9dedd003836e (patch) | |
tree | e4cc3eb17c59f678ea2919ecd0880f1df7854b6e /lib/sqlalchemy/sql/compiler.py | |
parent | 20fa7fe2b85d356e3da08191f01d7528ded42033 (diff) | |
download | sqlalchemy-a2468c8a31c8308cdb5740f2401e9dedd003836e.tar.gz |
- [feature] To complement [ticket:2547], types
can now provide "bind expressions" and
"column expressions" which allow compile-time
injection of SQL expressions into statements
on a per-column or per-bind level. This is
to suit the use case of a type which needs
to augment bind- and result- behavior at the
SQL level, as opposed to in the Python level.
Allows for schemes like transparent encryption/
decryption, usage of Postgis functions, etc.
[ticket:1534]
- update postgis example fully.
- still need to repair the result map propagation
here to be transparent for cases like "labeled column".
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 136 |
1 files changed, 88 insertions, 48 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 300cdb6b4..f975225d6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -172,6 +172,7 @@ class _CompileLabel(visitors.Visitable): def quote(self): return self.element.quote + class SQLCompiler(engine.Compiled): """Default implementation of Compiled. @@ -373,7 +374,8 @@ class SQLCompiler(engine.Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" - def visit_label(self, label, result_map=None, + def visit_label(self, label, + add_to_result_map = None, within_label_clause=False, within_columns_clause=False, **kw): # only render labels within the columns clause @@ -385,14 +387,18 @@ class SQLCompiler(engine.Compiled): else: labelname = label.name - if result_map is not None: - result_map[labelname + if add_to_result_map is not None: + self.result_map[ + labelname if self.dialect.case_sensitive - else labelname.lower()] = ( - label.name, - (label, label.element, labelname, ) + - label._alt_names, - label.type) + else labelname.lower() + ] = ( + label.name, + (label, label.element, labelname, ) + + label._alt_names + + add_to_result_map, + label.type, + ) return label.element._compiler_dispatch(self, within_columns_clause=True, @@ -405,7 +411,7 @@ class SQLCompiler(engine.Compiled): within_columns_clause=False, **kw) - def visit_column(self, column, result_map=None, **kwargs): + def visit_column(self, column, add_to_result_map=None, **kwargs): name = orig_name = column.name if name is None: raise exc.CompileError("Cannot compile Column object until " @@ -415,12 +421,16 @@ class SQLCompiler(engine.Compiled): if not is_literal and isinstance(name, sql._truncated_label): name = self._truncated_identifier("colident", name) - if result_map is not None: - result_map[name + if add_to_result_map is not None: + self.result_map[ + name if self.dialect.case_sensitive - else name.lower()] = (orig_name, - (column, name, column.key), - column.type) + else name.lower() + ] = ( + orig_name, + (column, name, column.key) + add_to_result_map, + column.type + ) if is_literal: name = self.escape_literal_column(name) @@ -527,7 +537,7 @@ class SQLCompiler(engine.Compiled): cast.typeclause._compiler_dispatch(self, **kwargs)) def visit_over(self, over, **kwargs): - x ="%s OVER (" % over.func._compiler_dispatch(self, **kwargs) + x = "%s OVER (" % over.func._compiler_dispatch(self, **kwargs) if over.partition_by is not None: x += "PARTITION BY %s" % \ over.partition_by._compiler_dispatch(self, **kwargs) @@ -544,12 +554,13 @@ class SQLCompiler(engine.Compiled): return "EXTRACT(%s FROM %s)" % (field, extract.expr._compiler_dispatch(self, **kwargs)) - def visit_function(self, func, result_map=None, **kwargs): - if result_map is not None: - result_map[func.name + def visit_function(self, func, add_to_result_map=None, **kwargs): + if add_to_result_map is not None: + self.result_map[ + func.name if self.dialect.case_sensitive - else func.name.lower()] = \ - (func.name, None, func.type) + else func.name.lower() + ] = (func.name, add_to_result_map, func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -557,14 +568,15 @@ class SQLCompiler(engine.Compiled): else: name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") return ".".join(list(func.packagenames) + [name]) % \ - {'expr':self.function_argspec(func, **kwargs)} + {'expr': self.function_argspec(func, **kwargs)} def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) def visit_sequence(self, sequence): raise NotImplementedError( - "Dialect '%s' does not support sequence increments." % self.dialect.name + "Dialect '%s' does not support sequence increments." % + self.dialect.name ) def function_argspec(self, func, **kwargs): @@ -704,7 +716,14 @@ class SQLCompiler(engine.Compiled): def visit_bindparam(self, bindparam, within_columns_clause=False, - literal_binds=False, **kwargs): + literal_binds=False, + skip_bind_expression=False, + **kwargs): + + if not skip_bind_expression and bindparam.type._has_bind_expression: + bind_expression = bindparam.type.bind_expression(bindparam) + return self.process(bind_expression, + skip_bind_expression=True) if literal_binds or \ (within_columns_clause and \ @@ -912,17 +931,31 @@ class SQLCompiler(engine.Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) - def label_select_column(self, select, column, asfrom): - """label columns present in a select().""" + def _label_select_column(self, select, column, populate_result_map, + asfrom, column_clause_args): + """produce labeled columns present in a select().""" + + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + if populate_result_map: + add_to_result_map = (column, ) + else: + add_to_result_map = None + else: + col_expr = column + if populate_result_map: + add_to_result_map = () + else: + add_to_result_map = None - if isinstance(column, sql.Label): - return column + if isinstance(col_expr, sql.Label): + result_expr = col_expr elif select is not None and \ select.use_labels and \ column._label: - return _CompileLabel( - column, + result_expr = _CompileLabel( + col_expr, column._label, alt_names=(column._key_label, ) ) @@ -933,15 +966,25 @@ class SQLCompiler(engine.Compiled): not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): - return _CompileLabel(column, sql._as_truncated(column.name), - alt_names=(column.key,)) + result_expr = _CompileLabel(col_expr, + sql._as_truncated(column.name), + alt_names=(column.key,)) elif not isinstance(column, (sql.UnaryExpression, sql.TextClause)) \ and (not hasattr(column, 'name') or \ isinstance(column, sql.Function)): - return _CompileLabel(column, column.anon_label) + result_expr = _CompileLabel(col_expr, column.anon_label) + elif col_expr is not column: + result_expr = _CompileLabel(col_expr, column.anon_label) else: - return column + result_expr = col_expr + + return result_expr._compiler_dispatch( + self, within_columns_clause=True, + add_to_result_map=add_to_result_map, + **column_clause_args + ) + def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) @@ -976,24 +1019,21 @@ class SQLCompiler(engine.Compiled): # to outermost if existingfroms: correlate_froms = # correlate_froms.union(existingfroms) - self.stack.append({'from': correlate_froms, 'iswrapper' - : iswrapper}) + self.stack.append({'from': correlate_froms, + 'iswrapper': iswrapper}) - if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map, - 'positional_names':positional_names} - else: - column_clause_args = {'positional_names':positional_names} + populate_result_map = compound_index == 1 and not entry or \ + entry.get('iswrapper', False) + column_clause_args = {'positional_names': positional_names} # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ - self.label_select_column(select, co, asfrom=asfrom).\ - _compiler_dispatch(self, - within_columns_clause=True, - **column_clause_args) - for co in util.unique_list(select.inner_columns) - ] + self._label_select_column(select, column, + populate_result_map, asfrom, + column_clause_args) + for column in util.unique_list(select.inner_columns) + ] if c is not None ] @@ -1059,8 +1099,8 @@ class SQLCompiler(engine.Compiled): text += self.for_update_clause(select) if self.ctes and \ - compound_index==1 and not entry: - text = self._render_cte_clause() + text + compound_index == 1 and not entry: + text = self._render_cte_clause() + text self.stack.pop(-1) |