summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-08-17 18:35:25 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-08-17 18:35:25 -0400
commita2468c8a31c8308cdb5740f2401e9dedd003836e (patch)
treee4cc3eb17c59f678ea2919ecd0880f1df7854b6e /lib/sqlalchemy/sql/compiler.py
parent20fa7fe2b85d356e3da08191f01d7528ded42033 (diff)
downloadsqlalchemy-a2468c8a31c8308cdb5740f2401e9dedd003836e.tar.gz
- [feature] To complement [ticket:2547], types
can now provide "bind expressions" and "column expressions" which allow compile-time injection of SQL expressions into statements on a per-column or per-bind level. This is to suit the use case of a type which needs to augment bind- and result- behavior at the SQL level, as opposed to in the Python level. Allows for schemes like transparent encryption/ decryption, usage of Postgis functions, etc. [ticket:1534] - update postgis example fully. - still need to repair the result map propagation here to be transparent for cases like "labeled column".
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py136
1 files changed, 88 insertions, 48 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 300cdb6b4..f975225d6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -172,6 +172,7 @@ class _CompileLabel(visitors.Visitable):
def quote(self):
return self.element.quote
+
class SQLCompiler(engine.Compiled):
"""Default implementation of Compiled.
@@ -373,7 +374,8 @@ class SQLCompiler(engine.Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
- def visit_label(self, label, result_map=None,
+ def visit_label(self, label,
+ add_to_result_map = None,
within_label_clause=False,
within_columns_clause=False, **kw):
# only render labels within the columns clause
@@ -385,14 +387,18 @@ class SQLCompiler(engine.Compiled):
else:
labelname = label.name
- if result_map is not None:
- result_map[labelname
+ if add_to_result_map is not None:
+ self.result_map[
+ labelname
if self.dialect.case_sensitive
- else labelname.lower()] = (
- label.name,
- (label, label.element, labelname, ) +
- label._alt_names,
- label.type)
+ else labelname.lower()
+ ] = (
+ label.name,
+ (label, label.element, labelname, ) +
+ label._alt_names +
+ add_to_result_map,
+ label.type,
+ )
return label.element._compiler_dispatch(self,
within_columns_clause=True,
@@ -405,7 +411,7 @@ class SQLCompiler(engine.Compiled):
within_columns_clause=False,
**kw)
- def visit_column(self, column, result_map=None, **kwargs):
+ def visit_column(self, column, add_to_result_map=None, **kwargs):
name = orig_name = column.name
if name is None:
raise exc.CompileError("Cannot compile Column object until "
@@ -415,12 +421,16 @@ class SQLCompiler(engine.Compiled):
if not is_literal and isinstance(name, sql._truncated_label):
name = self._truncated_identifier("colident", name)
- if result_map is not None:
- result_map[name
+ if add_to_result_map is not None:
+ self.result_map[
+ name
if self.dialect.case_sensitive
- else name.lower()] = (orig_name,
- (column, name, column.key),
- column.type)
+ else name.lower()
+ ] = (
+ orig_name,
+ (column, name, column.key) + add_to_result_map,
+ column.type
+ )
if is_literal:
name = self.escape_literal_column(name)
@@ -527,7 +537,7 @@ class SQLCompiler(engine.Compiled):
cast.typeclause._compiler_dispatch(self, **kwargs))
def visit_over(self, over, **kwargs):
- x ="%s OVER (" % over.func._compiler_dispatch(self, **kwargs)
+ x = "%s OVER (" % over.func._compiler_dispatch(self, **kwargs)
if over.partition_by is not None:
x += "PARTITION BY %s" % \
over.partition_by._compiler_dispatch(self, **kwargs)
@@ -544,12 +554,13 @@ class SQLCompiler(engine.Compiled):
return "EXTRACT(%s FROM %s)" % (field,
extract.expr._compiler_dispatch(self, **kwargs))
- def visit_function(self, func, result_map=None, **kwargs):
- if result_map is not None:
- result_map[func.name
+ def visit_function(self, func, add_to_result_map=None, **kwargs):
+ if add_to_result_map is not None:
+ self.result_map[
+ func.name
if self.dialect.case_sensitive
- else func.name.lower()] = \
- (func.name, None, func.type)
+ else func.name.lower()
+ ] = (func.name, add_to_result_map, func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
if disp:
@@ -557,14 +568,15 @@ class SQLCompiler(engine.Compiled):
else:
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
return ".".join(list(func.packagenames) + [name]) % \
- {'expr':self.function_argspec(func, **kwargs)}
+ {'expr': self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
def visit_sequence(self, sequence):
raise NotImplementedError(
- "Dialect '%s' does not support sequence increments." % self.dialect.name
+ "Dialect '%s' does not support sequence increments." %
+ self.dialect.name
)
def function_argspec(self, func, **kwargs):
@@ -704,7 +716,14 @@ class SQLCompiler(engine.Compiled):
def visit_bindparam(self, bindparam, within_columns_clause=False,
- literal_binds=False, **kwargs):
+ literal_binds=False,
+ skip_bind_expression=False,
+ **kwargs):
+
+ if not skip_bind_expression and bindparam.type._has_bind_expression:
+ bind_expression = bindparam.type.bind_expression(bindparam)
+ return self.process(bind_expression,
+ skip_bind_expression=True)
if literal_binds or \
(within_columns_clause and \
@@ -912,17 +931,31 @@ class SQLCompiler(engine.Compiled):
else:
return alias.original._compiler_dispatch(self, **kwargs)
- def label_select_column(self, select, column, asfrom):
- """label columns present in a select()."""
+ def _label_select_column(self, select, column, populate_result_map,
+ asfrom, column_clause_args):
+ """produce labeled columns present in a select()."""
+
+ if column.type._has_column_expression:
+ col_expr = column.type.column_expression(column)
+ if populate_result_map:
+ add_to_result_map = (column, )
+ else:
+ add_to_result_map = None
+ else:
+ col_expr = column
+ if populate_result_map:
+ add_to_result_map = ()
+ else:
+ add_to_result_map = None
- if isinstance(column, sql.Label):
- return column
+ if isinstance(col_expr, sql.Label):
+ result_expr = col_expr
elif select is not None and \
select.use_labels and \
column._label:
- return _CompileLabel(
- column,
+ result_expr = _CompileLabel(
+ col_expr,
column._label,
alt_names=(column._key_label, )
)
@@ -933,15 +966,25 @@ class SQLCompiler(engine.Compiled):
not column.is_literal and \
column.table is not None and \
not isinstance(column.table, sql.Select):
- return _CompileLabel(column, sql._as_truncated(column.name),
- alt_names=(column.key,))
+ result_expr = _CompileLabel(col_expr,
+ sql._as_truncated(column.name),
+ alt_names=(column.key,))
elif not isinstance(column,
(sql.UnaryExpression, sql.TextClause)) \
and (not hasattr(column, 'name') or \
isinstance(column, sql.Function)):
- return _CompileLabel(column, column.anon_label)
+ result_expr = _CompileLabel(col_expr, column.anon_label)
+ elif col_expr is not column:
+ result_expr = _CompileLabel(col_expr, column.anon_label)
else:
- return column
+ result_expr = col_expr
+
+ return result_expr._compiler_dispatch(
+ self, within_columns_clause=True,
+ add_to_result_map=add_to_result_map,
+ **column_clause_args
+ )
+
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
@@ -976,24 +1019,21 @@ class SQLCompiler(engine.Compiled):
# to outermost if existingfroms: correlate_froms =
# correlate_froms.union(existingfroms)
- self.stack.append({'from': correlate_froms, 'iswrapper'
- : iswrapper})
+ self.stack.append({'from': correlate_froms,
+ 'iswrapper': iswrapper})
- if compound_index==1 and not entry or entry.get('iswrapper', False):
- column_clause_args = {'result_map':self.result_map,
- 'positional_names':positional_names}
- else:
- column_clause_args = {'positional_names':positional_names}
+ populate_result_map = compound_index == 1 and not entry or \
+ entry.get('iswrapper', False)
+ column_clause_args = {'positional_names': positional_names}
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c for c in [
- self.label_select_column(select, co, asfrom=asfrom).\
- _compiler_dispatch(self,
- within_columns_clause=True,
- **column_clause_args)
- for co in util.unique_list(select.inner_columns)
- ]
+ self._label_select_column(select, column,
+ populate_result_map, asfrom,
+ column_clause_args)
+ for column in util.unique_list(select.inner_columns)
+ ]
if c is not None
]
@@ -1059,8 +1099,8 @@ class SQLCompiler(engine.Compiled):
text += self.for_update_clause(select)
if self.ctes and \
- compound_index==1 and not entry:
- text = self._render_cte_clause() + text
+ compound_index == 1 and not entry:
+ text = self._render_cte_clause() + text
self.stack.pop(-1)