diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 16:46:11 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-07-27 16:46:11 +0000 |
commit | 8db97dad9845b8d405412abbf713d2d22538b9cc (patch) | |
tree | d8531cb3644340707b4248e2721968359e232740 | |
parent | 7252e3d879483cc14be5d1b95004843e69e35aab (diff) | |
download | sqlalchemy-8db97dad9845b8d405412abbf713d2d22538b9cc.tar.gz |
fixed glitch in Select visit traversal, fixes #693
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 2 | ||||
-rw-r--r-- | test/orm/alltests.py | 1 | ||||
-rw-r--r-- | test/orm/selectable.py | 49 | ||||
-rw-r--r-- | test/sql/generative.py | 9 |
5 files changed, 58 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 76cc41289..92b186012 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -407,6 +407,9 @@ class Mapper(object): # may be a join or other construct self.tables = sqlutil.TableFinder(self.mapped_table) + if not len(self.tables): + raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) + # determine primary key columns self.pks_by_table = {} diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 01588e92d..ff92f0b43 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -3146,7 +3146,7 @@ class Select(_SelectBaseMixin, FromClause): def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.columns) or []) + \ - list(self._froms) + \ + list(self.locate_all_froms()) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] def _recorrelate_froms(self, froms): diff --git a/test/orm/alltests.py b/test/orm/alltests.py index 4f8f4b6b7..9fcea8859 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -11,6 +11,7 @@ def suite(): 'orm.lazy_relations', 'orm.eager_relations', 'orm.mapper', + 'orm.selectable', 'orm.collection', 'orm.generative', 'orm.lazytest1', diff --git a/test/orm/selectable.py b/test/orm/selectable.py new file mode 100644 index 000000000..920cd9d8f --- /dev/null +++ b/test/orm/selectable.py @@ -0,0 +1,49 @@ +"""all tests involving generic mapping to Select statements""" + +import testbase +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * +from fixtures import * +from query import QueryTest + +class SelectableNoFromsTest(ORMTest): + def define_tables(self, metadata): + global common_table + common_table = Table('common', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + Column('extra', String(45)), + ) + + def test_no_tables(self): + class Subset(object): + pass + selectable = select(["x", "y", "z"]).alias('foo') + try: + mapper(Subset, selectable) + compile_mappers() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Could not find any Table objects in mapped table 'SELECT x, y, z'", str(e) + + def test_basic(self): + class Subset(Base): + pass + + subset_select = select([common_table.c.id, common_table.c.data]).alias('subset') + subset_mapper = mapper(Subset, subset_select) + + sess = create_session(bind=testbase.db) + l = Subset() + l.data = 1 + sess.save(l) + sess.flush() + sess.clear() + + assert [Subset(data=1)] == sess.query(Subset).all() + + # TODO: more tests mapping to selects + +if __name__ == '__main__': + testbase.main()
\ No newline at end of file diff --git a/test/sql/generative.py b/test/sql/generative.py index 357a66fcd..80a18d497 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -166,10 +166,9 @@ class ClauseTest(selecttests.SQLTest): assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3)) def test_select(self): - s = t1.select() - s2 = select([s]) + s2 = select([t1]) s2_assert = str(s2) - s3_assert = str(select([t1.select()], t1.c.col2==7)) + s3_assert = str(select([t1], t1.c.col2==7)) class Vis(ClauseVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col2==7) @@ -183,7 +182,7 @@ class ClauseTest(selecttests.SQLTest): print "------------------" - s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9))) + s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9))) class Vis(ClauseVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col3==9) @@ -194,7 +193,7 @@ class ClauseTest(selecttests.SQLTest): assert str(s3) == s3_assert print "------------------" - s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9))) + s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9))) class Vis(ClauseVisitor): def visit_binary(self, binary): if binary.left is t1.c.col3: |