summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 16:46:11 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 16:46:11 +0000
commit8db97dad9845b8d405412abbf713d2d22538b9cc (patch)
treed8531cb3644340707b4248e2721968359e232740
parent7252e3d879483cc14be5d1b95004843e69e35aab (diff)
downloadsqlalchemy-8db97dad9845b8d405412abbf713d2d22538b9cc.tar.gz
fixed glitch in Select visit traversal, fixes #693
-rw-r--r--lib/sqlalchemy/orm/mapper.py3
-rw-r--r--lib/sqlalchemy/sql.py2
-rw-r--r--test/orm/alltests.py1
-rw-r--r--test/orm/selectable.py49
-rw-r--r--test/sql/generative.py9
5 files changed, 58 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 76cc41289..92b186012 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -407,6 +407,9 @@ class Mapper(object):
# may be a join or other construct
self.tables = sqlutil.TableFinder(self.mapped_table)
+ if not len(self.tables):
+ raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
+
# determine primary key columns
self.pks_by_table = {}
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 01588e92d..ff92f0b43 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -3146,7 +3146,7 @@ class Select(_SelectBaseMixin, FromClause):
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.columns) or []) + \
- list(self._froms) + \
+ list(self.locate_all_froms()) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
def _recorrelate_froms(self, froms):
diff --git a/test/orm/alltests.py b/test/orm/alltests.py
index 4f8f4b6b7..9fcea8859 100644
--- a/test/orm/alltests.py
+++ b/test/orm/alltests.py
@@ -11,6 +11,7 @@ def suite():
'orm.lazy_relations',
'orm.eager_relations',
'orm.mapper',
+ 'orm.selectable',
'orm.collection',
'orm.generative',
'orm.lazytest1',
diff --git a/test/orm/selectable.py b/test/orm/selectable.py
new file mode 100644
index 000000000..920cd9d8f
--- /dev/null
+++ b/test/orm/selectable.py
@@ -0,0 +1,49 @@
+"""all tests involving generic mapping to Select statements"""
+
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+from fixtures import *
+from query import QueryTest
+
+class SelectableNoFromsTest(ORMTest):
+ def define_tables(self, metadata):
+ global common_table
+ common_table = Table('common', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer),
+ Column('extra', String(45)),
+ )
+
+ def test_no_tables(self):
+ class Subset(object):
+ pass
+ selectable = select(["x", "y", "z"]).alias('foo')
+ try:
+ mapper(Subset, selectable)
+ compile_mappers()
+ assert False
+ except exceptions.InvalidRequestError, e:
+ assert str(e) == "Could not find any Table objects in mapped table 'SELECT x, y, z'", str(e)
+
+ def test_basic(self):
+ class Subset(Base):
+ pass
+
+ subset_select = select([common_table.c.id, common_table.c.data]).alias('subset')
+ subset_mapper = mapper(Subset, subset_select)
+
+ sess = create_session(bind=testbase.db)
+ l = Subset()
+ l.data = 1
+ sess.save(l)
+ sess.flush()
+ sess.clear()
+
+ assert [Subset(data=1)] == sess.query(Subset).all()
+
+ # TODO: more tests mapping to selects
+
+if __name__ == '__main__':
+ testbase.main() \ No newline at end of file
diff --git a/test/sql/generative.py b/test/sql/generative.py
index 357a66fcd..80a18d497 100644
--- a/test/sql/generative.py
+++ b/test/sql/generative.py
@@ -166,10 +166,9 @@ class ClauseTest(selecttests.SQLTest):
assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
def test_select(self):
- s = t1.select()
- s2 = select([s])
+ s2 = select([t1])
s2_assert = str(s2)
- s3_assert = str(select([t1.select()], t1.c.col2==7))
+ s3_assert = str(select([t1], t1.c.col2==7))
class Vis(ClauseVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col2==7)
@@ -183,7 +182,7 @@ class ClauseTest(selecttests.SQLTest):
print "------------------"
- s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9)))
+ s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9)))
class Vis(ClauseVisitor):
def visit_select(self, select):
select.append_whereclause(t1.c.col3==9)
@@ -194,7 +193,7 @@ class ClauseTest(selecttests.SQLTest):
assert str(s3) == s3_assert
print "------------------"
- s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9)))
+ s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9)))
class Vis(ClauseVisitor):
def visit_binary(self, binary):
if binary.left is t1.c.col3: