diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 43 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 8 |
4 files changed, 41 insertions, 12 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 8aa293a6c..dbc0c5dc9 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -42,6 +42,7 @@ from sqlalchemy.sql import ( select, subquery, text, + tuple_, union, union_all, update, diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 0b347ca38..aa18eac17 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -46,6 +46,7 @@ from sqlalchemy.sql.expression import ( subquery, table, text, + tuple_, union, union_all, update, diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index cf5d98d8f..6d74fec16 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -47,7 +47,7 @@ __all__ = [ 'modifier', 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label', 'literal', 'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select', - 'subquery', 'table', 'text', 'union', 'union_all', 'update', ] + 'subquery', 'table', 'text', 'tuple_', 'union', 'union_all', 'update', ] PARSE_AUTOCOMMIT = util._symbol('PARSE_AUTOCOMMIT') @@ -662,6 +662,18 @@ def literal(value, type_=None): """ return _BindParamClause(None, value, type_=type_, unique=True) +def tuple_(*expr): + """Return a SQL tuple. + + Main usage is to produce a composite IN construct:: + + tuple_(table.c.col1, table.c.col2).in_( + [(1, 2), (5, 12), (10, 19)] + ) + + """ + return _Tuple(*expr) + def label(name, obj): """Return a :class:`_Label` object for the given :class:`ColumnElement`. @@ -955,6 +967,13 @@ def _literal_as_binds(element, name=None, type_=None): else: return element +def _type_from_args(args): + for a in args: + if not isinstance(a.type, sqltypes.NullType): + return a.type + else: + return sqltypes.NullType + def _no_literals(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() @@ -1500,7 +1519,8 @@ class _CompareMixin(ColumnOperators): if not _is_literal(o): if not isinstance( o, _CompareMixin): raise exc.InvalidRequestError( - "in() function accepts either a list of non-selectable values, or a selectable: %r" % o) + "in() function accepts either a list of non-selectable values, " + "or a selectable: %r" % o) else: o = self._bind_param(o) args.append(o) @@ -2360,6 +2380,22 @@ class BooleanClauseList(ClauseList, ColumnElement): def _select_iterable(self): return (self, ) +class _Tuple(ClauseList, ColumnElement): + + def __init__(self, *clauses, **kw): + super(_Tuple, self).__init__(*clauses, **kw) + self.type = _type_from_args(clauses) + + @property + def _select_iterable(self): + return (self, ) + + def _bind_param(self, obj): + return _Tuple(*[ + _BindParamClause(None, o, type_=self.type, unique=True) + for o in obj + ]).self_group() + class _Case(ColumnElement): __visit_name__ = 'case' @@ -3318,9 +3354,6 @@ class _ScalarSelect(_Grouping): def __init__(self, element): self.element = element cols = list(element.c) - if len(cols) != 1: - raise exc.InvalidRequestError("Scalar select can only be created " - "from a Select object that has exactly one column expression.") self.type = cols[0].type @property diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index c6cb938d4..212f81ada 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,6 +1,6 @@ from sqlalchemy import types as sqltypes from sqlalchemy.sql.expression import ( - ClauseList, Function, _literal_as_binds, text + ClauseList, Function, _literal_as_binds, text, _type_from_args ) from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import VisitableType @@ -102,9 +102,3 @@ class sysdate(AnsiFunction): class user(AnsiFunction): __return_type__ = sqltypes.String -def _type_from_args(args): - for a in args: - if not isinstance(a.type, sqltypes.NullType): - return a.type - else: - return sqltypes.NullType |