summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py101
-rw-r--r--lib/sqlalchemy/sql/expression.py34
2 files changed, 51 insertions, 84 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index c1f3bc2a0..a31997d1b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -130,13 +130,11 @@ class DefaultCompiler(engine.Compiled):
# a stack. what recursive compiler doesn't have a stack ? :)
self.stack = []
- # a dictionary of result-set column names (strings) to TypeEngine instances,
- # which will be passed to a ResultProxy and used for resultset-level value conversion
- self.typemap = {}
-
- # a dictionary of select columns labels mapped to their "generated" label
- self.column_labels = {}
-
+ # relates label names in the final SQL to
+ # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
+ # ResultProxy uses this for type processing and column targeting
+ self.result_map = {}
+
# a dictionary of ClauseElement subclasses to counters, which are used to
# generate truncated identifier names or "anonymous" identifiers such as
# for aliases
@@ -213,19 +211,15 @@ class DefaultCompiler(engine.Compiled):
def visit_grouping(self, grouping, **kwargs):
return "(" + self.process(grouping.elem) + ")"
- def visit_label(self, label, typemap=None, column_labels=None):
+ def visit_label(self, label, result_map=None):
labelname = self._truncated_identifier("colident", label.name)
- if typemap is not None:
- self.typemap.setdefault(labelname.lower(), label.obj.type)
+ if result_map is not None:
+ result_map[labelname] = (label.name, (label, label.obj), label.obj.type)
- if column_labels is not None:
- if isinstance(label.obj, sql._ColumnClause):
- column_labels[label.obj._label] = labelname
- column_labels[label.name] = labelname
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column, typemap=None, column_labels=None, **kwargs):
+ def visit_column(self, column, result_map=None, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
@@ -236,15 +230,13 @@ class DefaultCompiler(engine.Compiled):
else:
name = column.name
- if typemap is not None:
- typemap.setdefault(name.lower(), column.type)
- if column_labels is not None:
- self.column_labels.setdefault(column._label, name.lower())
+ if result_map is not None:
+ result_map[name] = (name, (column, ), column.type)
if column._is_oid:
n = self.dialect.oid_column_name(column)
if n is not None:
- if column.table is None or not column.table.named_with_column():
+ if column.table is None or not column.table.named_with_column:
return n
else:
return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
@@ -254,7 +246,7 @@ class DefaultCompiler(engine.Compiled):
return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
else:
return None
- elif column.table is None or not column.table.named_with_column():
+ elif column.table is None or not column.table.named_with_column:
if getattr(column, "is_literal", False):
return name
else:
@@ -277,8 +269,9 @@ class DefaultCompiler(engine.Compiled):
def visit_textclause(self, textclause, **kwargs):
if textclause.typemap is not None:
- self.typemap.update(textclause.typemap)
-
+ for colname, type_ in textclause.typemap.iteritems():
+ self.result_map[colname] = (colname, None, type_)
+
def do_bindparam(m):
name = m.group(1)
if name in textclause.bindparams:
@@ -302,7 +295,7 @@ class DefaultCompiler(engine.Compiled):
sep = ', '
else:
sep = " " + self.operator_string(clauselist.operator) + " "
- return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
+ return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None])
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
@@ -310,12 +303,13 @@ class DefaultCompiler(engine.Compiled):
def visit_calculatedclause(self, clause, **kwargs):
return self.process(clause.clause_expr)
- def visit_cast(self, cast, typemap=None, **kwargs):
+ def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func, typemap=None, **kwargs):
- if typemap is not None:
- typemap.setdefault(func.name, func.type)
+ def visit_function(self, func, result_map=None, **kwargs):
+ if result_map is not None:
+ result_map[func.name] = (func.name, None, func.type)
+
if not self.apply_function_parens(func):
return ".".join(func.packagenames + [func.name])
else:
@@ -325,7 +319,7 @@ class DefaultCompiler(engine.Compiled):
stack_entry = {'select':cs}
if asfrom:
- stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ stack_entry['is_subquery'] = True
elif self.stack and self.stack[-1].get('select'):
stack_entry['is_subquery'] = True
self.stack.append(stack_entry)
@@ -353,7 +347,7 @@ class DefaultCompiler(engine.Compiled):
s = s + " " + self.operator_string(unary.modifier)
return s
- def visit_binary(self, binary, typemap=None, **kwargs):
+ def visit_binary(self, binary, **kwargs):
op = self.operator_string(binary.operator)
if callable(op):
return op(self.process(binary.left), self.process(binary.right))
@@ -438,22 +432,17 @@ class DefaultCompiler(engine.Compiled):
else:
return self.process(alias.original, **kwargs)
- def label_select_column(self, select, column):
- """convert a column from a select's "columns" clause.
+ def label_select_column(self, select, column, asfrom):
+ """label columns present in a select()."""
- given a select() and a column element from its inner_columns collection, return a
- Label object if this column should be labeled in the columns clause. Otherwise,
- return None and the column will be used as-is.
-
- The calling method will traverse the returned label to acquire its string
- representation.
- """
-
- # SQLite doesnt like selecting from a subquery where the column
- # names look like table.colname. so if column is in a "selected from"
- # subquery, label it synoymously with its column name
+ if isinstance(column, sql._Label):
+ return column
+
+ if select.use_labels and column._label:
+ return column.label(column._label)
+
if \
- (self.stack and self.stack[-1].get('is_selected_from')) and \
+ asfrom and \
isinstance(column, sql._ColumnClause) and \
not column.is_literal and \
column.table is not None and \
@@ -462,20 +451,20 @@ class DefaultCompiler(engine.Compiled):
elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'):
return column.label(None)
else:
- return None
+ return column
def visit_select(self, select, asfrom=False, parens=True, **kwargs):
stack_entry = {'select':select}
if asfrom:
- stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ stack_entry['is_subquery'] = True
column_clause_args = {}
elif self.stack and 'select' in self.stack[-1]:
stack_entry['is_subquery'] = True
column_clause_args = {}
else:
- column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels}
+ column_clause_args = {'result_map':self.result_map}
if self.stack and 'from' in self.stack[-1]:
existingfroms = self.stack[-1]['from']
@@ -487,8 +476,7 @@ class DefaultCompiler(engine.Compiled):
correlate_froms = util.Set()
for f in froms:
correlate_froms.add(f)
- for f2 in f._get_from_objects():
- correlate_froms.add(f2)
+ correlate_froms.update(f._get_from_objects())
# TODO: might want to propigate existing froms for select(select(select))
# where innermost select should correlate to outermost
@@ -501,19 +489,8 @@ class DefaultCompiler(engine.Compiled):
inner_columns = util.OrderedSet()
for co in select.inner_columns:
- if select.use_labels:
- labelname = co._label
- if labelname is not None:
- l = co.label(labelname)
- inner_columns.add(self.process(l, **column_clause_args))
- else:
- inner_columns.add(self.process(co, **column_clause_args))
- else:
- l = self.label_select_column(select, co)
- if l is not None:
- inner_columns.add(self.process(l, **column_clause_args))
- else:
- inner_columns.add(self.process(co, **column_clause_args))
+ l = self.label_select_column(select, co, asfrom=asfrom)
+ inner_columns.add(self.process(l, **column_clause_args))
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b3200a7eb..039145006 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1522,6 +1522,7 @@ class FromClause(Selectable):
"""Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
__visit_name__ = 'fromclause'
+ named_with_column=False
def __init__(self):
self.oid_column = None
@@ -1562,13 +1563,6 @@ class FromClause(Selectable):
return Alias(self, name)
- def named_with_column(self):
- """True if the name of this FromClause may be prepended to a
- column in a generated SQL statement.
- """
-
- return False
-
def is_derived_from(self, fromclause):
"""Return True if this FromClause is 'derived' from the given FromClause.
@@ -2379,6 +2373,8 @@ class Alias(FromClause):
``FromClause`` subclasses.
"""
+ named_with_column = True
+
def __init__(self, selectable, alias=None):
baseselectable = selectable
while isinstance(baseselectable, Alias):
@@ -2386,7 +2382,7 @@ class Alias(FromClause):
self.original = baseselectable
self.selectable = selectable
if alias is None:
- if self.original.named_with_column():
+ if self.original.named_with_column:
alias = getattr(self.original, 'name', None)
alias = '{ANON %d %s}' % (id(self), alias or 'anon')
self.name = alias
@@ -2408,9 +2404,6 @@ class Alias(FromClause):
def _table_iterator(self):
return self.original._table_iterator()
- def named_with_column(self):
- return True
-
def _exportable_columns(self):
#return self.selectable._exportable_columns()
return self.selectable.columns
@@ -2602,7 +2595,7 @@ class _ColumnClause(ColumnElement):
if self.is_literal:
return None
if self.__label is None:
- if self.table is not None and self.table.named_with_column():
+ if self.table is not None and self.table.named_with_column:
self.__label = self.table.name + "_" + self.name
counter = 1
while self.__label in self.table.c:
@@ -2652,6 +2645,8 @@ class TableClause(FromClause):
functionality.
"""
+ named_with_column = True
+
def __init__(self, name, *columns):
super(TableClause, self).__init__()
self.name = self.fullname = name
@@ -2666,9 +2661,6 @@ class TableClause(FromClause):
# TableClause is immutable
return self
- def named_with_column(self):
- return True
-
def append_column(self, c):
self._columns[c.name] = c
c.table = self
@@ -3041,16 +3033,14 @@ class Select(_SelectBaseMixin, FromClause):
froms = froms.difference(hide_froms)
if len(froms) > 1:
- corr = self.__correlate
+ if self.__correlate:
+ froms = froms.difference(self.__correlate)
if self._should_correlate and existing_froms is not None:
- corr.update(existing_froms)
+ froms = froms.difference(existing_froms)
- f = froms.difference(corr)
- if not f:
+ if not froms:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
- return f
- else:
- return froms
+ return froms
froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")