diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/base.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 580 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 31 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 27 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 146 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 52 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 678 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/operators.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/roles.py | 157 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/schema.py | 45 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 304 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 30 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 3 |
17 files changed, 1330 insertions, 798 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index fb5639ef3..00cafd8ff 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -94,6 +94,22 @@ def __go(lcls): from .elements import ClauseList # noqa from .selectable import AnnotatedFromClause # noqa + from . import base + from . import coercions + from . import elements + from . import selectable + from . import schema + from . import sqltypes + from . import type_api + + base.coercions = elements.coercions = coercions + base.elements = elements + base.type_api = type_api + coercions.elements = elements + coercions.schema = schema + coercions.selectable = selectable + coercions.sqltypes = sqltypes + _prepare_annotations(ColumnElement, AnnotatedColumnElement) _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index c5e5fd8a1..9df0c932f 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -17,6 +17,9 @@ from .visitors import ClauseVisitor from .. import exc from .. import util +coercions = None # type: types.ModuleType +elements = None # type: types.ModuleType +type_api = None # type: types.ModuleType PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") NO_ARG = util.symbol("NO_ARG") @@ -589,8 +592,7 @@ class ColumnCollection(util.OrderedProperties): __hash__ = None - @util.dependencies("sqlalchemy.sql.elements") - def __eq__(self, elements, other): + def __eq__(self, other): l = [] for c in getattr(other, "_all_columns", other): for local in self._all_columns: @@ -636,8 +638,7 @@ class ColumnSet(util.ordered_column_set): def __add__(self, other): return list(self) + list(other) - @util.dependencies("sqlalchemy.sql.elements") - def __eq__(self, elements, other): + def __eq__(self, other): l = [] for c in other: for local in self: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py new file mode 100644 index 000000000..7c7222f9f --- /dev/null +++ b/lib/sqlalchemy/sql/coercions.py @@ -0,0 +1,580 @@ +# sql/coercions.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import numbers +import re + +from . import operators +from . import roles +from . import visitors +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util import collections_abc + +elements = None # type: types.ModuleType +schema = None # type: types.ModuleType +selectable = None # type: types.ModuleType +sqltypes = None # type: types.ModuleType + + +def _is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + """ + return not isinstance( + element, (Visitable, schema.SchemaEventTarget) + ) and not hasattr(element, "__clause_element__") + + +def _document_text_coercion(paramname, meth_rst, param_rst): + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + +def expect(role, element, **kw): + # major case is that we are given a ClauseElement already, skip more + # elaborate logic up front if possible + impl = _impl_lookup[role] + + if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)): + resolved = impl._resolve_for_clause_element(element, **kw) + else: + resolved = element + + if issubclass(resolved.__class__, impl._role_class): + if impl._post_coercion: + resolved = impl._post_coercion(resolved, **kw) + return resolved + else: + return impl._implicit_coercions(element, resolved, **kw) + + +def expect_as_key(role, element, **kw): + kw["as_key"] = True + return expect(role, element, **kw) + + +def expect_col_expression_collection(role, expressions): + for expr in expressions: + strname = None + column = None + + resolved = expect(role, expr) + if isinstance(resolved, util.string_types): + strname = resolved = expr + else: + cols = [] + visitors.traverse(resolved, {}, {"column": cols.append}) + if cols: + column = cols[0] + add_element = column if column is not None else strname + yield resolved, column, strname, add_element + + +class RoleImpl(object): + __slots__ = ("_role_class", "name", "_use_inspection") + + def _literal_coercion(self, element, **kw): + raise NotImplementedError() + + _post_coercion = None + + def __init__(self, role_class): + self._role_class = role_class + self.name = role_class._role_name + self._use_inspection = issubclass(role_class, roles.UsesInspection) + + def _resolve_for_clause_element(self, element, argname=None, **kw): + literal_coercion = self._literal_coercion + original_element = element + is_clause_element = False + + while hasattr(element, "__clause_element__") and not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + element = element.__clause_element__() + is_clause_element = True + + if not is_clause_element: + if self._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + self._raise_for_expected(original_element, argname) + + return self._literal_coercion(element, argname=argname, **kw) + else: + return element + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + self._raise_for_expected(element, argname) + + def _raise_for_expected(self, element, argname=None): + if argname: + raise exc.ArgumentError( + "%s expected for argument %r; got %r." + % (self.name, argname, element) + ) + else: + raise exc.ArgumentError( + "%s expected, got %r." % (self.name, element) + ) + + +class _StringOnly(object): + def _resolve_for_clause_element(self, element, argname=None, **kw): + return self._literal_coercion(element, **kw) + + +class _ReturnsStringKey(object): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return original_element + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, **kw): + return element + + +class _ColumnCoercions(object): + def _warn_for_scalar_subquery_coercion(self): + util.warn_deprecated( + "coercing SELECT object to scalar subquery in a " + "column-expression context is deprecated in version 1.4; " + "please use the .scalar_subquery() method to produce a scalar " + "subquery. This automatic coercion will be removed in a " + "future release." + ) + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_select_statement: + self._warn_for_scalar_subquery_coercion() + return resolved.scalar_subquery() + elif ( + resolved._is_from_clause + and isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + self._warn_for_scalar_subquery_coercion() + return resolved.original.scalar_subquery() + else: + self._raise_for_expected(original_element, argname) + + +def _no_text_coercion( + element, argname=None, exc_cls=exc.ArgumentError, extra=None +): + raise exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ) + + +class _NoTextCoercion(object): + def _literal_coercion(self, element, argname=None): + if isinstance(element, util.string_types) and issubclass( + elements.TextClause, self._role_class + ): + _no_text_coercion(element, argname) + else: + self._raise_for_expected(element, argname) + + +class _CoerceLiterals(object): + _coerce_consts = False + _coerce_star = False + _coerce_numerics = False + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + def _literal_coercion(self, element, argname=None): + if isinstance(element, util.string_types): + if self._coerce_star and element == "*": + return elements.ColumnClause("*", is_literal=True) + else: + return self._text_coercion(element, argname) + + if self._coerce_consts: + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + + if self._coerce_numerics and isinstance(element, (numbers.Number)): + return elements.ColumnClause(str(element), is_literal=True) + + self._raise_for_expected(element, argname) + + +class ExpressionElementImpl( + _ColumnCoercions, RoleImpl, roles.ExpressionElementRole +): + def _literal_coercion(self, element, name=None, type_=None, argname=None): + if element is None: + return elements.Null() + else: + try: + return elements.BindParameter( + name, element, type_, unique=True + ) + except exc.ArgumentError: + self._raise_for_expected(element) + + +class BinaryElementImpl( + ExpressionElementImpl, RoleImpl, roles.BinaryElementRole +): + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None + ): + try: + return expr._bind_param(operator, element, type_=bindparam_type) + except exc.ArgumentError: + self._raise_for_expected(element) + + def _post_coercion(self, resolved, expr, **kw): + if ( + isinstance(resolved, elements.BindParameter) + and resolved.type._isnull + ): + resolved = resolved._clone() + resolved.type = expr.type + return resolved + + +class InElementImpl(RoleImpl, roles.InElementRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + return resolved.original + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, expr, operator, **kw): + if isinstance(element, collections_abc.Iterable) and not isinstance( + element, util.string_types + ): + args = [] + for o in element: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + self._raise_for_expected(element, **kw) + elif o is None: + o = elements.Null() + else: + o = expr._bind_param(operator, o) + args.append(o) + + return elements.ClauseList(*args) + + else: + self._raise_for_expected(element, **kw) + + def _post_coercion(self, element, expr, operator, **kw): + if element._is_select_statement: + return element.scalar_subquery() + elif isinstance(element, elements.ClauseList): + if len(element.clauses) == 0: + op, negate_op = ( + (operators.empty_in_op, operators.empty_notin_op) + if operator is operators.in_op + else (operators.empty_notin_op, operators.empty_in_op) + ) + return element.self_group(against=op)._annotate( + dict(in_ops=(op, negate_op)) + ) + else: + return element.self_group(against=operator) + + elif isinstance(element, elements.BindParameter) and element.expanding: + + if isinstance(expr, elements.Tuple): + element = element._with_expanding_in_types( + [elem.type for elem in expr] + ) + return element + else: + return element + + +class WhereHavingImpl( + _CoerceLiterals, _ColumnCoercions, RoleImpl, roles.WhereHavingRole +): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + +class StatementOptionImpl( + _CoerceLiterals, RoleImpl, roles.StatementOptionRole +): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class ColumnArgumentImpl(_NoTextCoercion, RoleImpl, roles.ColumnArgumentRole): + pass + + +class ColumnArgumentOrKeyImpl( + _ReturnsStringKey, RoleImpl, roles.ColumnArgumentOrKeyRole +): + pass + + +class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements._textual_label_reference(element) + + +class OrderByImpl(ByOfImpl, RoleImpl, roles.OrderByRole): + def _post_coercion(self, resolved): + if ( + isinstance(resolved, self._role_class) + and resolved._order_by_label_element is not None + ): + return elements._label_reference(resolved) + else: + return resolved + + +class DMLColumnImpl(_ReturnsStringKey, RoleImpl, roles.DMLColumnRole): + def _post_coercion(self, element, as_key=False): + if as_key: + return element.key + else: + return element + + +class ConstExprImpl(RoleImpl, roles.ConstExprRole): + def _literal_coercion(self, element, argname=None): + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + else: + self._raise_for_expected(element, argname) + + +class TruncatedLabelImpl(_StringOnly, RoleImpl, roles.TruncatedLabelRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return resolved + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, argname=None): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(element, elements._truncated_label): + return element + else: + return elements._truncated_label(element) + + +class DDLExpressionImpl(_CoerceLiterals, RoleImpl, roles.DDLExpressionRole): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class DDLConstraintColumnImpl( + _ReturnsStringKey, RoleImpl, roles.DDLConstraintColumnRole +): + pass + + +class LimitOffsetImpl(RoleImpl, roles.LimitOffsetRole): + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if resolved is None: + return None + else: + self._raise_for_expected(element, argname) + + def _literal_coercion(self, element, name, type_, **kw): + if element is None: + return None + else: + value = util.asint(element) + return selectable._OffsetLimitParam( + name, value, type_=type_, unique=True + ) + + +class LabeledColumnExprImpl( + ExpressionElementImpl, roles.LabeledColumnExprRole +): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.ExpressionElementRole): + return resolved.label(None) + else: + new = super(LabeledColumnExprImpl, self)._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + if isinstance(new, roles.ExpressionElementRole): + return new.label(None) + else: + self._raise_for_expected(original_element, argname) + + +class ColumnsClauseImpl(_CoerceLiterals, RoleImpl, roles.ColumnsClauseRole): + + _coerce_consts = True + _coerce_numerics = True + _coerce_star = True + + _guess_straight_column = re.compile(r"^\w\S*$", re.I) + + def _text_coercion(self, element, argname=None): + element = str(element) + + guess_is_literal = not self._guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r %(argname)sshould be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "literal_column": "literal_column" + if guess_is_literal + else "column", + } + ) + + +class ReturnsRowsImpl(RoleImpl, roles.ReturnsRowsRole): + pass + + +class StatementImpl(_NoTextCoercion, RoleImpl, roles.StatementRole): + pass + + +class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl, roles.StatementRole): + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class SelectStatementImpl( + _NoTextCoercion, RoleImpl, roles.SelectStatementRole +): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved.columns() + else: + self._raise_for_expected(original_element, argname) + + +class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): + pass + + +class FromClauseImpl(_NoTextCoercion, RoleImpl, roles.FromClauseRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved + else: + self._raise_for_expected(original_element, argname) + + +class DMLSelectImpl(_NoTextCoercion, RoleImpl, roles.DMLSelectRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + return resolved.original + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname) + + +class CompoundElementImpl( + _NoTextCoercion, RoleImpl, roles.CompoundElementRole +): + def _implicit_coercions(self, original_element, resolved, argname=None): + if resolved._is_from_clause: + return resolved + else: + self._raise_for_expected(original_element, argname) + + +_impl_lookup = {} + + +for name in dir(roles): + cls = getattr(roles, name) + if name.endswith("Role"): + name = name.replace("Role", "Impl") + if name in globals(): + impl = globals()[name](cls) + _impl_lookup[cls] = impl diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c7fe3dc50..8080d2cc6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -27,14 +27,15 @@ import contextlib import itertools import re +from . import coercions from . import crud from . import elements from . import functions from . import operators +from . import roles from . import schema from . import selectable from . import sqltypes -from . import visitors from .. import exc from .. import util @@ -400,7 +401,9 @@ class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): return type_._compiler_dispatch(self, **kw) -class _CompileLabel(visitors.Visitable): +# this was a Visitable, but to allow accurate detection of +# column elements this is actually a column element +class _CompileLabel(elements.ColumnElement): """lightweight label object which acts as an expression.Label.""" @@ -766,10 +769,10 @@ class SQLCompiler(Compiled): else: col = with_cols[element.element] except KeyError: - elements._no_text_coercion( + coercions._no_text_coercion( element.element, - exc.CompileError, - "Can't resolve label reference for ORDER BY / GROUP BY.", + extra="Can't resolve label reference for ORDER BY / GROUP BY.", + exc_cls=exc.CompileError, ) else: kwargs["render_label_as_label"] = col @@ -1635,7 +1638,6 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_name[cte_name] = cte - # look for embedded DML ctes and propagate autocommit if ( "autocommit" in cte.element._execution_options and "autocommit" not in self.execution_options @@ -1656,10 +1658,10 @@ class SQLCompiler(Compiled): self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.original, selectable.Select): - col_source = cte.original - elif isinstance(cte.original, selectable.CompoundSelect): - col_source = cte.original.selects[0] + if isinstance(cte.element, selectable.Select): + col_source = cte.element + elif isinstance(cte.element, selectable.CompoundSelect): + col_source = cte.element.selects[0] else: assert False recur_cols = [ @@ -1810,7 +1812,7 @@ class SQLCompiler(Compiled): ): result_expr = _CompileLabel( col_expr, - elements._as_truncated(column.name), + coercions.expect(roles.TruncatedLabelRole, column.name), alt_names=(column.key,), ) elif ( @@ -1830,7 +1832,7 @@ class SQLCompiler(Compiled): # assert isinstance(column, elements.ColumnClause) result_expr = _CompileLabel( col_expr, - elements._as_truncated(column.name), + coercions.expect(roles.TruncatedLabelRole, column.name), alt_names=(column.key,), ) else: @@ -1880,7 +1882,7 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() if ( - newelem.is_selectable + newelem._is_from_clause and newelem._is_join and isinstance(newelem.right, selectable.FromGrouping) ): @@ -1933,7 +1935,7 @@ class SQLCompiler(Compiled): # marker in the stack. kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) - elif newelem.is_selectable and newelem._is_select: + elif newelem._is_returns_rows and newelem._is_select_statement: barrier_select = ( kw.get("transform_clue", None) == "select_container" ) @@ -2349,6 +2351,7 @@ class SQLCompiler(Compiled): + join_type + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + " ON " + # TODO: likely need asfrom=True here? + join.onclause._compiler_dispatch(self, **kwargs) ) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 552f61b4a..881ea9fcd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -9,14 +9,16 @@ within INSERT and UPDATE statements. """ +import functools import operator +from . import coercions from . import dml from . import elements +from . import roles from .. import exc from .. import util - REQUIRED = util.symbol( "REQUIRED", """ @@ -174,7 +176,7 @@ def _get_crud_params(compiler, stmt, **kw): if check: raise exc.CompileError( "Unconsumed column names: %s" - % (", ".join("%s" % c for c in check)) + % (", ".join("%s" % (c,) for c in check)) ) if stmt._has_multi_parameters: @@ -207,8 +209,12 @@ def _key_getters_for_crud_column(compiler, stmt): # statement. _et = set(stmt._extra_froms) + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + def _column_as_key(key): - str_key = elements._column_as_key(key) + str_key = c_key_role(key) if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: @@ -227,7 +233,9 @@ def _key_getters_for_crud_column(compiler, stmt): return col.key else: - _column_as_key = elements._column_as_key + _column_as_key = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) _getattr_col_key = _col_bind_name = operator.attrgetter("key") return _column_as_key, _getattr_col_key, _col_bind_name @@ -386,7 +394,7 @@ def _append_param_parameter( kw, ): value = parameters.pop(col_key) - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -633,9 +641,8 @@ def _get_multitable_params( values, kw, ): - normalized_params = dict( - (elements._clause_element_as_expr(c), param) + (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) affected_tables = set() @@ -645,7 +652,7 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -697,7 +704,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if col in row or col.key in row: key = col if col in row else col.key - if elements._is_literal(row[key]): + if coercions._is_literal(row[key]): new_param = _create_bind_param( compiler, col, @@ -730,7 +737,7 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param - if elements._is_literal(v): + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw ) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index d87a6a1b0..ff36a68e4 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -10,6 +10,7 @@ to invoke them for a create/drop call. """ +from . import roles from .base import _bind_or_error from .base import _generative from .base import Executable @@ -29,7 +30,7 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) -class DDLElement(Executable, _DDLCompiles): +class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """Base class for DDL expression constructs. This class is the base for the general purpose :class:`.DDL` class, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 9a12b84cd..918f7524e 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -8,32 +8,21 @@ """Default implementation of SQL comparison operations. """ + +from . import coercions from . import operators +from . import roles from . import type_api -from .elements import _clause_element_as_expr -from .elements import _const_expr -from .elements import _is_literal -from .elements import _literal_as_text from .elements import and_ from .elements import BinaryExpression -from .elements import BindParameter -from .elements import ClauseElement from .elements import ClauseList from .elements import collate from .elements import CollectionAggregate -from .elements import ColumnElement from .elements import False_ from .elements import Null from .elements import or_ -from .elements import TextClause from .elements import True_ -from .elements import Tuple from .elements import UnaryExpression -from .elements import Visitable -from .selectable import Alias -from .selectable import ScalarSelect -from .selectable import Selectable -from .selectable import SelectBase from .. import exc from .. import util @@ -62,7 +51,7 @@ def _boolean_compare( ): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -71,7 +60,7 @@ def _boolean_compare( elif op in (operators.is_distinct_from, operators.isnot_distinct_from): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -82,7 +71,7 @@ def _boolean_compare( if op in (operators.eq, operators.is_): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.is_, negate=operators.isnot, type_=result_type, @@ -90,7 +79,7 @@ def _boolean_compare( elif op in (operators.ne, operators.isnot): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.isnot, negate=operators.is_, type_=result_type, @@ -102,7 +91,9 @@ def _boolean_compare( "operators can be used with None/True/False" ) else: - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) if reverse: return BinaryExpression( @@ -127,7 +118,9 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) if reverse: left, right = obj, expr @@ -156,77 +149,22 @@ def _scalar(expr, op, fn, **kw): def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, ScalarSelect): - return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op) - elif isinstance(seq_or_selectable, SelectBase): - - # TODO: if we ever want to support (x, y, z) IN (select x, - # y, z from table), we would need a multi-column version of - # as_scalar() to produce a multi- column selectable that - # does not export itself as a FROM clause - - return _boolean_compare( - expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, ClauseElement): - if ( - isinstance(seq_or_selectable, BindParameter) - and seq_or_selectable.expanding - ): - - if isinstance(expr, Tuple): - seq_or_selectable = seq_or_selectable._with_expanding_in_types( - [elem.type for elem in expr] - ) - - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op - ) - else: - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' - % seq_or_selectable - ) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' % o - ) - elif o is None: - o = Null() - else: - o = expr._bind_param(op, o) - args.append(o) - - if len(args) == 0: - op, negate_op = ( - (operators.empty_in_op, operators.empty_notin_op) - if op is operators.in_op - else (operators.empty_notin_op, operators.empty_in_op) - ) + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] return _boolean_compare( - expr, op, ClauseList(*args).self_group(against=op), negate=negate_op + expr, op, seq_or_selectable, negate=negate_op, **kw ) def _getitem_impl(expr, op, other, **kw): if isinstance(expr.type, type_api.INDEXABLE): - other = _check_literal(expr, op, other) + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) return _binary_operate(expr, op, other, **kw) else: _unsupported_impl(expr, op, other, **kw) @@ -257,7 +195,12 @@ def _match_impl(expr, op, other, **kw): return _boolean_compare( expr, operators.match_op, - _check_literal(expr, operators.match_op, other), + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), result_type=type_api.MATCHTYPE, negate=operators.notmatch_op if op is operators.match_op @@ -278,8 +221,18 @@ def _between_impl(expr, op, cleft, cright, **kw): return BinaryExpression( expr, ClauseList( - _check_literal(expr, operators.and_, cleft), - _check_literal(expr, operators.and_, cright), + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), operator=operators.and_, group=False, group_contents=False, @@ -349,22 +302,3 @@ operator_lookup = { "rshift": (_unsupported_impl,), "contains": (_unsupported_impl,), } - - -def _check_literal(expr, operator, other, bindparam_type=None): - if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and other.type._isnull: - other = other._clone() - other.type = expr.type - return other - elif hasattr(other, "__clause_element__"): - other = other.__clause_element__() - elif isinstance(other, type_api.TypeEngine.Comparator): - other = other.expr - - if isinstance(other, (SelectBase, Alias)): - return other.as_scalar() - elif not isinstance(other, Visitable): - return expr._bind_param(operator, other, type_=bindparam_type) - else: - return other diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3c40e7914..c7d83fc12 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,18 +9,16 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ +from . import coercions +from . import roles from .base import _from_objects from .base import _generative from .base import DialectKWArgs from .base import Executable from .elements import _clone -from .elements import _column_as_key -from .elements import _literal_as_text from .elements import and_ from .elements import ClauseElement from .elements import Null -from .selectable import _interpret_as_from -from .selectable import _interpret_as_select from .selectable import HasCTE from .selectable import HasPrefixes from .. import exc @@ -28,7 +26,12 @@ from .. import util class UpdateBase( - HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement + roles.DMLRole, + HasCTE, + DialectKWArgs, + HasPrefixes, + Executable, + ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. @@ -210,7 +213,7 @@ class ValuesBase(UpdateBase): _post_values_clause = None def __init__(self, table, values, prefixes): - self.table = _interpret_as_from(table) + self.table = coercions.expect(roles.FromClauseRole, table) self.parameters, self._has_multi_parameters = self._process_colparams( values ) @@ -604,13 +607,16 @@ class Insert(ValuesBase): ) self.parameters, self._has_multi_parameters = self._process_colparams( - {_column_as_key(n): Null() for n in names} + { + coercions.expect(roles.DMLColumnRole, n, as_key=True): Null() + for n in names + } ) self.select_names = names self.inline = True self.include_insert_from_select_defaults = include_defaults - self.select = _interpret_as_select(select) + self.select = coercions.expect(roles.DMLSelectRole, select) def _copy_internals(self, clone=_clone, **kw): # TODO: coverage @@ -678,7 +684,7 @@ class Update(ValuesBase): users.update().values(name='ed').where( users.c.name==select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) :param values: @@ -744,7 +750,7 @@ class Update(ValuesBase): users.update().values( name=select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) .. seealso:: @@ -759,7 +765,9 @@ class Update(ValuesBase): self._bind = bind self._returning = returning if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) else: self._whereclause = None self.inline = inline @@ -785,10 +793,13 @@ class Update(ValuesBase): """ if self._whereclause is not None: self._whereclause = and_( - self._whereclause, _literal_as_text(whereclause) + self._whereclause, + coercions.expect(roles.WhereHavingRole, whereclause), ) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) @property def _extra_froms(self): @@ -846,7 +857,7 @@ class Delete(UpdateBase): users.delete().where( users.c.name==select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) .. versionchanged:: 1.2.0 @@ -858,14 +869,16 @@ class Delete(UpdateBase): """ self._bind = bind - self.table = _interpret_as_from(table) + self.table = coercions.expect(roles.FromClauseRole, table) self._returning = returning if prefixes: self._setup_prefixes(prefixes) if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) else: self._whereclause = None @@ -883,10 +896,13 @@ class Delete(UpdateBase): if self._whereclause is not None: self._whereclause = and_( - self._whereclause, _literal_as_text(whereclause) + self._whereclause, + coercions.expect(roles.WhereHavingRole, whereclause), ) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) @property def _extra_froms(self): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e634e5a36..a333303ec 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -13,12 +13,13 @@ from __future__ import unicode_literals import itertools -import numbers import operator import re from . import clause_compare +from . import coercions from . import operators +from . import roles from . import type_api from .annotation import Annotated from .base import _generative @@ -26,6 +27,7 @@ from .base import Executable from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT +from .coercions import _document_text_coercion from .visitors import cloned_traverse from .visitors import traverse from .visitors import Visitable @@ -38,20 +40,6 @@ def _clone(element, **kw): return element._clone() -def _document_text_coercion(paramname, meth_rst, param_rst): - return util.add_parameter_text( - paramname, - ( - ".. warning:: " - "The %s argument to %s can be passed as a Python string argument, " - "which will be treated " - "as **trusted SQL text** and rendered as given. **DO NOT PASS " - "UNTRUSTED INPUT TO THIS PARAMETER**." - ) - % (param_rst, meth_rst), - ) - - def collate(expression, collation): """Return the clause ``expression COLLATE collation``. @@ -71,7 +59,7 @@ def collate(expression, collation): """ - expr = _literal_as_binds(expression) + expr = coercions.expect(roles.ExpressionElementRole, expression) return BinaryExpression( expr, CollationClause(collation), operators.collate, type_=expr.type ) @@ -127,7 +115,7 @@ def between(expr, lower_bound, upper_bound, symmetric=False): :meth:`.ColumnElement.between` """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) return expr.between(lower_bound, upper_bound, symmetric=symmetric) @@ -172,11 +160,11 @@ def not_(clause): same result. """ - return operators.inv(_literal_as_binds(clause)) + return operators.inv(coercions.expect(roles.ExpressionElementRole, clause)) @inspection._self_inspects -class ClauseElement(Visitable): +class ClauseElement(roles.SQLRole, Visitable): """Base class for elements of a programmatically constructed SQL expression. @@ -188,13 +176,20 @@ class ClauseElement(Visitable): supports_execution = False _from_objects = [] bind = None + description = None _is_clone_of = None - is_selectable = False + is_clause_element = True + is_selectable = False - description = None - _order_by_label_element = None + _is_textual = False + _is_from_clause = False + _is_returns_rows = False + _is_text_clause = False _is_from_container = False + _is_select_statement = False + + _order_by_label_element = None def _clone(self): """Create a shallow copy of this ClauseElement. @@ -238,7 +233,7 @@ class ClauseElement(Visitable): """ - raise NotImplementedError(self.__class__) + raise NotImplementedError() @property def _constructor(self): @@ -394,6 +389,7 @@ class ClauseElement(Visitable): return [] def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement """Apply a 'grouping' to this :class:`.ClauseElement`. This method is overridden by subclasses to return a @@ -553,7 +549,20 @@ class ClauseElement(Visitable): ) -class ColumnElement(operators.ColumnOperators, ClauseElement): +class ColumnElement( + roles.ColumnArgumentOrKeyRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.BinaryElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + roles.LimitOffsetRole, + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + operators.ColumnOperators, + ClauseElement, +): """Represent a column-oriented SQL expression suitable for usage in the "columns" clause, WHERE clause etc. of a statement. @@ -586,17 +595,13 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): :class:`.TypeEngine` objects) are applied to the value. * any special object value, typically ORM-level constructs, which - feature a method called ``__clause_element__()``. The Core + feature an accessor called ``__clause_element__()``. The Core expression system looks for this method when an object of otherwise unknown type is passed to a function that is looking to coerce the - argument into a :class:`.ColumnElement` expression. The - ``__clause_element__()`` method, if present, should return a - :class:`.ColumnElement` instance. The primary use of - ``__clause_element__()`` within SQLAlchemy is that of class-bound - attributes on ORM-mapped classes; a ``User`` class which contains a - mapped attribute named ``.name`` will have a method - ``User.name.__clause_element__()`` which when invoked returns the - :class:`.Column` called ``name`` associated with the mapped table. + argument into a :class:`.ColumnElement` and sometimes a + :class:`.SelectBase` expression. It is used within the ORM to + convert from ORM-specific objects like mapped classes and + mapped attributes into Core expression objects. * The Python ``None`` value is typically interpreted as ``NULL``, which in SQLAlchemy Core produces an instance of :func:`.null`. @@ -702,6 +707,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): _alt_names = () def self_group(self, against=None): + # type: (Module, Module, Optional[Any]) -> ClauseEleent if ( against in (operators.and_, operators.or_, operators._asbool) and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity @@ -826,7 +832,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): else: key = name co = ColumnClause( - _as_truncated(name) if name_is_truncatable else name, + coercions.expect(roles.TruncatedLabelRole, name) + if name_is_truncatable + else name, type_=getattr(self, "type", None), _selectable=selectable, ) @@ -878,7 +886,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): ) -class BindParameter(ColumnElement): +class BindParameter(roles.InElementRole, ColumnElement): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -1235,7 +1243,8 @@ class BindParameter(ColumnElement): "bindparams collection argument required for _cache_key " "implementation. Bound parameter cache keys are not safe " "to use without accommodating for the value or callable " - "within the parameter itself.") + "within the parameter itself." + ) else: bindparams.append(self) return (BindParameter, self.type._cache_key, self._orig_key) @@ -1282,7 +1291,20 @@ class TypeClause(ClauseElement): return (TypeClause, self.type._cache_key) -class TextClause(Executable, ClauseElement): +class TextClause( + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.OrderByRole, + roles.FromClauseRole, + roles.SelectStatementRole, + roles.CoerceTextStatementRole, + roles.BinaryElementRole, + roles.InElementRole, + Executable, + ClauseElement, +): """Represent a literal SQL text fragment. E.g.:: @@ -1304,6 +1326,10 @@ class TextClause(Executable, ClauseElement): __visit_name__ = "textclause" + _is_text_clause = True + + _is_textual = True + _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE) _execution_options = Executable._execution_options.union( {"autocommit": PARSE_AUTOCOMMIT} @@ -1318,20 +1344,16 @@ class TextClause(Executable, ClauseElement): def _select_iterable(self): return (self,) - @property - def selectable(self): - # allows text() to be considered by - # _interpret_as_from - return self - - _hide_froms = [] - # help in those cases where text() is # interpreted in a column expression situation key = _label = _resolve_label = None _allow_label_resolve = False + @property + def _hide_froms(self): + return [] + def __init__(self, text, bind=None): self._bind = bind self._bindparams = {} @@ -1670,7 +1692,6 @@ class TextClause(Executable, ClauseElement): """ - positional_input_cols = [ ColumnClause(col.key, types.pop(col.key)) if col.key in types @@ -1696,6 +1717,7 @@ class TextClause(Executable, ClauseElement): return self.type.comparator_factory(self) def self_group(self, against=None): + # type: (Optional[Any]) -> Union[Grouping, TextClause] if against is operators.in_op: return Grouping(self) else: @@ -1715,7 +1737,7 @@ class TextClause(Executable, ClauseElement): ) -class Null(ColumnElement): +class Null(roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -1739,7 +1761,7 @@ class Null(ColumnElement): return (Null,) -class False_(ColumnElement): +class False_(roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. :class:`.False_` is accessed as a constant via the @@ -1798,7 +1820,7 @@ class False_(ColumnElement): return (False_,) -class True_(ColumnElement): +class True_(roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -1864,7 +1886,12 @@ class True_(ColumnElement): return (True_,) -class ClauseList(ClauseElement): +class ClauseList( + roles.InElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + ClauseElement, +): """Describe a list of clauses, separated by an operator. By default, is comma-separated, such as a column listing. @@ -1877,16 +1904,22 @@ class ClauseList(ClauseElement): self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) self.group_contents = kwargs.pop("group_contents", True) - text_converter = kwargs.pop( - "_literal_as_text", _expression_literal_as_text + + self._text_converter_role = text_converter_role = kwargs.pop( + "_literal_as_text_role", roles.WhereHavingRole ) if self.group_contents: self.clauses = [ - text_converter(clause).self_group(against=self.operator) + coercions.expect(text_converter_role, clause).self_group( + against=self.operator + ) for clause in clauses ] else: - self.clauses = [text_converter(clause) for clause in clauses] + self.clauses = [ + coercions.expect(text_converter_role, clause) + for clause in clauses + ] self._is_implicitly_boolean = operators.is_boolean(self.operator) def __iter__(self): @@ -1902,10 +1935,14 @@ class ClauseList(ClauseElement): def append(self, clause): if self.group_contents: self.clauses.append( - _literal_as_text(clause).self_group(against=self.operator) + coercions.expect(self._text_converter_role, clause).self_group( + against=self.operator + ) ) else: - self.clauses.append(_literal_as_text(clause)) + self.clauses.append( + coercions.expect(self._text_converter_role, clause) + ) def _copy_internals(self, clone=_clone, **kw): self.clauses = [clone(clause, **kw) for clause in self.clauses] @@ -1923,6 +1960,7 @@ class ClauseList(ClauseElement): return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -1947,7 +1985,7 @@ class BooleanClauseList(ClauseList, ColumnElement): convert_clauses = [] clauses = [ - _expression_literal_as_text(clause) + coercions.expect(roles.WhereHavingRole, clause) for clause in util.coerce_generator_arg(clauses) ] for clause in clauses: @@ -2055,6 +2093,7 @@ class BooleanClauseList(ClauseList, ColumnElement): return (self,) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if not self.clauses: return self else: @@ -2092,7 +2131,9 @@ class Tuple(ClauseList, ColumnElement): """ - clauses = [_literal_as_binds(c) for c in clauses] + clauses = [ + coercions.expect(roles.ExpressionElementRole, c) for c in clauses + ] self._type_tuple = [arg.type for arg in clauses] self.type = kw.pop( "type_", @@ -2283,12 +2324,20 @@ class Case(ColumnElement): if value is not None: whenlist = [ - (_literal_as_binds(c).self_group(), _literal_as_binds(r)) + ( + coercions.expect( + roles.ExpressionElementRole, c + ).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) for (c, r) in whens ] else: whenlist = [ - (_no_literals(c).self_group(), _literal_as_binds(r)) + ( + coercions.expect(roles.ColumnArgumentRole, c).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) for (c, r) in whens ] @@ -2300,12 +2349,12 @@ class Case(ColumnElement): if value is None: self.value = None else: - self.value = _literal_as_binds(value) + self.value = coercions.expect(roles.ExpressionElementRole, value) self.type = type_ self.whens = whenlist if else_ is not None: - self.else_ = _literal_as_binds(else_) + self.else_ = coercions.expect(roles.ExpressionElementRole, else_) else: self.else_ = None @@ -2455,7 +2504,9 @@ class Cast(ColumnElement): """ self.type = type_api.to_instance(type_) - self.clause = _literal_as_binds(expression, type_=self.type) + self.clause = coercions.expect( + roles.ExpressionElementRole, expression, type_=self.type + ) self.typeclause = TypeClause(self.type) def _copy_internals(self, clone=_clone, **kw): @@ -2557,7 +2608,9 @@ class TypeCoerce(ColumnElement): """ self.type = type_api.to_instance(type_) - self.clause = _literal_as_binds(expression, type_=self.type) + self.clause = coercions.expect( + roles.ExpressionElementRole, expression, type_=self.type + ) def _copy_internals(self, clone=_clone, **kw): self.clause = clone(self.clause, **kw) @@ -2598,7 +2651,7 @@ class Extract(ColumnElement): """ self.type = type_api.INTEGERTYPE self.field = field - self.expr = _literal_as_binds(expr, None) + self.expr = coercions.expect(roles.ExpressionElementRole, expr) def _copy_internals(self, clone=_clone, **kw): self.expr = clone(self.expr, **kw) @@ -2733,7 +2786,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.nullsfirst_op, wraps_column_expression=False, ) @@ -2776,7 +2829,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.nullslast_op, wraps_column_expression=False, ) @@ -2817,7 +2870,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.desc_op, wraps_column_expression=False, ) @@ -2857,7 +2910,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.asc_op, wraps_column_expression=False, ) @@ -2898,7 +2951,7 @@ class UnaryExpression(ColumnElement): :data:`.func` """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) return UnaryExpression( expr, operator=operators.distinct_op, @@ -2953,6 +3006,7 @@ class UnaryExpression(ColumnElement): return ClauseElement._negate(self) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2990,10 +3044,8 @@ class CollectionAggregate(UnaryExpression): """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) - if expr.is_selectable and hasattr(expr, "as_scalar"): - expr = expr.as_scalar() expr = expr.self_group() return CollectionAggregate( expr, @@ -3023,9 +3075,7 @@ class CollectionAggregate(UnaryExpression): """ - expr = _literal_as_binds(expr) - if expr.is_selectable and hasattr(expr, "as_scalar"): - expr = expr.as_scalar() + expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() return CollectionAggregate( expr, @@ -3064,6 +3114,7 @@ class AsBoolean(UnaryExpression): self._is_implicitly_boolean = element._is_implicitly_boolean def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self def _cache_key(self, **kw): @@ -3155,6 +3206,8 @@ class BinaryExpression(ColumnElement): ) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement + if operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3191,6 +3244,7 @@ class Slice(ColumnElement): self.type = type_api.NULLTYPE def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement assert against is operator.getitem return self @@ -3215,6 +3269,7 @@ class Grouping(ColumnElement): self.type = getattr(element, "type", type_api.NULLTYPE) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self @util.memoized_property @@ -3363,13 +3418,12 @@ class Over(ColumnElement): self.element = element if order_by is not None: self.order_by = ClauseList( - *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) if partition_by is not None: self.partition_by = ClauseList( *util.to_list(partition_by), - _literal_as_text=_literal_as_label_reference + _literal_as_text_role=roles.ByOfRole ) if range_: @@ -3534,8 +3588,7 @@ class WithinGroup(ColumnElement): self.element = element if order_by is not None: self.order_by = ClauseList( - *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) def over(self, partition_by=None, order_by=None, range_=None, rows=None): @@ -3658,7 +3711,7 @@ class FunctionFilter(ColumnElement): """ for criterion in list(criterion): - criterion = _expression_literal_as_text(criterion) + criterion = coercions.expect(roles.WhereHavingRole, criterion) if self.criterion is not None: self.criterion = self.criterion & criterion @@ -3727,7 +3780,7 @@ class FunctionFilter(ColumnElement): ) -class Label(ColumnElement): +class Label(roles.LabeledColumnExprRole, ColumnElement): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3801,6 +3854,7 @@ class Label(ColumnElement): return self._element.self_group(against=operators.as_) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): @@ -3849,7 +3903,7 @@ class Label(ColumnElement): return e -class ColumnClause(Immutable, ColumnElement): +class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): """Represents a column expression from any textual string. The :class:`.ColumnClause`, a lightweight analogue to the @@ -3985,14 +4039,14 @@ class ColumnClause(Immutable, ColumnElement): if ( self.is_literal or self.table is None - or self.table._textual + or self.table._is_textual or not hasattr(other, "proxy_set") or ( isinstance(other, ColumnClause) and ( other.is_literal or other.table is None - or other.table._textual + or other.table._is_textual ) ) ): @@ -4083,7 +4137,7 @@ class ColumnClause(Immutable, ColumnElement): counter += 1 label = _label - return _as_truncated(label) + return coercions.expect(roles.TruncatedLabelRole, label) else: return name @@ -4110,7 +4164,7 @@ class ColumnClause(Immutable, ColumnElement): # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - _as_truncated(name or self.name) + coercions.expect(roles.TruncatedLabelRole, name or self.name) if name_is_truncatable else (name or self.name), type_=self.type, @@ -4250,6 +4304,108 @@ class quoted_name(util.MemoizedSlots, util.text_type): return "'%s'" % backslashed +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ) + + +def _cloned_difference(a, b): + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + ) + + +def _find_columns(clause): + """locate Column objects within the given expression.""" + + cols = util.column_set() + traverse(clause, {}, {"column": cols.add}) + return cols + + +def _type_from_args(args): + for a in args: + if not a.type._isnull: + return a.type + else: + return type_api.NULLTYPE + + +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column( + column, require_embedded=require_embedded + ) + if c is None: + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " + "failed to locate a corresponding column from table '%s'" + % (column, getattr(column, "table", None), fromclause.description) + ) + return c + + +class AnnotatedColumnElement(Annotated): + def __init__(self, element, values): + Annotated.__init__(self, element, values) + ColumnElement.comparator._reset(self) + for attr in ("name", "key", "table"): + if self.__dict__.get(attr, False) is None: + self.__dict__.pop(attr) + + def _with_annotations(self, values): + clone = super(AnnotatedColumnElement, self)._with_annotations(values) + ColumnElement.comparator._reset(clone) + return clone + + @util.memoized_property + def name(self): + """pull 'name' from parent, if not present""" + return self._Annotated__element.name + + @util.memoized_property + def table(self): + """pull 'table' from parent, if not present""" + return self._Annotated__element.table + + @util.memoized_property + def key(self): + """pull 'key' from parent, if not present""" + return self._Annotated__element.key + + @util.memoized_property + def info(self): + return self._Annotated__element.info + + @util.memoized_property + def anon_label(self): + return self._Annotated__element.anon_label + + class _truncated_label(quoted_name): """A unicode subclass used to identify symbolic " "names that may require truncation.""" @@ -4378,349 +4534,3 @@ class _anonymous_label(_truncated_label): else: # else skip the constructor call return self % map_ - - -def _as_truncated(value): - """coerce the given value to :class:`._truncated_label`. - - Existing :class:`._truncated_label` and - :class:`._anonymous_label` objects are passed - unchanged. - """ - - if isinstance(value, _truncated_label): - return value - else: - return _truncated_label(value) - - -def _string_or_unprintable(element): - if isinstance(element, util.string_types): - return element - else: - try: - return str(element) - except Exception: - return "unprintable element %r" % element - - -def _expand_cloned(elements): - """expand the given set of ClauseElements to be the set of all 'cloned' - predecessors. - - """ - return itertools.chain(*[x._cloned_set for x in elements]) - - -def _select_iterables(elements): - """expand tables into individual columns in the - given list of column expressions. - - """ - return itertools.chain(*[c._select_iterable for c in elements]) - - -def _cloned_intersection(a, b): - """return the intersection of sets a and b, counting - any overlap between 'cloned' predecessors. - - The returned set is in terms of the entities present within 'a'. - - """ - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if all_overlap.intersection(elem._cloned_set) - ) - - -def _cloned_difference(a, b): - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if not all_overlap.intersection(elem._cloned_set) - ) - - -@util.dependencies("sqlalchemy.sql.functions") -def _labeled(functions, element): - if not hasattr(element, "name") or isinstance( - element, functions.FunctionElement - ): - return element.label(None) - else: - return element - - -def _is_column(col): - """True if ``col`` is an instance of :class:`.ColumnElement`.""" - - return isinstance(col, ColumnElement) - - -def _find_columns(clause): - """locate Column objects within the given expression.""" - - cols = util.column_set() - traverse(clause, {}, {"column": cols.add}) - return cols - - -# there is some inconsistency here between the usage of -# inspect() vs. checking for Visitable and __clause_element__. -# Ideally all functions here would derive from inspect(), -# however the inspect() versions add significant callcount -# overhead for critical functions like _interpret_as_column_or_from(). -# Generally, the column-based functions are more performance critical -# and are fine just checking for __clause_element__(). It is only -# _interpret_as_from() where we'd like to be able to receive ORM entities -# that have no defined namespace, hence inspect() is needed there. - - -def _column_as_key(element): - if isinstance(element, util.string_types): - return element - if hasattr(element, "__clause_element__"): - element = element.__clause_element__() - try: - return element.key - except AttributeError: - return None - - -def _clause_element_as_expr(element): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - else: - return element - - -def _literal_as_label_reference(element): - if isinstance(element, util.string_types): - return _textual_label_reference(element) - - elif hasattr(element, "__clause_element__"): - element = element.__clause_element__() - - return _literal_as_text(element) - - -def _literal_and_labels_as_label_reference(element): - if isinstance(element, util.string_types): - return _textual_label_reference(element) - - elif hasattr(element, "__clause_element__"): - element = element.__clause_element__() - - if ( - isinstance(element, ColumnElement) - and element._order_by_label_element is not None - ): - return _label_reference(element) - else: - return _literal_as_text(element) - - -def _expression_literal_as_text(element): - return _literal_as_text(element) - - -def _literal_as(element, text_fallback): - if isinstance(element, Visitable): - return element - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif isinstance(element, util.string_types): - return text_fallback(element) - elif isinstance(element, (util.NoneType, bool)): - return _const_expr(element) - else: - raise exc.ArgumentError( - "SQL expression object expected, got object of type %r " - "instead" % type(element) - ) - - -def _literal_as_text(element, allow_coercion_to_text=False): - if allow_coercion_to_text: - return _literal_as(element, TextClause) - else: - return _literal_as(element, _no_text_coercion) - - -def _literal_as_column(element): - return _literal_as(element, ColumnClause) - - -def _no_column_coercion(element): - element = str(element) - guess_is_literal = not _guess_straight_column.match(element) - raise exc.ArgumentError( - "Textual column expression %(column)r should be " - "explicitly declared with text(%(column)r), " - "or use %(literal_column)s(%(column)r) " - "for more specificity" - % { - "column": util.ellipses_string(element), - "literal_column": "literal_column" - if guess_is_literal - else "column", - } - ) - - -def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None): - raise exc_cls( - "%(extra)sTextual SQL expression %(expr)r should be " - "explicitly declared as text(%(expr)r)" - % { - "expr": util.ellipses_string(element), - "extra": "%s " % extra if extra else "", - } - ) - - -def _no_literals(element): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif not isinstance(element, Visitable): - raise exc.ArgumentError( - "Ambiguous literal: %r. Use the 'text()' " - "function to indicate a SQL expression " - "literal, or 'literal()' to indicate a " - "bound value." % (element,) - ) - else: - return element - - -def _is_literal(element): - return not isinstance(element, Visitable) and not hasattr( - element, "__clause_element__" - ) - - -def _only_column_elements_or_none(element, name): - if element is None: - return None - else: - return _only_column_elements(element, name) - - -def _only_column_elements(element, name): - if hasattr(element, "__clause_element__"): - element = element.__clause_element__() - if not isinstance(element, ColumnElement): - raise exc.ArgumentError( - "Column-based expression object expected for argument " - "'%s'; got: '%s', type %s" % (name, element, type(element)) - ) - return element - - -def _literal_as_binds(element, name=None, type_=None): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif not isinstance(element, Visitable): - if element is None: - return Null() - else: - return BindParameter(name, element, type_=type_, unique=True) - else: - return element - - -_guess_straight_column = re.compile(r"^\w\S*$", re.I) - - -def _interpret_as_column_or_from(element): - if isinstance(element, Visitable): - return element - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - - insp = inspection.inspect(element, raiseerr=False) - if insp is None: - if isinstance(element, (util.NoneType, bool)): - return _const_expr(element) - elif hasattr(insp, "selectable"): - return insp.selectable - - # be forgiving as this is an extremely common - # and known expression - if element == "*": - guess_is_literal = True - elif isinstance(element, (numbers.Number)): - return ColumnClause(str(element), is_literal=True) - else: - _no_column_coercion(element) - return ColumnClause(element, is_literal=guess_is_literal) - - -def _const_expr(element): - if isinstance(element, (Null, False_, True_)): - return element - elif element is None: - return Null() - elif element is False: - return False_() - elif element is True: - return True_() - else: - raise exc.ArgumentError("Expected None, False, or True") - - -def _type_from_args(args): - for a in args: - if not a.type._isnull: - return a.type - else: - return type_api.NULLTYPE - - -def _corresponding_column_or_error(fromclause, column, require_embedded=False): - c = fromclause.corresponding_column( - column, require_embedded=require_embedded - ) - if c is None: - raise exc.InvalidRequestError( - "Given column '%s', attached to table '%s', " - "failed to locate a corresponding column from table '%s'" - % (column, getattr(column, "table", None), fromclause.description) - ) - return c - - -class AnnotatedColumnElement(Annotated): - def __init__(self, element, values): - Annotated.__init__(self, element, values) - ColumnElement.comparator._reset(self) - for attr in ("name", "key", "table"): - if self.__dict__.get(attr, False) is None: - self.__dict__.pop(attr) - - def _with_annotations(self, values): - clone = super(AnnotatedColumnElement, self)._with_annotations(values) - ColumnElement.comparator._reset(clone) - return clone - - @util.memoized_property - def name(self): - """pull 'name' from parent, if not present""" - return self._Annotated__element.name - - @util.memoized_property - def table(self): - """pull 'table' from parent, if not present""" - return self._Annotated__element.table - - @util.memoized_property - def key(self): - """pull 'key' from parent, if not present""" - return self._Annotated__element.key - - @util.memoized_property - def info(self): - return self._Annotated__element.info - - @util.memoized_property - def anon_label(self): - return self._Annotated__element.anon_label diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f381879ce..b04355cf5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -67,7 +67,6 @@ __all__ = [ "outerjoin", "over", "select", - "subquery", "table", "text", "tuple_", @@ -92,22 +91,7 @@ from .dml import Insert # noqa from .dml import Update # noqa from .dml import UpdateBase # noqa from .dml import ValuesBase # noqa -from .elements import _clause_element_as_expr # noqa -from .elements import _clone # noqa -from .elements import _cloned_difference # noqa -from .elements import _cloned_intersection # noqa -from .elements import _column_as_key # noqa -from .elements import _corresponding_column_or_error # noqa -from .elements import _expression_literal_as_text # noqa -from .elements import _is_column # noqa -from .elements import _labeled # noqa -from .elements import _literal_as_binds # noqa -from .elements import _literal_as_column # noqa -from .elements import _literal_as_label_reference # noqa -from .elements import _literal_as_text # noqa -from .elements import _only_column_elements # noqa from .elements import _select_iterables # noqa -from .elements import _string_or_unprintable # noqa from .elements import _truncated_label # noqa from .elements import between # noqa from .elements import BinaryExpression # noqa @@ -147,7 +131,6 @@ from .functions import func # noqa from .functions import Function # noqa from .functions import FunctionElement # noqa from .functions import modifier # noqa -from .selectable import _interpret_as_from # noqa from .selectable import Alias # noqa from .selectable import CompoundSelect # noqa from .selectable import CTE # noqa @@ -160,6 +143,7 @@ from .selectable import HasPrefixes # noqa from .selectable import HasSuffixes # noqa from .selectable import Join # noqa from .selectable import Lateral # noqa +from .selectable import ReturnsRows # noqa from .selectable import ScalarSelect # noqa from .selectable import Select # noqa from .selectable import Selectable # noqa @@ -171,7 +155,6 @@ from .selectable import TextAsFrom # noqa from .visitors import Visitable # noqa from ..util.langhelpers import public_factory # noqa - # factory functions - these pull class-bound constructors and classmethods # from SQL elements and selectables into public functions. This allows # the functions to be available in the sqlalchemy.sql.* namespace and diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index d0aa23988..173789998 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -9,14 +9,15 @@ """ from . import annotation +from . import coercions from . import operators +from . import roles from . import schema from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable from .elements import _clone -from .elements import _literal_as_binds from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -83,7 +84,12 @@ class FunctionElement(Executable, ColumnElement, FromClause): """Construct a :class:`.FunctionElement`. """ args = [ - _literal_as_binds(c, getattr(self, "name", None)) for c in clauses + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + ) + for c in clauses ] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( @@ -686,7 +692,12 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: - parsed_args = [_literal_as_binds(c, self.name) for c in args] + parsed_args = [ + coercions.expect( + roles.ExpressionElementRole, c, name=self.name + ) + for c in args + ] self._has_args = self._has_args or bool(parsed_args) self.packagenames = [] self._bind = kwargs.get("bind", None) @@ -751,7 +762,10 @@ class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c, self.name) for c in args] + args = [ + coercions.expect(roles.ExpressionElementRole, c, name=self.name) + for c in args + ] kwargs.setdefault("type_", _type_from_args(args)) kwargs["_parsed_args"] = args super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) @@ -880,7 +894,7 @@ class array_agg(GenericFunction): type = sqltypes.ARRAY def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] + args = [coercions.expect(roles.ExpressionElementRole, c) for c in args] default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 8479c1d59..b8bbb4525 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1053,7 +1053,7 @@ class ColumnOperators(Operators): expr = 5 == mytable.c.somearray.any_() # mysql '5 = ANY (SELECT value FROM table)' - expr = 5 == select([table.c.value]).as_scalar().any_() + expr = 5 == select([table.c.value]).scalar_subquery().any_() .. seealso:: @@ -1078,7 +1078,7 @@ class ColumnOperators(Operators): expr = 5 == mytable.c.somearray.all_() # mysql '5 = ALL (SELECT value FROM table)' - expr = 5 == select([table.c.value]).as_scalar().all_() + expr = 5 == select([table.c.value]).scalar_subquery().all_() .. seealso:: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py new file mode 100644 index 000000000..2d3aaf903 --- /dev/null +++ b/lib/sqlalchemy/sql/roles.py @@ -0,0 +1,157 @@ +# sql/roles.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +class SQLRole(object): + """Define a "role" within a SQL statement structure. + + Classes within SQL Core participate within SQLRole hierarchies in order + to more accurately indicate where they may be used within SQL statements + of all types. + + .. versionadded:: 1.4 + + """ + + +class UsesInspection(object): + pass + + +class ColumnArgumentRole(SQLRole): + _role_name = "Column expression" + + +class ColumnArgumentOrKeyRole(ColumnArgumentRole): + _role_name = "Column expression or string key" + + +class ColumnListRole(SQLRole): + """Elements suitable for forming comma separated lists of expressions.""" + + +class TruncatedLabelRole(SQLRole): + _role_name = "String SQL identifier" + + +class ColumnsClauseRole(UsesInspection, ColumnListRole): + _role_name = "Column expression or FROM clause" + + @property + def _select_iterable(self): + raise NotImplementedError() + + +class LimitOffsetRole(SQLRole): + _role_name = "LIMIT / OFFSET expression" + + +class ByOfRole(ColumnListRole): + _role_name = "GROUP BY / OF / etc. expression" + + +class OrderByRole(ByOfRole): + _role_name = "ORDER BY expression" + + +class StructuralRole(SQLRole): + pass + + +class StatementOptionRole(StructuralRole): + _role_name = "statement sub-expression element" + + +class WhereHavingRole(StructuralRole): + _role_name = "SQL expression for WHERE/HAVING role" + + +class ExpressionElementRole(SQLRole): + _role_name = "SQL expression element" + + +class ConstExprRole(ExpressionElementRole): + _role_name = "Constant True/False/None expression" + + +class LabeledColumnExprRole(ExpressionElementRole): + pass + + +class BinaryElementRole(ExpressionElementRole): + _role_name = "SQL expression element or literal value" + + +class InElementRole(SQLRole): + _role_name = ( + "IN expression list, SELECT construct, or bound parameter object" + ) + + +class FromClauseRole(ColumnsClauseRole): + _role_name = "FROM expression, such as a Table or alias() object" + + @property + def _hide_froms(self): + raise NotImplementedError() + + +class CoerceTextStatementRole(SQLRole): + _role_name = "Executable SQL, text() construct, or string statement" + + +class StatementRole(CoerceTextStatementRole): + _role_name = "Executable SQL or text() construct" + + +class ReturnsRowsRole(StatementRole): + _role_name = ( + "Row returning expression such as a SELECT, or an " + "INSERT/UPDATE/DELETE with RETURNING" + ) + + +class SelectStatementRole(ReturnsRowsRole): + _role_name = "SELECT construct or equivalent text() construct" + + +class HasCTERole(ReturnsRowsRole): + pass + + +class CompoundElementRole(SQLRole): + """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc.""" + + _role_name = ( + "SELECT construct for inclusion in a UNION or other set construct" + ) + + +class DMLRole(StatementRole): + pass + + +class DMLColumnRole(SQLRole): + _role_name = "SET/VALUES column expression or string key" + + +class DMLSelectRole(SQLRole): + """A SELECT statement embedded in DML, typically INSERT from SELECT """ + + _role_name = "SELECT statement or equivalent textual object" + + +class DDLRole(StatementRole): + pass + + +class DDLExpressionRole(StructuralRole): + _role_name = "SQL expression element for DDL constraint" + + +class DDLConstraintColumnRole(SQLRole): + _role_name = "String column name or column object for DDL constraint" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b045e006e..62ff25a64 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -34,16 +34,16 @@ import collections import operator import sqlalchemy +from . import coercions from . import ddl +from . import roles from . import type_api from . import visitors from .base import _bind_or_error from .base import ColumnCollection from .base import DialectKWArgs from .base import SchemaEventTarget -from .elements import _as_truncated -from .elements import _document_text_coercion -from .elements import _literal_as_text +from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -1583,7 +1583,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) try: c = self._constructor( - _as_truncated(name or self.name) + coercions.expect( + roles.TruncatedLabelRole, name if name else self.name + ) if name_is_truncatable else (name or self.name), self.type, @@ -2109,13 +2111,19 @@ class ForeignKey(DialectKWArgs, SchemaItem): class _NotAColumnExpr(object): + # the coercions system is not used in crud.py for the values passed in + # the insert().values() and update().values() methods, so the usual + # pathways to rejecting a coercion in the unlikely case of adding defaut + # generator objects to insert() or update() constructs aren't available; + # create a quick coercion rejection here that is specific to what crud.py + # calls on value objects. def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " "as a column expression." % self.__class__.__name__ ) - __clause_element__ = self_group = lambda self: self._not_a_column_expr() + self_group = lambda self: self._not_a_column_expr() # noqa _from_objects = property(lambda self: self._not_a_column_expr()) @@ -2274,7 +2282,7 @@ class ColumnDefault(DefaultGenerator): return "ColumnDefault(%r)" % (self.arg,) -class Sequence(DefaultGenerator): +class Sequence(roles.StatementRole, DefaultGenerator): """Represents a named database sequence. The :class:`.Sequence` object represents the name and configurational @@ -2759,25 +2767,6 @@ class ColumnCollectionMixin(object): if _autoattach and self._pending_colargs: self._check_attach() - @classmethod - def _extract_col_expression_collection(cls, expressions): - for expr in expressions: - strname = None - column = None - if hasattr(expr, "__clause_element__"): - expr = expr.__clause_element__() - - if not isinstance(expr, (ColumnElement, TextClause)): - # this assumes a string - strname = expr - else: - cols = [] - visitors.traverse(expr, {}, {"column": cols.append}) - if cols: - column = cols[0] - add_element = column if column is not None else strname - yield expr, column, strname, add_element - def _check_attach(self, evt=False): col_objs = [c for c in self._pending_colargs if isinstance(c, Column)] @@ -2960,7 +2949,7 @@ class CheckConstraint(ColumnCollectionConstraint): """ - self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True) + self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext) columns = [] visitors.traverse(self.sqltext, {}, {"column": columns.append}) @@ -3630,7 +3619,9 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): column, strname, add_element, - ) in self._extract_col_expression_collection(expressions): + ) in coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ): if add_element is not None: columns.append(add_element) processed_expressions.append(expr) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5167182fe..41be9fc5a 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -15,8 +15,9 @@ import itertools import operator from operator import attrgetter -from sqlalchemy.sql.visitors import Visitable +from . import coercions from . import operators +from . import roles from . import type_api from .annotation import Annotated from .base import _from_objects @@ -26,18 +27,12 @@ from .base import ColumnSet from .base import Executable from .base import Generative from .base import Immutable +from .coercions import _document_text_coercion from .elements import _anonymous_label -from .elements import _clause_element_as_expr from .elements import _clone from .elements import _cloned_difference from .elements import _cloned_intersection -from .elements import _document_text_coercion from .elements import _expand_cloned -from .elements import _interpret_as_column_or_from -from .elements import _literal_and_labels_as_label_reference -from .elements import _literal_as_label_reference -from .elements import _literal_as_text -from .elements import _no_text_coercion from .elements import _select_iterables from .elements import and_ from .elements import BindParameter @@ -48,75 +43,15 @@ from .elements import literal_column from .elements import True_ from .elements import UnaryExpression from .. import exc -from .. import inspection from .. import util -def _interpret_as_from(element): - insp = inspection.inspect(element, raiseerr=False) - if insp is None: - if isinstance(element, util.string_types): - _no_text_coercion(element) - try: - return insp.selectable - except AttributeError: - raise exc.ArgumentError("FROM expression expected") - - -def _interpret_as_select(element): - element = _interpret_as_from(element) - if isinstance(element, Alias): - element = element.original - if not isinstance(element, SelectBase): - element = element.select() - return element - - class _OffsetLimitParam(BindParameter): @property def _limit_offset_value(self): return self.effective_value -def _offset_or_limit_clause(element, name=None, type_=None): - """Convert the given value to an "offset or limit" clause. - - This handles incoming integers and converts to an expression; if - an expression is already given, it is passed through. - - """ - if element is None: - return None - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif isinstance(element, Visitable): - return element - else: - value = util.asint(element) - return _OffsetLimitParam(name, value, type_=type_, unique=True) - - -def _offset_or_limit_clause_asint(clause, attrname): - """Convert the "offset or limit" clause of a select construct to an - integer. - - This is only possible if the value is stored as a simple bound parameter. - Otherwise, a compilation error is raised. - - """ - if clause is None: - return None - try: - value = clause._limit_offset_value - except AttributeError: - raise exc.CompileError( - "This SELECT structure does not use a simple " - "integer value for %s" % attrname - ) - else: - return util.asint(value) - - def subquery(alias, *args, **kwargs): r"""Return an :class:`.Alias` object derived from a :class:`.Select`. @@ -133,8 +68,42 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) -class Selectable(ClauseElement): - """mark a class as being selectable""" +class ReturnsRows(roles.ReturnsRowsRole, ClauseElement): + """The basemost class for Core contructs that have some concept of + columns that can represent rows. + + While the SELECT statement and TABLE are the primary things we think + of in this category, DML like INSERT, UPDATE and DELETE can also specify + RETURNING which means they can be used in CTEs and other forms, and + PostgreSQL has functions that return rows also. + + .. versionadded:: 1.4 + + """ + + _is_returns_rows = True + + # sub-elements of returns_rows + _is_from_clause = False + _is_select_statement = False + _is_lateral = False + + @property + def selectable(self): + raise NotImplementedError( + "This object is a base ReturnsRows object, but is not a " + "FromClause so has no .c. collection." + ) + + +class Selectable(ReturnsRows): + """mark a class as being selectable. + + This class is legacy as of 1.4 as the concept of a SQL construct which + "returns rows" is more generalized than one which can be the subject + of a SELECT. + + """ __visit_name__ = "selectable" @@ -190,7 +159,7 @@ class HasPrefixes(object): def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( [ - (_literal_as_text(p, allow_coercion_to_text=True), dialect) + (coercions.expect(roles.StatementOptionRole, p), dialect) for p in prefixes ] ) @@ -236,13 +205,13 @@ class HasSuffixes(object): def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( [ - (_literal_as_text(p, allow_coercion_to_text=True), dialect) + (coercions.expect(roles.StatementOptionRole, p), dialect) for p in suffixes ] ) -class FromClause(Selectable): +class FromClause(roles.FromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -265,16 +234,6 @@ class FromClause(Selectable): named_with_column = False _hide_froms = [] - _is_join = False - _is_select = False - _is_from_container = False - - _is_lateral = False - - _textual = False - """a marker that allows us to easily distinguish a :class:`.TextAsFrom` - or similar object from other kinds of :class:`.FromClause` objects.""" - schema = None """Define the 'schema' attribute for this :class:`.FromClause`. @@ -284,6 +243,11 @@ class FromClause(Selectable): """ + is_selectable = has_selectable = True + _is_from_clause = True + _is_text_as_from = False + _is_join = False + def _translate_schema(self, effective_schema, map_): return effective_schema @@ -726,8 +690,8 @@ class Join(FromClause): :class:`.FromClause` object. """ - self.left = _interpret_as_from(left) - self.right = _interpret_as_from(right).self_group() + self.left = coercions.expect(roles.FromClauseRole, left) + self.right = coercions.expect(roles.FromClauseRole, right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) @@ -1292,7 +1256,9 @@ class Alias(FromClause): .. versionadded:: 0.9.0 """ - return _interpret_as_from(selectable).alias(name=name, flat=flat) + return coercions.expect(roles.FromClauseRole, selectable).alias( + name=name, flat=flat + ) def _init(self, selectable, name=None): baseselectable = selectable @@ -1327,14 +1293,6 @@ class Alias(FromClause): else: return self.name.encode("ascii", "backslashreplace") - def as_scalar(self): - try: - return self.element.as_scalar() - except AttributeError: - raise AttributeError( - "Element %s does not support " "'as_scalar()'" % self.element - ) - def is_derived_from(self, fromclause): if fromclause in self._cloned_set: return True @@ -1426,7 +1384,9 @@ class Lateral(Alias): :ref:`lateral_selects` - overview of usage. """ - return _interpret_as_from(selectable).lateral(name=name) + return coercions.expect(roles.FromClauseRole, selectable).lateral( + name=name + ) class TableSample(Alias): @@ -1488,7 +1448,7 @@ class TableSample(Alias): REPEATABLE sub-clause is also rendered. """ - return _interpret_as_from(selectable).tablesample( + return coercions.expect(roles.FromClauseRole, selectable).tablesample( sampling, name=name, seed=seed ) @@ -1523,7 +1483,7 @@ class CTE(Generative, HasSuffixes, Alias): Please see :meth:`.HasCte.cte` for detail on CTE usage. """ - return _interpret_as_from(selectable).cte( + return coercions.expect(roles.HasCTERole, selectable).cte( name=name, recursive=recursive ) @@ -1588,7 +1548,7 @@ class CTE(Generative, HasSuffixes, Alias): ) -class HasCTE(object): +class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 @@ -2059,13 +2019,22 @@ class ForUpdateArg(ClauseElement): self.key_share = key_share if of is not None: self.of = [ - _interpret_as_column_or_from(elem) for elem in util.to_list(of) + coercions.expect(roles.ColumnsClauseRole, elem) + for elem in util.to_list(of) ] else: self.of = None -class SelectBase(HasCTE, Executable, FromClause): +class SelectBase( + roles.SelectStatementRole, + roles.DMLSelectRole, + roles.CompoundElementRole, + roles.InElementRole, + HasCTE, + Executable, + FromClause, +): """Base class for SELECT statements. @@ -2075,15 +2044,32 @@ class SelectBase(HasCTE, Executable, FromClause): """ + _is_select_statement = True + + @util.deprecated( + "1.4", + "The :meth:`.SelectBase.as_scalar` method is deprecated and will be " + "removed in a future release. Please refer to " + ":meth:`.SelectBase.scalar_subquery`.", + ) def as_scalar(self): + return self.scalar_subquery() + + def scalar_subquery(self): """return a 'scalar' representation of this selectable, which can be used as a column expression. Typically, a select statement which has only one column in its columns - clause is eligible to be used as a scalar expression. + clause is eligible to be used as a scalar expression. The scalar + subquery can then be used in the WHERE clause or columns clause of + an enclosing SELECT. - The returned object is an instance of - :class:`ScalarSelect`. + Note that the scalar subquery differentiates from the FROM-level + subquery that can be produced using the :meth:`.SelectBase.subquery` + method. + + .. versionchanged: 1.4 - the ``.as_scalar()`` method was renamed to + :meth:`.SelectBase.scalar_subquery`. """ return ScalarSelect(self) @@ -2097,7 +2083,7 @@ class SelectBase(HasCTE, Executable, FromClause): :meth:`~.SelectBase.as_scalar`. """ - return self.as_scalar().label(name) + return self.scalar_subquery().label(name) @_generative @util.deprecated( @@ -2181,20 +2167,19 @@ class GenerativeSelect(SelectBase): {"autocommit": autocommit} ) if limit is not None: - self._limit_clause = _offset_or_limit_clause(limit) + self._limit_clause = self._offset_or_limit_clause(limit) if offset is not None: - self._offset_clause = _offset_or_limit_clause(offset) + self._offset_clause = self._offset_or_limit_clause(offset) self._bind = bind if order_by is not None: self._order_by_clause = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_and_labels_as_label_reference + _literal_as_text_role=roles.OrderByRole ) if group_by is not None: self._group_by_clause = ClauseList( - *util.to_list(group_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(group_by), _literal_as_text_role=roles.ByOfRole ) @property @@ -2287,6 +2272,37 @@ class GenerativeSelect(SelectBase): """ self.use_labels = True + def _offset_or_limit_clause(self, element, name=None, type_=None): + """Convert the given value to an "offset or limit" clause. + + This handles incoming integers and converts to an expression; if + an expression is already given, it is passed through. + + """ + return coercions.expect( + roles.LimitOffsetRole, element, name=name, type_=type_ + ) + + def _offset_or_limit_clause_asint(self, clause, attrname): + """Convert the "offset or limit" clause of a select construct to an + integer. + + This is only possible if the value is stored as a simple bound + parameter. Otherwise, a compilation error is raised. + + """ + if clause is None: + return None + try: + value = clause._limit_offset_value + except AttributeError: + raise exc.CompileError( + "This SELECT structure does not use a simple " + "integer value for %s" % attrname + ) + else: + return util.asint(value) + @property def _limit(self): """Get an integer value for the limit. This should only be used @@ -2295,7 +2311,7 @@ class GenerativeSelect(SelectBase): isn't currently set to an integer. """ - return _offset_or_limit_clause_asint(self._limit_clause, "limit") + return self._offset_or_limit_clause_asint(self._limit_clause, "limit") @property def _simple_int_limit(self): @@ -2319,7 +2335,9 @@ class GenerativeSelect(SelectBase): offset isn't currently set to an integer. """ - return _offset_or_limit_clause_asint(self._offset_clause, "offset") + return self._offset_or_limit_clause_asint( + self._offset_clause, "offset" + ) @_generative def limit(self, limit): @@ -2339,7 +2357,7 @@ class GenerativeSelect(SelectBase): """ - self._limit_clause = _offset_or_limit_clause(limit) + self._limit_clause = self._offset_or_limit_clause(limit) @_generative def offset(self, offset): @@ -2361,7 +2379,7 @@ class GenerativeSelect(SelectBase): """ - self._offset_clause = _offset_or_limit_clause(offset) + self._offset_clause = self._offset_or_limit_clause(offset) @_generative def order_by(self, *clauses): @@ -2403,8 +2421,7 @@ class GenerativeSelect(SelectBase): if getattr(self, "_order_by_clause", None) is not None: clauses = list(self._order_by_clause) + list(clauses) self._order_by_clause = ClauseList( - *clauses, - _literal_as_text=_literal_and_labels_as_label_reference + *clauses, _literal_as_text_role=roles.OrderByRole ) def append_group_by(self, *clauses): @@ -2423,7 +2440,7 @@ class GenerativeSelect(SelectBase): if getattr(self, "_group_by_clause", None) is not None: clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList( - *clauses, _literal_as_text=_literal_as_label_reference + *clauses, _literal_as_text_role=roles.ByOfRole ) @property @@ -2478,7 +2495,7 @@ class CompoundSelect(GenerativeSelect): # some DBs do not like ORDER BY in the inner queries of a UNION, etc. for n, s in enumerate(selects): - s = _clause_element_as_expr(s) + s = coercions.expect(roles.CompoundElementRole, s) if not numcols: numcols = len(s.c._all_columns) @@ -2741,7 +2758,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): _correlate = () _correlate_except = None _memoized_property = SelectBase._memoized_property - _is_select = True @util.deprecated_params( autocommit=( @@ -2965,12 +2981,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._distinct = True else: self._distinct = [ - _literal_as_text(e) for e in util.to_list(distinct) + coercions.expect(roles.WhereHavingRole, e) + for e in util.to_list(distinct) ] if from_obj is not None: self._from_obj = util.OrderedSet( - _interpret_as_from(f) for f in util.to_list(from_obj) + coercions.expect(roles.FromClauseRole, f) + for f in util.to_list(from_obj) ) else: self._from_obj = util.OrderedSet() @@ -2986,7 +3004,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): if cols_present: self._raw_columns = [] for c in columns: - c = _interpret_as_column_or_from(c) + c = coercions.expect(roles.ColumnsClauseRole, c) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) self._raw_columns.append(c) @@ -2994,16 +3012,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._raw_columns = [] if whereclause is not None: - self._whereclause = _literal_as_text(whereclause).self_group( - against=operators._asbool - ) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ).self_group(against=operators._asbool) else: self._whereclause = None if having is not None: - self._having = _literal_as_text(having).self_group( - against=operators._asbool - ) + self._having = coercions.expect( + roles.WhereHavingRole, having + ).self_group(against=operators._asbool) else: self._having = None @@ -3202,15 +3220,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: self._hints = self._hints.union({(selectable, dialect_name): text}) - @property - def type(self): - raise exc.InvalidRequestError( - "Select objects don't have a type. " - "Call as_scalar() on this Select " - "object to return a 'scalar' version " - "of this Select." - ) - @_memoized_property.method def locate_all_froms(self): """return a Set of all FromClause elements referenced by this Select. @@ -3496,7 +3505,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._reset_exported() rc = [] for c in columns: - c = _interpret_as_column_or_from(c) + c = coercions.expect(roles.ColumnsClauseRole, c) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) rc.append(c) @@ -3530,7 +3539,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ if expr: - expr = [_literal_as_label_reference(e) for e in expr] + expr = [coercions.expect(roles.ByOfRole, e) for e in expr] if isinstance(self._distinct, list): self._distinct = self._distinct + expr else: @@ -3618,7 +3627,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate = () else: self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclauses + coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @_generative @@ -3653,7 +3662,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate_except = () else: self._correlate_except = set(self._correlate_except or ()).union( - _interpret_as_from(f) for f in fromclauses + coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) def append_correlation(self, fromclause): @@ -3668,7 +3677,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._auto_correlate = False self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclause + coercions.expect(roles.FromClauseRole, f) for f in fromclause ) def append_column(self, column): @@ -3689,7 +3698,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - column = _interpret_as_column_or_from(column) + column = coercions.expect(roles.ColumnsClauseRole, column) if isinstance(column, ScalarSelect): column = column.self_group(against=operators.comma_op) @@ -3705,7 +3714,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): standard :term:`method chaining`. """ - clause = _literal_as_text(clause) + clause = coercions.expect(roles.WhereHavingRole, clause) self._prefixes = self._prefixes + (clause,) def append_whereclause(self, whereclause): @@ -3747,7 +3756,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - fromclause = _interpret_as_from(fromclause) + fromclause = coercions.expect(roles.FromClauseRole, fromclause) self._from_obj = self._from_obj.union([fromclause]) @_memoized_property @@ -3894,7 +3903,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): bind = property(bind, _set_bind) -class ScalarSelect(Generative, Grouping): +class ScalarSelect(roles.InElementRole, Generative, Grouping): _from_objects = [] _is_from_container = True _is_implicitly_boolean = False @@ -3956,7 +3965,7 @@ class Exists(UnaryExpression): else: if not args: args = ([literal_column("*")],) - s = Select(*args, **kwargs).as_scalar().self_group() + s = Select(*args, **kwargs).scalar_subquery().self_group() UnaryExpression.__init__( self, @@ -3999,6 +4008,7 @@ class Exists(UnaryExpression): return e +# TODO: rename to TextualSelect, this is not a FROM clause class TextAsFrom(SelectBase): """Wrap a :class:`.TextClause` construct within a :class:`.SelectBase` interface. @@ -4022,7 +4032,7 @@ class TextAsFrom(SelectBase): __visit_name__ = "text_as_from" - _textual = True + _is_textual = True def __init__(self, text, columns, positional=False): self.element = text diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 0d3944552..6a520a2d5 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -14,13 +14,14 @@ import datetime as dt import decimal import json +from . import coercions from . import elements from . import operators +from . import roles from . import type_api from .base import _bind_or_error from .base import SchemaEventTarget from .elements import _defer_name -from .elements import _literal_as_binds from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa @@ -2187,19 +2188,21 @@ class JSON(Indexable, TypeEngine): if not isinstance(index, util.string_types) and isinstance( index, compat.collections_abc.Sequence ): - index = default_comparator._check_literal( - self.expr, - operators.json_path_getitem_op, + index = coercions.expect( + roles.BinaryElementRole, index, + expr=self.expr, + operator=operators.json_path_getitem_op, bindparam_type=JSON.JSONPathType, ) operator = operators.json_path_getitem_op else: - index = default_comparator._check_literal( - self.expr, - operators.json_getitem_op, + index = coercions.expect( + roles.BinaryElementRole, index, + expr=self.expr, + operator=operators.json_getitem_op, bindparam_type=JSON.JSONIndexType, ) operator = operators.json_getitem_op @@ -2372,17 +2375,20 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if self.type.zero_indexes: index = slice(index.start + 1, index.stop + 1, index.step) index = Slice( - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.start, name=self.expr.key, type_=type_api.INTEGERTYPE, ), - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.stop, name=self.expr.key, type_=type_api.INTEGERTYPE, ), - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.step, name=self.expr.key, type_=type_api.INTEGERTYPE, @@ -2438,7 +2444,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ operator = operator if operator else operators.eq return operator( - elements._literal_as_binds(other), + coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_any(self.expr), ) @@ -2473,7 +2479,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ operator = operator if operator else operators.eq return operator( - elements._literal_as_binds(other), + coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_all(self.expr), ) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bdeae9613..5eea27e08 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -57,6 +57,9 @@ class TypeEngine(Visitable): default_comparator = None + def __clause_element__(self): + return self.expr + def __init__(self, expr): self.expr = expr self.type = expr.type |