summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES4
-rw-r--r--lib/sqlalchemy/ansisql.py4
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py4
-rw-r--r--lib/sqlalchemy/schema.py18
-rw-r--r--lib/sqlalchemy/sql.py88
-rw-r--r--lib/sqlalchemy/types.py5
-rw-r--r--test/ext/activemapper.py1
-rw-r--r--test/sql/query.py94
8 files changed, 173 insertions, 45 deletions
diff --git a/CHANGES b/CHANGES
index 98404cdd8..b7504b967 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,7 @@
+0.3.2
+- added keywords for EXCEPT, INTERSECT, EXCEPT ALL, INTERSECT ALL
+[ticket:247]
+
0.3.1
- Engine/Pool:
- some new Pool utility classes, updated docs
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index a0ce64905..2e0fe6e34 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -302,7 +302,7 @@ class ANSICompiler(sql.Compiled):
self.select_stack.append(select)
for c in select._raw_columns:
# TODO: make this polymorphic?
- if isinstance(c, sql.Select) and c._scalar:
+ if isinstance(c, sql.Select) and c.is_scalar:
c.accept_visitor(self)
inner_columns[self.get_str(c)] = c
continue
@@ -319,7 +319,7 @@ class ANSICompiler(sql.Compiled):
inner_columns[co._label] = l
# TODO: figure this out, a ColumnClause with a select as a parent
# is different from any other kind of parent
- elif select.issubquery and isinstance(co, sql._ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
+ elif select.is_subquery and isinstance(co, sql._ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
# SQLite doesnt like selecting from a subquery where the column
# names look like table.colname, so add a label synonomous with
# the column name
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index d83ecfb59..d3081bc23 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -324,9 +324,7 @@ def _selectable_name(selectable):
if isinstance(selectable, sql.Alias):
return _selectable_name(selectable.selectable)
elif isinstance(selectable, sql.Select):
- # sometimes a Select has itself in _froms
- nonrecursive_froms = [s for s in selectable._froms if s is not selectable]
- return ''.join([_selectable_name(s) for s in nonrecursive_froms])
+ return ''.join([_selectable_name(s) for s in selectable.froms])
elif isinstance(selectable, schema.Table):
return selectable.name.capitalize()
else:
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index fb6894f0b..d9a7684e7 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -241,7 +241,7 @@ class Table(SchemaItem, sql.TableClause):
[repr(self.name)] + [repr(self.metadata)] +
[repr(x) for x in self.columns] +
["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]
- , ',\n')
+ , ',')
def __str__(self):
return _get_table_key(self.name, self.schema)
@@ -401,10 +401,22 @@ class Column(SchemaItem, sql._ColumnClause):
fk._set_parent(self)
def __repr__(self):
- return "Column(%s)" % string.join(
+ kwarg = []
+ if self.key != self.name:
+ kwarg.append('key')
+ if self._primary_key:
+ kwarg.append('primary_key')
+ if not self.nullable:
+ kwarg.append('nullable')
+ if self.onupdate:
+ kwarg.append('onupdate')
+ if self.default:
+ kwarg.append('default')
+ return "Column(%s)" % string.join(
[repr(self.name)] + [repr(self.type)] +
[repr(x) for x in self.foreign_keys if x is not None] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'default', 'onupdate']]
+ [repr(x) for x in self.constraints] +
+ ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
, ',')
def _get_parent(self):
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index ce33810a5..b5faf37fe 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -9,7 +9,8 @@ from sqlalchemy import util, exceptions
from sqlalchemy import types as sqltypes
import string, re, random, sets
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable']
+
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'except_', 'except_all', 'intersect', 'intersect_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable']
def desc(column):
"""return a descending ORDER BY clause element, e.g.:
@@ -181,6 +182,18 @@ def union(*selects, **params):
def union_all(*selects, **params):
return _compound_select('UNION ALL', *selects, **params)
+def except_(*selects, **params):
+ return _compound_select('EXCEPT', *selects, **params)
+
+def except_all(*selects, **params):
+ return _compound_select('EXCEPT ALL', *selects, **params)
+
+def intersect(*selects, **params):
+ return _compound_select('INTERSECT', *selects, **params)
+
+def intersect_all(*selects, **params):
+ return _compound_select('INTERSECT ALL', *selects, **params)
+
def alias(*args, **params):
return Alias(*args, **params)
@@ -1357,7 +1370,7 @@ class _SelectBaseMixin(object):
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
def _get_from_objects(self):
- if self.is_where or self._scalar:
+ if self.is_where or self.is_scalar:
return []
else:
return [self]
@@ -1366,19 +1379,27 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
_SelectBaseMixin.__init__(self)
self.keyword = keyword
- self.selects = selects
self.use_labels = kwargs.pop('use_labels', False)
self.parens = kwargs.pop('parens', False)
self.correlate = kwargs.pop('correlate', False)
self.for_update = kwargs.pop('for_update', False)
self.nowait = kwargs.pop('nowait', False)
- self.limit = kwargs.get('limit', None)
- self.offset = kwargs.get('offset', None)
- for s in self.selects:
+ self.limit = kwargs.pop('limit', None)
+ self.offset = kwargs.pop('offset', None)
+ self.is_compound = True
+ self.is_where = False
+ self.is_scalar = False
+
+ self.selects = selects
+
+ for s in selects:
s.group_by(None)
s.order_by(None)
- self.group_by(*kwargs.get('group_by', [None]))
- self.order_by(*kwargs.get('order_by', [None]))
+
+ self.group_by(*kwargs.pop('group_by', [None]))
+ self.order_by(*kwargs.pop('order_by', [None]))
+ if len(kwargs):
+ raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
self._col_map = {}
name = property(lambda s:s.keyword + " statement")
@@ -1420,9 +1441,9 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
class Select(_SelectBaseMixin, FromClause):
"""represents a SELECT statement, with appendable clauses, as well as
the ability to execute itself and return a result set."""
- def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, nowait=False, engine=None, limit=None, offset=None, scalar=False, correlate=True):
+ def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True):
_SelectBaseMixin.__init__(self)
- self._froms = util.OrderedDict()
+ self.__froms = util.OrderedDict()
self.use_labels = use_labels
self.whereclause = None
self.having = None
@@ -1430,31 +1451,29 @@ class Select(_SelectBaseMixin, FromClause):
self.limit = limit
self.offset = offset
self.for_update = for_update
- self.nowait = nowait
+ self.is_compound = False
# indicates that this select statement should not expand its columns
# into the column clause of an enclosing select, and should instead
# act like a single scalar column
- self._scalar = scalar
+ self.is_scalar = scalar
# indicates if this select statement, as a subquery, should correlate
# its FROM clause to that of an enclosing select statement
self.correlate = correlate
# indicates if this select statement is a subquery inside another query
- self.issubquery = False
+ self.is_subquery = False
# indicates if this select statement is a subquery as a criterion
# inside of a WHERE clause
self.is_where = False
self.distinct = distinct
- self._text = None
self._raw_columns = []
self._correlated = None
- self._correlator = Select._CorrelatedVisitor(self, False)
- self._wherecorrelator = Select._CorrelatedVisitor(self, True)
-
+ self.__correlator = Select._CorrelatedVisitor(self, False)
+ self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
self.group_by(*(group_by or [None]))
self.order_by(*(order_by or [None]))
@@ -1471,10 +1490,6 @@ class Select(_SelectBaseMixin, FromClause):
for f in from_obj:
self.append_from(f)
- def _foo(self):
- raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables"
- name = property(lambda s:s._foo()) #"SELECT statement")
-
class _CorrelatedVisitor(ClauseVisitor):
"""visits a clause, locates any Select clauses, and tells them that they should
correlate their FROM list to that of their parent."""
@@ -1491,12 +1506,12 @@ class Select(_SelectBaseMixin, FromClause):
if select is self.select:
return
select.is_where = self.is_where
- select.issubquery = True
+ select.is_subquery = True
select.parens = True
if not select.correlate:
return
if getattr(select, '_correlated', None) is None:
- select._correlated = self.select._froms
+ select._correlated = self.select._Select__froms
def append_column(self, column):
if _is_literal(column):
@@ -1506,12 +1521,13 @@ class Select(_SelectBaseMixin, FromClause):
# if the column is a Select statement itself,
# accept visitor
- column.accept_visitor(self._correlator)
+ column.accept_visitor(self.__correlator)
# visit the FROM objects of the column looking for more Selects
for f in column._get_from_objects():
- f.accept_visitor(self._correlator)
- column._process_from_dict(self._froms, False)
+ f.accept_visitor(self.__correlator)
+ column._process_from_dict(self.__froms, False)
+
def _exportable_columns(self):
return self._raw_columns
def _proxy_column(self, column):
@@ -1526,23 +1542,23 @@ class Select(_SelectBaseMixin, FromClause):
def _append_condition(self, attribute, condition):
if type(condition) == str:
condition = _TextClause(condition)
- condition.accept_visitor(self._wherecorrelator)
- condition._process_from_dict(self._froms, False)
+ condition.accept_visitor(self.__wherecorrelator)
+ condition._process_from_dict(self.__froms, False)
if getattr(self, attribute) is not None:
setattr(self, attribute, and_(getattr(self, attribute), condition))
else:
setattr(self, attribute, condition)
def clear_from(self, from_obj):
- self._froms[from_obj] = FromClause()
+ self.__froms[from_obj] = FromClause()
def append_from(self, fromclause):
if type(fromclause) == str:
fromclause = _TextClause(fromclause)
- fromclause.accept_visitor(self._correlator)
- fromclause._process_from_dict(self._froms, True)
+ fromclause.accept_visitor(self.__correlator)
+ fromclause._process_from_dict(self.__froms, True)
def _locate_oid_column(self):
- for f in self._froms.values():
+ for f in self.__froms.values():
if f is self:
# we might be in our own _froms list if a column with us as the parent is attached,
# which includes textual columns.
@@ -1553,8 +1569,8 @@ class Select(_SelectBaseMixin, FromClause):
else:
return None
def _get_froms(self):
- return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))]
- froms = property(lambda s: s._get_froms())
+ return [f for f in self.__froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))]
+ froms = property(lambda s: s._get_froms(), doc="""a list containing all elements of the FROM clause""")
def accept_visitor(self, visitor):
# TODO: add contextual visit_ methods
@@ -1581,7 +1597,7 @@ class Select(_SelectBaseMixin, FromClause):
if self._engine is not None:
return self._engine
- for f in self._froms.values():
+ for f in self.__froms.values():
if f is self:
continue
e = f.engine
@@ -1657,7 +1673,7 @@ class _Update(_UpdateBase):
visitor.visit_update(self)
class _Delete(_UpdateBase):
- def __init__(self, table, whereclause, **params):
+ def __init__(self, table, whereclause):
self.table = table
self.whereclause = whereclause
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index d7e9d8ce6..08f2c9843 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -12,6 +12,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
]
from sqlalchemy import util, exceptions
+import inspect
try:
import cPickle as pickle
except:
@@ -37,7 +38,9 @@ class AbstractType(object):
this can be useful for calling setinputsizes(), for example."""
return None
-
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
+
class TypeEngine(AbstractType):
def __init__(self, *args, **params):
pass
diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py
index e6ce06390..f87cbb46e 100644
--- a/test/ext/activemapper.py
+++ b/test/ext/activemapper.py
@@ -10,6 +10,7 @@ import sqlalchemy.ext.activemapper as activemapper
class testcase(testbase.PersistTest):
def setUpAll(self):
+ sqlalchemy.clear_mappers()
global Person, Preferences, Address
class Person(ActiveMapper):
diff --git a/test/sql/query.py b/test/sql/query.py
index 96ad6ec8b..d88b2bf83 100644
--- a/test/sql/query.py
+++ b/test/sql/query.py
@@ -261,6 +261,100 @@ class QueryTest(PersistTest):
r.close()
finally:
shadowed.drop()
+
+class CompoundTest(PersistTest):
+ """test compound statements like UNION, INTERSECT, particularly their ability to nest on
+ different databases."""
+ def setUpAll(self):
+ global metadata, t1, t2, t3
+ metadata = BoundMetaData(testbase.db)
+ t1 = Table('t1', metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', String(40)),
+ Column('col4', String(30))
+ )
+ t2 = Table('t2', metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', String(40)),
+ Column('col4', String(30)))
+ t3 = Table('t3', metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', String(40)),
+ Column('col4', String(30)))
+ metadata.create_all()
+
+ t1.insert().execute([
+ dict(col2="t1col2r1", col3="aaa", col4="aaa"),
+ dict(col2="t1col2r2", col3="bbb", col4="bbb"),
+ dict(col2="t1col2r3", col3="ccc", col4="ccc"),
+ ])
+ t2.insert().execute([
+ dict(col2="t2col2r1", col3="aaa", col4="bbb"),
+ dict(col2="t2col2r2", col3="bbb", col4="ccc"),
+ dict(col2="t2col2r3", col3="ccc", col4="aaa"),
+ ])
+ t3.insert().execute([
+ dict(col2="t3col2r1", col3="aaa", col4="ccc"),
+ dict(col2="t3col2r2", col3="bbb", col4="aaa"),
+ dict(col2="t3col2r3", col3="ccc", col4="bbb"),
+ ])
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def test_union(self):
+ (s1, s2) = (
+ select([t1.c.col3, t1.c.col4], t1.c.col2.in_("t1col2r1", "t1col2r2")),
+ select([t2.c.col3, t2.c.col4], t2.c.col2.in_("t2col2r2", "t2col2r3"))
+ )
+ u = union(s1, s2)
+ assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+ assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+
+ @testbase.unsupported('mysql')
+ def test_intersect(self):
+ i = intersect(
+ select([t2.c.col3, t2.c.col4]),
+ select([t2.c.col3, t2.c.col4], t2.c.col4==t3.c.col3)
+ )
+ assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+ assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+
+ @testbase.unsupported('mysql')
+ def test_except_style1(self):
+ e = except_(union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ parens=True), select([t2.c.col3, t2.c.col4]))
+ assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+
+ @testbase.unsupported('mysql')
+ def test_except_style2(self):
+ e = except_(union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ ).alias('foo').select(), select([t2.c.col3, t2.c.col4]))
+ assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+ assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+
+ @testbase.unsupported('mysql')
+ def test_composite(self):
+ u = intersect(
+ select([t2.c.col3, t2.c.col4]),
+ union(
+ select([t1.c.col3, t1.c.col4]),
+ select([t2.c.col3, t2.c.col4]),
+ select([t3.c.col3, t3.c.col4]),
+ ).alias('foo').select()
+ )
+ assert u.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+ assert u.alias('foo').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+
if __name__ == "__main__":
testbase.main()