diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-08-22 18:41:46 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-08-22 18:41:46 -0400 |
commit | 8f5a31441aed9d223e67d211472445e574fc521f (patch) | |
tree | 5f95780d99cf6c77ffd569d34d709514fa98263c | |
parent | 477dd0f774f1c2f2f3873924ac0606bf499e0061 (diff) | |
download | sqlalchemy-8f5a31441aed9d223e67d211472445e574fc521f.tar.gz |
- [bug] Fixed cextension bug whereby the
"ambiguous column error" would fail to
function properly if the given index were
a Column object and not a string.
Note there are still some column-targeting
issues here which are fixed in 0.8.
[ticket:2553]
- find more cases where column targeting is being inaccurate, add
more information to result_map to better differentiate "ambiguous"
results from "present" or "not present". In particular, result_map
is sensitive to dupes, even though no error is raised; the conflicting
columns are added to the "obj" member of the tuple so that the two
are both directly accessible in the result proxy
- handwringing over the damn "name fallback" thing in results. can't
really make it perfect yet
- fix up oracle returning clause. not sure why its guarding against
labels, remove that for now and see what the bot says.
-rw-r--r-- | CHANGES | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/cextension/resultproxy.c | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/result.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 64 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 3 | ||||
-rw-r--r-- | test/dialect/test_oracle.py | 8 | ||||
-rw-r--r-- | test/sql/test_query.py | 66 |
9 files changed, 140 insertions, 46 deletions
@@ -671,6 +671,14 @@ are also present in 0.8. the absense of which was preventing the new GAE dialect from being loaded. [ticket:2529] + - [bug] Fixed cextension bug whereby the + "ambiguous column error" would fail to + function properly if the given index were + a Column object and not a string. + Note there are still some column-targeting + issues here which are fixed in 0.8. + [ticket:2553] + - [bug] Fixed the repr() of Enum to include the "name" and "native_enum" flags. Helps Alembic autogenerate. diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index 8c89baa25..76c785f85 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -244,7 +244,7 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) PyObject *processors, *values; PyObject *processor, *value, *processed_value; PyObject *row, *record, *result, *indexobject; - PyObject *exc_module, *exception; + PyObject *exc_module, *exception, *cstr_obj; char *cstr_key; long index; int key_fallback = 0; @@ -301,9 +301,16 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) if (exception == NULL) return NULL; - cstr_key = PyString_AsString(key); - if (cstr_key == NULL) + // wow. this seems quite excessive. + cstr_obj = PyObject_Str(key); + if (cstr_obj == NULL) return NULL; + cstr_key = PyString_AsString(cstr_obj); + if (cstr_key == NULL) { + Py_DECREF(cstr_obj); + return NULL; + } + Py_DECREF(cstr_obj); PyErr_Format(exception, "Ambiguous column name '%.200s' in result set! " diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 0dd610788..bccd585d0 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -859,15 +859,15 @@ class MSSQLCompiler(compiler.SQLCompiler): t, column) if add_to_result_map is not None: - self.result_map[column.name - if self.dialect.case_sensitive - else column.name.lower()] = \ - (column.name, (column, ) + add_to_result_map, - column.type) + add_to_result_map( + column.name, + column.name, + (column, ), + column.type + ) return super(MSSQLCompiler, self).\ - visit_column(converted, - result_map=None, **kwargs) + visit_column(converted, **kwargs) return super(MSSQLCompiler, self).visit_column( column, add_to_result_map=add_to_result_map, **kwargs) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 88bc15bcc..eb1d75caa 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -508,12 +508,14 @@ class OracleCompiler(compiler.SQLCompiler): columnlist = list(expression._select_iterables(returning_cols)) - # within_columns_clause =False so that labels (foo AS bar) don't render - columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) for c in columnlist] + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in columnlist + ] binds = [create_out_param(c, i) for i, c in enumerate(columnlist)] - return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 91d161348..a9bb248b3 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -232,7 +232,7 @@ class ResultMetaData(object): # unambiguous. primary_keymap[name if self.case_sensitive - else name.lower()] = (processor, obj, None) + else name.lower()] = rec = (processor, obj, None) self.keys.append(colname) if obj: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fd9718f1f..ee96c8c81 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -375,7 +375,7 @@ class SQLCompiler(engine.Compiled): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" def visit_label(self, label, - add_to_result_map = None, + add_to_result_map=None, within_label_clause=False, within_columns_clause=False, **kw): # only render labels within the columns clause @@ -388,17 +388,12 @@ class SQLCompiler(engine.Compiled): labelname = label.name if add_to_result_map is not None: - self.result_map[ - labelname - if self.dialect.case_sensitive - else labelname.lower() - ] = ( - label.name, - (label, label.element, labelname, ) + - label._alt_names + - add_to_result_map, - label.type, - ) + add_to_result_map( + labelname, + label.name, + (label, label.element, labelname, ) + label._alt_names, + label.type + ) return label.element._compiler_dispatch(self, within_columns_clause=True, @@ -423,15 +418,12 @@ class SQLCompiler(engine.Compiled): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - self.result_map[ - name - if self.dialect.case_sensitive - else name.lower() - ] = ( - orig_name, - (column, name, column.key) + add_to_result_map, - column.type - ) + add_to_result_map( + name, + orig_name, + (column, name, column.key), + column.type + ) if is_literal: name = self.escape_literal_column(name) @@ -557,11 +549,9 @@ class SQLCompiler(engine.Compiled): def visit_function(self, func, add_to_result_map=None, **kwargs): if add_to_result_map is not None: - self.result_map[ - func.name - if self.dialect.case_sensitive - else func.name.lower() - ] = (func.name, add_to_result_map, func.type) + add_to_result_map( + func.name, func.name, (), func.type + ) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: @@ -932,6 +922,20 @@ class SQLCompiler(engine.Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) + def _add_to_result_map(self, keyname, name, objects, type_): + if not self.dialect.case_sensitive: + keyname = keyname.lower() + + if keyname in self.result_map: + # conflicting keyname, just double up the list + # of objects. this will cause an "ambiguous name" + # error if an attempt is made by the result set to + # access. + e_name, e_obj, e_type = self.result_map[keyname] + self.result_map[keyname] = e_name, e_obj + objects, e_type + else: + self.result_map[keyname] = name, objects, type_ + def _label_select_column(self, select, column, populate_result_map, asfrom, column_clause_args): """produce labeled columns present in a select().""" @@ -939,13 +943,16 @@ class SQLCompiler(engine.Compiled): if column.type._has_column_expression: col_expr = column.type.column_expression(column) if populate_result_map: - add_to_result_map = (column, ) + add_to_result_map = lambda keyname, name, objects, type_: \ + self._add_to_result_map( + keyname, name, + objects + (column,), type_) else: add_to_result_map = None else: col_expr = column if populate_result_map: - add_to_result_map = () + add_to_result_map = self._add_to_result_map else: add_to_result_map = None @@ -993,7 +1000,6 @@ class SQLCompiler(engine.Compiled): **column_clause_args ) - def format_from_hint_text(self, sqltext, table, hint, iscrud): hinttext = self.get_from_hint_text(table, hint) if hinttext: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index ce03c9c52..cbc3b47ad 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -4325,8 +4325,9 @@ class ColumnClause(Immutable, ColumnElement): self.is_literal = is_literal def _compare_name_for_result(self, other): + # TODO: this still isn't 100% correct if self.table is not None and hasattr(other, 'proxy_set'): - return other.proxy_set.intersection(self.proxy_set) + return self.proxy_set.intersection(other.proxy_set) else: return super(ColumnClause, self).\ _compare_name_for_result(other) diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index 2a5ab0627..b70358ffd 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -445,6 +445,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'addresses.user_id = :user_id_1 ORDER BY ' 'addresses.id, address_types.id') + def test_returning_insert(self): + t1 = table('t1', column('c1'), column('c2'), column('c3')) + self.assert_compile( + t1.insert().values(c1=1).returning(t1.c.c2, t1.c.c3), + "INSERT INTO t1 (c1) VALUES (:c1) RETURNING " + "t1.c2, t1.c3 INTO :ret_0, :ret_1" + ) + def test_compound(self): t1 = table('t1', column('c1'), column('c2'), column('c3')) t2 = table('t2', column('c1'), column('c2'), column('c3')) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index f54162bda..16a2f254d 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -954,13 +954,36 @@ class QueryTest(fixtures.TestBase): def test_ambiguous_column(self): users.insert().execute(user_id=1, user_name='john') - r = users.outerjoin(addresses).select().execute().first() + result = users.outerjoin(addresses).select().execute() + r = result.first() + assert_raises_message( exc.InvalidRequestError, "Ambiguous column name", lambda: r['user_id'] ) + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: r[users.c.user_id] + ) + + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: r[addresses.c.user_id] + ) + + # try to trick it - fake_table isn't in the result! + # we get the correct error + fake_table = Table('fake', MetaData(), Column('user_id', Integer)) + assert_raises_message( + exc.InvalidRequestError, + "Could not locate column in row for column 'fake.user_id'", + lambda: r[fake_table.c.user_id] + ) + r = util.pickle.loads(util.pickle.dumps(r)) assert_raises_message( exc.InvalidRequestError, @@ -978,6 +1001,41 @@ class QueryTest(fixtures.TestBase): lambda: r['user_id'] ) + def test_ambiguous_column_by_col(self): + users.insert().execute(user_id=1, user_name='john') + ua = users.alias() + u2 = users.alias() + result = select([users.c.user_id, ua.c.user_id]).execute() + row = result.first() + + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: row[users.c.user_id] + ) + + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: row[ua.c.user_id] + ) + + # Unfortunately, this fails - + # we'd like + # "Could not locate column in row" + # to be raised here, but the check for + # "common column" in _compare_name_for_result() + # has other requirements to be more liberal. + # Ultimately the + # expression system would need a way to determine + # if given two columns in a "proxy" relationship, if they + # refer to a different parent table + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: row[u2.c.user_id] + ) + @testing.requires.subqueries def test_column_label_targeting(self): users.insert().execute(user_id=7, user_name='ed') @@ -1365,10 +1423,14 @@ class KeyTargetingTest(fixtures.TablesTest): keyed3 = self.tables.keyed3 row = testing.db.execute(select([keyed1, keyed3])).first() - assert 'b' not in row eq_(row.q, "c1") assert_raises_message( exc.InvalidRequestError, + "Ambiguous column name 'b'", + getattr, row, "b" + ) + assert_raises_message( + exc.InvalidRequestError, "Ambiguous column name 'a'", getattr, row, "a" ) |