diff options
-rw-r--r-- | CHANGES | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 7 | ||||
-rw-r--r-- | test/orm/query.py | 19 |
3 files changed, 30 insertions, 2 deletions
@@ -73,7 +73,11 @@ CHANGES - query doesn't throw an error if you use distinct() and an order_by() containing UnaryExpressions (or other) together [ticket:848] - + + - fixed error where Query.add_column() would not accept a class-bound + attribute as an argument; Query also raises an error if an invalid + argument was sent to add_column() (at instances() time) [ticket:858] + - The session API has been solidified: - It's an error to session.save() an object which is already diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 753e735d1..300d4a549 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -240,6 +240,10 @@ class Query(object): q = self._clone() + # duck type to get a ClauseElement + if hasattr(column, 'clause_element'): + column = column.clause_element() + # alias non-labeled column elements. if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'): column = column.label(None) @@ -682,6 +686,9 @@ class Query(object): res.append(row_adapter(row)[m]) process.append((proc, res)) y(m) + else: + raise exceptions.InvalidRequestError("Invalid column expression '%r'" % m) + result = [] else: result = util.UniqueAppender([]) diff --git a/test/orm/query.py b/test/orm/query.py index a8790d2cf..e736d49e4 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -721,8 +721,24 @@ class InstancesTest(QueryTest): assert l == [(user8, address3)] def test_multi_columns(self): + sess = create_session() + + expected = [(u, u.name) for u in sess.query(User).all()] + + for add_col in (User.name, users.c.name, User.c.name): + assert sess.query(User).add_column(add_col).all() == expected + + try: + sess.query(User).add_column(object()).all() + assert False + except exceptions.InvalidRequestError, e: + assert "Invalid column expression" in str(e) + + + def test_multi_columns_2(self): """test aliased/nonalised joins with the usage of add_column()""" sess = create_session() + (user7, user8, user9, user10) = sess.query(User).all() expected = [(user7, 1), (user8, 3), @@ -740,7 +756,8 @@ class InstancesTest(QueryTest): q = sess.query(User) l = q.add_column("count").from_statement(s).all() assert l == expected - + + def test_two_columns(self): sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() |