summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-11-25 03:28:49 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-11-25 03:28:49 +0000
commitcf18eecd704f5eb6fde4e0c362cfdb322e3e559a (patch)
tree9b6e4503802cf0fcae9171b3ac2f85ba0af453c7
parent6f604f911640d92f705fc6611bfaa3e2600c4ee1 (diff)
downloadsqlalchemy-cf18eecd704f5eb6fde4e0c362cfdb322e3e559a.tar.gz
- named_with_column becomes an attribute
- cleanup within compiler visit_select(), column labeling - is_select() removed from dialects, replaced with returns_rows_text(), returns_rows_compiled() - should_autocommit() removed from dialects, replaced with should_autocommit_text() and should_autocommit_compiled() - typemap and column_labels collections removed from Compiler, replaced with single "result_map" collection. - ResultProxy uses more succinct logic in combination with result_map to target columns
-rw-r--r--lib/sqlalchemy/databases/access.py6
-rw-r--r--lib/sqlalchemy/databases/informix.py9
-rw-r--r--lib/sqlalchemy/databases/mssql.py8
-rw-r--r--lib/sqlalchemy/databases/mysql.py14
-rw-r--r--lib/sqlalchemy/databases/postgres.py18
-rw-r--r--lib/sqlalchemy/databases/sqlite.py7
-rw-r--r--lib/sqlalchemy/databases/sybase.py12
-rw-r--r--lib/sqlalchemy/engine/base.py76
-rw-r--r--lib/sqlalchemy/engine/default.py34
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/schema.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py101
-rw-r--r--lib/sqlalchemy/sql/expression.py34
-rw-r--r--test/dialect/postgres.py3
-rw-r--r--test/profiling/compiler.py2
-rw-r--r--test/profiling/zoomark.py12
16 files changed, 163 insertions, 179 deletions
diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py
index d57c9fa9f..354a8c332 100644
--- a/lib/sqlalchemy/databases/access.py
+++ b/lib/sqlalchemy/databases/access.py
@@ -356,11 +356,11 @@ class AccessCompiler(compiler.DefaultCompiler):
"""Access uses "mod" instead of "%" """
return binary.operator == '%' and 'mod' or binary.operator
- def label_select_column(self, select, column):
+ def label_select_column(self, select, column, asfrom):
if isinstance(column, expression._Function):
- return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+ return column.label()
else:
- return super(AccessCompiler, self).label_select_column(select, column)
+ return super(AccessCompiler, self).label_select_column(select, column, asfrom)
function_rewrites = {'current_date': 'now',
'current_timestamp': 'now',
diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py
index 247ab2d41..6b01bfc22 100644
--- a/lib/sqlalchemy/databases/informix.py
+++ b/lib/sqlalchemy/databases/informix.py
@@ -409,15 +409,6 @@ class InfoCompiler(compiler.DefaultCompiler):
def limit_clause(self, select):
return ""
- def __visit_label(self, label):
- # TODO: whats this method for ?
- if self.select_stack:
- self.typemap.setdefault(label.name.lower(), label.obj.type)
- if self.strings[label.obj]:
- self.strings[label] = self.strings[label.obj] + " AS " + label.name
- else:
- self.strings[label] = None
-
def visit_function( self , func ):
if func.name.lower() == 'current_date':
return "today"
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 672f8d77c..469355083 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -339,8 +339,8 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
_ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)',
re.I | re.UNICODE)
- def is_select(self):
- return self._ms_is_select.match(self.statement) is not None
+ def returns_rows_text(self, statement):
+ return self._ms_is_select.match(statement) is not None
class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
@@ -910,11 +910,11 @@ class MSSQLCompiler(compiler.DefaultCompiler):
else:
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
- def label_select_column(self, select, column):
+ def label_select_column(self, select, column, asfrom):
if isinstance(column, expression._Function):
return column.label(None)
else:
- return super(MSSQLCompiler, self).label_select_column(select, column)
+ return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index 39bfc0bea..03b9a749c 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -1378,9 +1378,6 @@ def descriptor():
class MySQLExecutionContext(default.DefaultExecutionContext):
- _my_is_select = re.compile(r'\s*(?:SELECT|SHOW|DESCRIBE|XA +RECOVER)',
- re.I | re.UNICODE)
-
def post_exec(self):
if self.compiled.isinsert and not self.executemany:
if (not len(self._last_inserted_ids) or
@@ -1388,11 +1385,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
self._last_inserted_ids = ([self.cursor.lastrowid] +
self._last_inserted_ids[1:])
- def is_select(self):
- return SELECT_RE.match(self.statement)
+ def returns_rows_text(self, statement):
+ return SELECT_RE.match(statement)
- def should_autocommit(self):
- return AUTOCOMMIT_RE.match(self.statement)
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_RE.match(statement)
class MySQLDialect(default.DefaultDialect):
@@ -1873,9 +1870,6 @@ class MySQLCompiler(compiler.DefaultCompiler):
if type_ is None:
return self.process(cast.clause)
- if self.stack and self.stack[-1].get('select'):
- # not sure if we want to set the typemap here...
- self.typemap.setdefault("CAST", cast.type)
return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 88ac0e202..1cae31b53 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -233,16 +233,24 @@ RETURNING_QUOTED_RE = re.compile(
class PGExecutionContext(default.DefaultExecutionContext):
- def is_select(self):
- m = SELECT_RE.match(self.statement)
- return m and (not m.group(1) or (RETURNING_RE.search(self.statement)
- and RETURNING_QUOTED_RE.match(self.statement)))
+ def returns_rows_text(self, statement):
+ m = SELECT_RE.match(statement)
+ return m and (not m.group(1) or (RETURNING_RE.search(statement)
+ and RETURNING_QUOTED_RE.match(statement)))
+
+ def returns_rows_compiled(self, compiled):
+ return isinstance(compiled.statement, expression.Selectable) or \
+ (
+ (compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs
+ )
def create_cursor(self):
# executing a default or Sequence standalone creates an execution context without a statement.
# so slightly hacky "if no statement assume we're server side" logic
+ # TODO: dont use regexp if Compiled is used ?
self.__is_server_side = \
- self.dialect.server_side_cursors and (self.statement is None or \
+ self.dialect.server_side_cursors and \
+ (self.statement is None or \
(SELECT_RE.match(self.statement) and not re.search(r'FOR UPDATE(?: NOWAIT)?\s*$', self.statement, re.I))
)
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 19d0855ff..16dd9427c 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -185,8 +185,8 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
- def is_select(self):
- return SELECT_REGEXP.match(self.statement)
+ def returns_rows_text(self, statement):
+ return SELECT_REGEXP.match(statement)
class SQLiteDialect(default.DefaultDialect):
supports_alter = False
@@ -343,9 +343,6 @@ class SQLiteCompiler(compiler.DefaultCompiler):
if self.dialect.supports_cast:
return super(SQLiteCompiler, self).visit_cast(cast)
else:
- if self.stack and self.stack[-1].get('select'):
- # not sure if we want to set the typemap here...
- self.typemap.setdefault("CAST", cast.type)
return self.process(cast.clause)
def limit_clause(self, select):
diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py
index 87045d192..2209594ed 100644
--- a/lib/sqlalchemy/databases/sybase.py
+++ b/lib/sqlalchemy/databases/sybase.py
@@ -778,11 +778,11 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
else:
return super(SybaseSQLCompiler, self).visit_binary(binary)
- def label_select_column(self, select, column):
+ def label_select_column(self, select, column, asfrom):
if isinstance(column, expression._Function):
- return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:])
+ return column.label(None)
else:
- return super(SybaseSQLCompiler, self).label_select_column(select, column)
+ return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
function_rewrites = {'current_date': 'getdate',
}
@@ -795,13 +795,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
cast = expression._Cast(func, SybaseDate_mxodbc)
# infinite recursion
# res = self.visit_cast(cast)
- if self.stack and self.stack[-1].get('select'):
- # not sure if we want to set the typemap here...
- self.typemap.setdefault("CAST", cast.type)
-# res = "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
-# elif func.name.lower() == 'count':
-# res = 'count(*)'
return res
def for_update_clause(self, select):
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 21977b689..9e3004325 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -315,6 +315,12 @@ class ExecutionContext(object):
isupdate
True if the statement is an UPDATE.
+ should_autocommit
+ True if the statement is a "committable" statement
+
+ returns_rows
+ True if the statement should return result rows
+
The Dialect should provide an ExecutionContext via the
create_execution_context() method. The `pre_exec` and `post_exec`
methods will be called for compiled statements.
@@ -363,8 +369,13 @@ class ExecutionContext(object):
raise NotImplementedError()
- def should_autocommit(self):
- """Return True if this context's statement should be 'committed' automatically in a non-transactional context"""
+ def should_autocommit_compiled(self, compiled):
+ """return True if the given Compiled object refers to a "committable" statement."""
+
+ raise NotImplementedError()
+
+ def should_autocommit_text(self, statement):
+ """Parse the given textual statement and return True if it refers to a "committable" statement"""
raise NotImplementedError()
@@ -750,7 +761,7 @@ class Connection(Connectable):
# TODO: have the dialect determine if autocommit can be set on
# the connection directly without this extra step
- if not self.in_transaction() and context.should_autocommit():
+ if not self.in_transaction() and context.should_autocommit:
self._commit_impl()
def _autorollback(self):
@@ -1305,7 +1316,7 @@ class ResultProxy(object):
self.cursor = context.cursor
self.connection = context.root_connection
self.__echo = context.engine._should_log_info
- if context.is_select():
+ if context.returns_rows:
self._init_metadata()
self._rowcount = None
else:
@@ -1322,8 +1333,6 @@ class ResultProxy(object):
out_parameters = property(lambda s:s.context.out_parameters)
def _init_metadata(self):
- if hasattr(self, '_ResultProxy__props'):
- return
self.__props = {}
self._key_cache = self._create_key_cache()
self.__keys = []
@@ -1336,20 +1345,24 @@ class ResultProxy(object):
# sqlite possibly prepending table name to colnames so strip
colname = (item[0].split('.')[-1]).decode(self.dialect.encoding)
- if self.context.typemap is not None:
- type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
+ if self.context.result_map:
+ try:
+ (name, obj, type_) = self.context.result_map[colname]
+ except KeyError:
+ (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
else:
- type = typemap.get(item[1], types.NULLTYPE)
+ (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
- rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i)
+ rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
- if rec[0] is None:
- raise exceptions.InvalidRequestError(
- "None for metadata " + colname)
- if self.__props.setdefault(colname.lower(), rec) is not rec:
- self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0)
+ if self.__props.setdefault(name.lower(), rec) is not rec:
+ self.__props[name.lower()] = (type_, self.__ambiguous_processor(colname), 0)
+
self.__keys.append(colname)
self.__props[i] = rec
+ if obj:
+ for o in obj:
+ self.__props[o] = rec
if self.__echo:
self.context.engine.logger.debug("Col " + repr(tuple([x[0] for x in metadata])))
@@ -1362,16 +1375,19 @@ class ResultProxy(object):
"""Given a key, which could be a ColumnElement, string, etc.,
matches it to the appropriate key we got from the result set's
metadata; then cache it locally for quick re-access."""
-
- if isinstance(key, int) and key in props:
+
+ if isinstance(key, basestring):
+ key = key.lower()
+
+ try:
rec = props[key]
- elif isinstance(key, basestring) and key.lower() in props:
- rec = props[key.lower()]
- elif isinstance(key, expression.ColumnElement):
- label = context.column_labels.get(key._label, key.name).lower()
- if label in props:
- rec = props[label]
- if not "rec" in locals():
+ except KeyError:
+ # fallback for targeting a ColumnElement to a textual expression
+ if isinstance(key, expression.ColumnElement):
+ if key._label.lower() in props:
+ return props[key._label.lower()]
+ elif key.name.lower() in props:
+ return props[key.name.lower()]
raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
return rec
@@ -1470,18 +1486,20 @@ class ResultProxy(object):
def _get_col(self, row, key):
try:
- rec = self._key_cache[key]
+ type_, processor, index = self._key_cache[key]
except TypeError:
# the 'slice' use case is very infrequent,
# so we use an exception catch to reduce conditionals in _get_col
if isinstance(key, slice):
indices = key.indices(len(row))
return tuple([self._get_col(row, i) for i in xrange(*indices)])
-
- if rec[1]:
- return rec[1](row[rec[2]])
+ else:
+ raise
+
+ if processor:
+ return processor(row[index])
else:
- return row[rec[2]]
+ return row[index]
def _fetchone_impl(self):
return self.cursor.fetchone()
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index a91d65b81..19ab22c9e 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -146,9 +146,8 @@ class DefaultExecutionContext(base.ExecutionContext):
if value is not None
])
- self.typemap = compiled.typemap
- self.column_labels = compiled.column_labels
-
+ self.result_map = compiled.result_map
+
if not dialect.supports_unicode_statements:
self.statement = unicode(compiled).encode(self.dialect.encoding)
else:
@@ -156,6 +155,12 @@ class DefaultExecutionContext(base.ExecutionContext):
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
+ if isinstance(compiled.statement, expression._TextClause):
+ self.returns_rows = self.returns_rows_text(self.statement)
+ self.should_autocommit = self.should_autocommit_text(self.statement)
+ else:
+ self.returns_rows = self.returns_rows_compiled(compiled)
+ self.should_autocommit = self.should_autocommit_compiled(compiled)
if not parameters:
self.compiled_parameters = [compiled.construct_params()]
@@ -170,7 +175,7 @@ class DefaultExecutionContext(base.ExecutionContext):
elif statement is not None:
# plain text statement.
- self.typemap = self.column_labels = None
+ self.result_map = None
self.parameters = self.__encode_param_keys(parameters)
self.executemany = len(parameters) > 1
if not dialect.supports_unicode_statements:
@@ -179,10 +184,12 @@ class DefaultExecutionContext(base.ExecutionContext):
self.statement = statement
self.isinsert = self.isupdate = False
self.cursor = self.create_cursor()
+ self.returns_rows = self.returns_rows_text(statement)
+ self.should_autocommit = self.should_autocommit_text(statement)
else:
# no statement. used for standalone ColumnDefault execution.
self.statement = None
- self.isinsert = self.isupdate = self.executemany = False
+ self.isinsert = self.isupdate = self.executemany = self.returns_rows = self.should_autocommit = False
self.cursor = self.create_cursor()
connection = property(lambda s:s._connection._branch())
@@ -244,10 +251,18 @@ class DefaultExecutionContext(base.ExecutionContext):
parameters.append(param)
return parameters
- def is_select(self):
- """return TRUE if the statement is expected to have result rows."""
+ def returns_rows_compiled(self, compiled):
+ return isinstance(compiled.statement, expression.Selectable)
- return SELECT_REGEXP.match(self.statement)
+ def returns_rows_text(self, statement):
+ return SELECT_REGEXP.match(statement)
+
+ def should_autocommit_compiled(self, compiled):
+ return isinstance(compiled.statement, expression._UpdateBase)
+
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_REGEXP.match(statement)
+
def create_cursor(self):
return self._connection.connection.cursor()
@@ -261,9 +276,6 @@ class DefaultExecutionContext(base.ExecutionContext):
def result(self):
return self.get_result_proxy()
- def should_autocommit(self):
- return AUTOCOMMIT_REGEXP.match(self.statement)
-
def pre_exec(self):
pass
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 121402584..3daf11ed0 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -249,7 +249,7 @@ class Query(object):
# alias non-labeled column elements.
if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
column = column.label(None)
-
+
q._entities = q._entities + [(column, None, id)]
return q
@@ -887,7 +887,7 @@ class Query(object):
context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
elif isinstance(m, sql.ColumnElement):
if clauses is not None:
- m = clauses.adapt_clause(m)
+ m = clauses.aliased_column(m)
context.secondary_columns.append(m)
if self._eager_loaders and self._nestable(**self._select_args()):
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 0e1e5f7a9..817981003 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -456,7 +456,7 @@ class Column(SchemaItem, expression._ColumnClause):
def __str__(self):
if self.table is not None:
- if self.table.named_with_column():
+ if self.table.named_with_column:
return (self.table.description + "." + self.description)
else:
return self.description
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index c1f3bc2a0..a31997d1b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -130,13 +130,11 @@ class DefaultCompiler(engine.Compiled):
# a stack. what recursive compiler doesn't have a stack ? :)
self.stack = []
- # a dictionary of result-set column names (strings) to TypeEngine instances,
- # which will be passed to a ResultProxy and used for resultset-level value conversion
- self.typemap = {}
-
- # a dictionary of select columns labels mapped to their "generated" label
- self.column_labels = {}
-
+ # relates label names in the final SQL to
+ # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
+ # ResultProxy uses this for type processing and column targeting
+ self.result_map = {}
+
# a dictionary of ClauseElement subclasses to counters, which are used to
# generate truncated identifier names or "anonymous" identifiers such as
# for aliases
@@ -213,19 +211,15 @@ class DefaultCompiler(engine.Compiled):
def visit_grouping(self, grouping, **kwargs):
return "(" + self.process(grouping.elem) + ")"
- def visit_label(self, label, typemap=None, column_labels=None):
+ def visit_label(self, label, result_map=None):
labelname = self._truncated_identifier("colident", label.name)
- if typemap is not None:
- self.typemap.setdefault(labelname.lower(), label.obj.type)
+ if result_map is not None:
+ result_map[labelname] = (label.name, (label, label.obj), label.obj.type)
- if column_labels is not None:
- if isinstance(label.obj, sql._ColumnClause):
- column_labels[label.obj._label] = labelname
- column_labels[label.name] = labelname
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
- def visit_column(self, column, typemap=None, column_labels=None, **kwargs):
+ def visit_column(self, column, result_map=None, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
# want to truncate a column identifier, if its mapped to the name of a
# physical column. but thats very hard to identify at this point, and
@@ -236,15 +230,13 @@ class DefaultCompiler(engine.Compiled):
else:
name = column.name
- if typemap is not None:
- typemap.setdefault(name.lower(), column.type)
- if column_labels is not None:
- self.column_labels.setdefault(column._label, name.lower())
+ if result_map is not None:
+ result_map[name] = (name, (column, ), column.type)
if column._is_oid:
n = self.dialect.oid_column_name(column)
if n is not None:
- if column.table is None or not column.table.named_with_column():
+ if column.table is None or not column.table.named_with_column:
return n
else:
return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
@@ -254,7 +246,7 @@ class DefaultCompiler(engine.Compiled):
return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
else:
return None
- elif column.table is None or not column.table.named_with_column():
+ elif column.table is None or not column.table.named_with_column:
if getattr(column, "is_literal", False):
return name
else:
@@ -277,8 +269,9 @@ class DefaultCompiler(engine.Compiled):
def visit_textclause(self, textclause, **kwargs):
if textclause.typemap is not None:
- self.typemap.update(textclause.typemap)
-
+ for colname, type_ in textclause.typemap.iteritems():
+ self.result_map[colname] = (colname, None, type_)
+
def do_bindparam(m):
name = m.group(1)
if name in textclause.bindparams:
@@ -302,7 +295,7 @@ class DefaultCompiler(engine.Compiled):
sep = ', '
else:
sep = " " + self.operator_string(clauselist.operator) + " "
- return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
+ return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None])
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
@@ -310,12 +303,13 @@ class DefaultCompiler(engine.Compiled):
def visit_calculatedclause(self, clause, **kwargs):
return self.process(clause.clause_expr)
- def visit_cast(self, cast, typemap=None, **kwargs):
+ def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
- def visit_function(self, func, typemap=None, **kwargs):
- if typemap is not None:
- typemap.setdefault(func.name, func.type)
+ def visit_function(self, func, result_map=None, **kwargs):
+ if result_map is not None:
+ result_map[func.name] = (func.name, None, func.type)
+
if not self.apply_function_parens(func):
return ".".join(func.packagenames + [func.name])
else:
@@ -325,7 +319,7 @@ class DefaultCompiler(engine.Compiled):
stack_entry = {'select':cs}
if asfrom:
- stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ stack_entry['is_subquery'] = True
elif self.stack and self.stack[-1].get('select'):
stack_entry['is_subquery'] = True
self.stack.append(stack_entry)
@@ -353,7 +347,7 @@ class DefaultCompiler(engine.Compiled):
s = s + " " + self.operator_string(unary.modifier)
return s
- def visit_binary(self, binary, typemap=None, **kwargs):
+ def visit_binary(self, binary, **kwargs):
op = self.operator_string(binary.operator)
if callable(op):
return op(self.process(binary.left), self.process(binary.right))
@@ -438,22 +432,17 @@ class DefaultCompiler(engine.Compiled):
else:
return self.process(alias.original, **kwargs)
- def label_select_column(self, select, column):
- """convert a column from a select's "columns" clause.
+ def label_select_column(self, select, column, asfrom):
+ """label columns present in a select()."""
- given a select() and a column element from its inner_columns collection, return a
- Label object if this column should be labeled in the columns clause. Otherwise,
- return None and the column will be used as-is.
-
- The calling method will traverse the returned label to acquire its string
- representation.
- """
-
- # SQLite doesnt like selecting from a subquery where the column
- # names look like table.colname. so if column is in a "selected from"
- # subquery, label it synoymously with its column name
+ if isinstance(column, sql._Label):
+ return column
+
+ if select.use_labels and column._label:
+ return column.label(column._label)
+
if \
- (self.stack and self.stack[-1].get('is_selected_from')) and \
+ asfrom and \
isinstance(column, sql._ColumnClause) and \
not column.is_literal and \
column.table is not None and \
@@ -462,20 +451,20 @@ class DefaultCompiler(engine.Compiled):
elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'):
return column.label(None)
else:
- return None
+ return column
def visit_select(self, select, asfrom=False, parens=True, **kwargs):
stack_entry = {'select':select}
if asfrom:
- stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ stack_entry['is_subquery'] = True
column_clause_args = {}
elif self.stack and 'select' in self.stack[-1]:
stack_entry['is_subquery'] = True
column_clause_args = {}
else:
- column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels}
+ column_clause_args = {'result_map':self.result_map}
if self.stack and 'from' in self.stack[-1]:
existingfroms = self.stack[-1]['from']
@@ -487,8 +476,7 @@ class DefaultCompiler(engine.Compiled):
correlate_froms = util.Set()
for f in froms:
correlate_froms.add(f)
- for f2 in f._get_from_objects():
- correlate_froms.add(f2)
+ correlate_froms.update(f._get_from_objects())
# TODO: might want to propigate existing froms for select(select(select))
# where innermost select should correlate to outermost
@@ -501,19 +489,8 @@ class DefaultCompiler(engine.Compiled):
inner_columns = util.OrderedSet()
for co in select.inner_columns:
- if select.use_labels:
- labelname = co._label
- if labelname is not None:
- l = co.label(labelname)
- inner_columns.add(self.process(l, **column_clause_args))
- else:
- inner_columns.add(self.process(co, **column_clause_args))
- else:
- l = self.label_select_column(select, co)
- if l is not None:
- inner_columns.add(self.process(l, **column_clause_args))
- else:
- inner_columns.add(self.process(co, **column_clause_args))
+ l = self.label_select_column(select, co, asfrom=asfrom)
+ inner_columns.add(self.process(l, **column_clause_args))
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b3200a7eb..039145006 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1522,6 +1522,7 @@ class FromClause(Selectable):
"""Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
__visit_name__ = 'fromclause'
+ named_with_column=False
def __init__(self):
self.oid_column = None
@@ -1562,13 +1563,6 @@ class FromClause(Selectable):
return Alias(self, name)
- def named_with_column(self):
- """True if the name of this FromClause may be prepended to a
- column in a generated SQL statement.
- """
-
- return False
-
def is_derived_from(self, fromclause):
"""Return True if this FromClause is 'derived' from the given FromClause.
@@ -2379,6 +2373,8 @@ class Alias(FromClause):
``FromClause`` subclasses.
"""
+ named_with_column = True
+
def __init__(self, selectable, alias=None):
baseselectable = selectable
while isinstance(baseselectable, Alias):
@@ -2386,7 +2382,7 @@ class Alias(FromClause):
self.original = baseselectable
self.selectable = selectable
if alias is None:
- if self.original.named_with_column():
+ if self.original.named_with_column:
alias = getattr(self.original, 'name', None)
alias = '{ANON %d %s}' % (id(self), alias or 'anon')
self.name = alias
@@ -2408,9 +2404,6 @@ class Alias(FromClause):
def _table_iterator(self):
return self.original._table_iterator()
- def named_with_column(self):
- return True
-
def _exportable_columns(self):
#return self.selectable._exportable_columns()
return self.selectable.columns
@@ -2602,7 +2595,7 @@ class _ColumnClause(ColumnElement):
if self.is_literal:
return None
if self.__label is None:
- if self.table is not None and self.table.named_with_column():
+ if self.table is not None and self.table.named_with_column:
self.__label = self.table.name + "_" + self.name
counter = 1
while self.__label in self.table.c:
@@ -2652,6 +2645,8 @@ class TableClause(FromClause):
functionality.
"""
+ named_with_column = True
+
def __init__(self, name, *columns):
super(TableClause, self).__init__()
self.name = self.fullname = name
@@ -2666,9 +2661,6 @@ class TableClause(FromClause):
# TableClause is immutable
return self
- def named_with_column(self):
- return True
-
def append_column(self, c):
self._columns[c.name] = c
c.table = self
@@ -3041,16 +3033,14 @@ class Select(_SelectBaseMixin, FromClause):
froms = froms.difference(hide_froms)
if len(froms) > 1:
- corr = self.__correlate
+ if self.__correlate:
+ froms = froms.difference(self.__correlate)
if self._should_correlate and existing_froms is not None:
- corr.update(existing_froms)
+ froms = froms.difference(existing_froms)
- f = froms.difference(corr)
- if not f:
+ if not froms:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
- return f
- else:
- return froms
+ return froms
froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py
index 82f41f80a..4affabb6c 100644
--- a/test/dialect/postgres.py
+++ b/test/dialect/postgres.py
@@ -101,6 +101,9 @@ class ReturningTest(AssertMixin):
result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
self.assertEqual([dict(row) for row in result3], [{'double_id':8}])
+
+ result4 = testbase.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
+ self.assertEqual([dict(row) for row in result4], [{'persons': 10}])
finally:
table.drop()
diff --git a/test/profiling/compiler.py b/test/profiling/compiler.py
index 544e674f3..6fa4f9659 100644
--- a/test/profiling/compiler.py
+++ b/test/profiling/compiler.py
@@ -24,7 +24,7 @@ class CompileTest(AssertMixin):
t1.update().compile()
# TODO: this is alittle high
- @profiling.profiled('ctest_select', call_range=(130, 150), always=True)
+ @profiling.profiled('ctest_select', call_range=(110, 130), always=True)
def test_select(self):
s = select([t1], t1.c.c2==t2.c.c1)
s.compile()
diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py
index d18502c72..48f0432cb 100644
--- a/test/profiling/zoomark.py
+++ b/test/profiling/zoomark.py
@@ -50,7 +50,7 @@ class ZooMarkTest(testing.AssertMixin):
metadata.create_all()
@testing.supported('postgres')
- @profiling.profiled('populate', call_range=(2800, 3700), always=True)
+ @profiling.profiled('populate', call_range=(2700, 3700), always=True)
def test_1a_populate(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']
@@ -126,7 +126,7 @@ class ZooMarkTest(testing.AssertMixin):
tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8)
@testing.supported('postgres')
- @profiling.profiled('properties', call_range=(2900, 3330), always=True)
+ @profiling.profiled('properties', call_range=(2300, 3030), always=True)
def test_3_properties(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']
@@ -149,7 +149,7 @@ class ZooMarkTest(testing.AssertMixin):
ticks = fullobject(Animal.select(Animal.c.Species=='Tick'))
@testing.supported('postgres')
- @profiling.profiled('expressions', call_range=(10350, 12200), always=True)
+ @profiling.profiled('expressions', call_range=(9200, 12050), always=True)
def test_4_expressions(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']
@@ -203,7 +203,7 @@ class ZooMarkTest(testing.AssertMixin):
assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1
@testing.supported('postgres')
- @profiling.profiled('aggregates', call_range=(960, 1170), always=True)
+ @profiling.profiled('aggregates', call_range=(800, 1170), always=True)
def test_5_aggregates(self):
Animal = metadata.tables['Animal']
Zoo = metadata.tables['Zoo']
@@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin):
legs.sort()
@testing.supported('postgres')
- @profiling.profiled('editing', call_range=(1150, 1280), always=True)
+ @profiling.profiled('editing', call_range=(1050, 1180), always=True)
def test_6_editing(self):
Zoo = metadata.tables['Zoo']
@@ -274,7 +274,7 @@ class ZooMarkTest(testing.AssertMixin):
assert SDZ['Founded'] == datetime.date(1935, 9, 13)
@testing.supported('postgres')
- @profiling.profiled('multiview', call_range=(2300, 2500), always=True)
+ @profiling.profiled('multiview', call_range=(1900, 2300), always=True)
def test_7_multiview(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']