diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-08-12 17:50:37 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-08-12 17:50:37 -0400 |
commit | f6198d9abf453182f4b111e0579a7a4ef1614e79 (patch) | |
tree | e258eafc9db70c4745d98a56b55b439732aebf91 /lib/sqlalchemy/sql/compiler.py | |
parent | e8c2a2738b6c15cb12e7571b9e12c15cc2f200c9 (diff) | |
download | sqlalchemy-f6198d9abf453182f4b111e0579a7a4ef1614e79.tar.gz |
- A large refactoring of the ``sqlalchemy.sql`` package has reorganized
the import structure of many core modules.
``sqlalchemy.schema`` and ``sqlalchemy.types``
remain in the top-level package, but are now just lists of names
that pull from within ``sqlalchemy.sql``. Their implementations
are now broken out among ``sqlalchemy.sql.type_api``, ``sqlalchemy.sql.sqltypes``,
``sqlalchemy.sql.schema`` and ``sqlalchemy.sql.ddl``, the last of which was
moved from ``sqlalchemy.engine``. ``sqlalchemy.sql.expression`` is also
a namespace now which pulls implementations mostly from ``sqlalchemy.sql.elements``,
``sqlalchemy.sql.selectable``, and ``sqlalchemy.sql.dml``.
Most of the "factory" functions
used to create SQL expression objects have been moved to classmethods
or constructors, which are exposed in ``sqlalchemy.sql.expression``
using a programmatic system. Care has been taken such that all the
original import namespaces remain intact and there should be no impact
on any existing applications. The rationale here was to break out these
very large modules into smaller ones, provide more manageable lists
of function names, to greatly reduce "import cycles" and clarify the
up-front importing of names, and to remove the need for redundant
functions and documentation throughout the expression package.
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 216 |
1 files changed, 159 insertions, 57 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index daed7c50f..a6e6987c5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -23,11 +23,9 @@ To generate user-defined SQL strings, see """ import re -import sys -from .. import schema, engine, util, exc, types -from . import ( - operators, functions, util as sql_util, visitors, expression as sql -) +from . import schema, sqltypes, operators, functions, \ + util as sql_util, visitors, elements, selectable +from .. import util, exc import decimal import itertools @@ -150,14 +148,118 @@ EXTRACT_MAP = { } COMPOUND_KEYWORDS = { - sql.CompoundSelect.UNION: 'UNION', - sql.CompoundSelect.UNION_ALL: 'UNION ALL', - sql.CompoundSelect.EXCEPT: 'EXCEPT', - sql.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', - sql.CompoundSelect.INTERSECT: 'INTERSECT', - sql.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' + selectable.CompoundSelect.UNION: 'UNION', + selectable.CompoundSelect.UNION_ALL: 'UNION ALL', + selectable.CompoundSelect.EXCEPT: 'EXCEPT', + selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', + selectable.CompoundSelect.INTERSECT: 'INTERSECT', + selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' } +class Compiled(object): + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + def __init__(self, dialect, statement, bind=None, + compile_kwargs=util.immutabledict()): + """Construct a new ``Compiled`` object. + + :param dialect: ``Dialect`` to compile against. + + :param statement: ``ClauseElement`` to be compiled. + + :param bind: Optional Engine or Connection to compile this + statement against. + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + .. versionadded:: 0.8 + + """ + + self.dialect = dialect + self.bind = bind + if statement is not None: + self.statement = statement + self.can_execute = statement.supports_execution + self.string = self.process(self.statement, **compile_kwargs) + + @util.deprecated("0.7", ":class:`.Compiled` objects now compile " + "within the constructor.") + def compile(self): + """Produce the internal string representation of this element.""" + pass + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self): + """Return the string text of the generated SQL or DDL.""" + + return self.string or '' + + def construct_params(self, params=None): + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + def execute(self, *multiparams, **params): + """Execute this compiled object.""" + + e = self.bind + if e is None: + raise exc.UnboundExecutionError( + "This Compiled object is not bound to any Engine " + "or Connection.") + return e._execute_compiled(self, multiparams, params) + + def scalar(self, *multiparams, **params): + """Execute this compiled object and return the result's + scalar value.""" + + return self.execute(*multiparams, **params).scalar() + + +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + + class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" @@ -183,7 +285,7 @@ class _CompileLabel(visitors.Visitable): return self.element.quote -class SQLCompiler(engine.Compiled): +class SQLCompiler(Compiled): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -284,7 +386,7 @@ class SQLCompiler(engine.Compiled): # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} - engine.Compiled.__init__(self, dialect, statement, **kwargs) + Compiled.__init__(self, dialect, statement, **kwargs) if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -397,7 +499,7 @@ class SQLCompiler(engine.Compiled): render_label_only = render_label_as_label is label if render_label_only or render_label_with_as: - if isinstance(label.name, sql._truncated_label): + if isinstance(label.name, elements._truncated_label): labelname = self._truncated_identifier("colident", label.name) else: labelname = label.name @@ -432,7 +534,7 @@ class SQLCompiler(engine.Compiled): "its 'name' is assigned.") is_literal = column.is_literal - if not is_literal and isinstance(name, sql._truncated_label): + if not is_literal and isinstance(name, elements._truncated_label): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: @@ -459,7 +561,7 @@ class SQLCompiler(engine.Compiled): else: schema_prefix = '' tablename = table.name - if isinstance(tablename, sql._truncated_label): + if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) return schema_prefix + \ @@ -687,8 +789,8 @@ class SQLCompiler(engine.Compiled): def visit_binary(self, binary, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ - isinstance(binary.left, sql.BindParameter) and \ - isinstance(binary.right, sql.BindParameter): + isinstance(binary.left, elements.BindParameter) and \ + isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True operator = binary.operator @@ -728,7 +830,7 @@ class SQLCompiler(engine.Compiled): @util.memoized_property def _like_percent_literal(self): - return sql.literal_column("'%'", type_=types.String()) + return elements.literal_column("'%'", type_=sqltypes.String()) def visit_contains_op_binary(self, binary, operator, **kw): binary = binary._clone() @@ -888,7 +990,7 @@ class SQLCompiler(engine.Compiled): return self.bind_names[bindparam] bind_name = bindparam.key - if isinstance(bind_name, sql._truncated_label): + if isinstance(bind_name, elements._truncated_label): bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation @@ -937,7 +1039,7 @@ class SQLCompiler(engine.Compiled): if self.positional: kwargs['positional_names'] = self.cte_positional - if isinstance(cte.name, sql._truncated_label): + if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: cte_name = cte.name @@ -966,7 +1068,7 @@ class SQLCompiler(engine.Compiled): if orig_cte not in self.ctes: self.visit_cte(orig_cte) cte_alias_name = cte._cte_alias.name - if isinstance(cte_alias_name, sql._truncated_label): + if isinstance(cte_alias_name, elements._truncated_label): cte_alias_name = self._truncated_identifier("alias", cte_alias_name) else: orig_cte = cte @@ -976,9 +1078,9 @@ class SQLCompiler(engine.Compiled): self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.original, sql.Select): + if isinstance(cte.original, selectable.Select): col_source = cte.original - elif isinstance(cte.original, sql.CompoundSelect): + elif isinstance(cte.original, selectable.CompoundSelect): col_source = cte.original.selects[0] else: assert False @@ -1007,7 +1109,7 @@ class SQLCompiler(engine.Compiled): iscrud=False, fromhints=None, **kwargs): if asfrom or ashint: - if isinstance(alias.name, sql._truncated_label): + if isinstance(alias.name, elements._truncated_label): alias_name = self._truncated_identifier("alias", alias.name) else: alias_name = alias.name @@ -1065,7 +1167,7 @@ class SQLCompiler(engine.Compiled): if not within_columns_clause: result_expr = col_expr - elif isinstance(column, sql.Label): + elif isinstance(column, elements.Label): if col_expr is not column: result_expr = _CompileLabel( col_expr, @@ -1084,23 +1186,23 @@ class SQLCompiler(engine.Compiled): elif \ asfrom and \ - isinstance(column, sql.ColumnClause) and \ + isinstance(column, elements.ColumnClause) and \ not column.is_literal and \ column.table is not None and \ - not isinstance(column.table, sql.Select): + not isinstance(column.table, selectable.Select): result_expr = _CompileLabel(col_expr, - sql._as_truncated(column.name), + elements._as_truncated(column.name), alt_names=(column.key,)) elif not isinstance(column, - (sql.UnaryExpression, sql.TextClause)) \ + (elements.UnaryExpression, elements.TextClause)) \ and (not hasattr(column, 'name') or \ - isinstance(column, sql.Function)): + isinstance(column, functions.Function)): result_expr = _CompileLabel(col_expr, column.anon_label) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? - # assert isinstance(column, sql.ColumnClause) + # assert isinstance(column, elements.ColumnClause) result_expr = _CompileLabel(col_expr, - sql._as_truncated(column.name), + elements._as_truncated(column.name), alt_names=(column.key,)) else: result_expr = col_expr @@ -1143,8 +1245,8 @@ class SQLCompiler(engine.Compiled): # as this whole system won't work for custom Join/Select # subclasses where compilation routines # call down to compiler.visit_join(), compiler.visit_select() - join_name = sql.Join.__visit_name__ - select_name = sql.Select.__visit_name__ + join_name = selectable.Join.__visit_name__ + select_name = selectable.Select.__visit_name__ def visit(element, **kw): if element in column_translate[-1]: @@ -1156,25 +1258,25 @@ class SQLCompiler(engine.Compiled): newelem = cloned[element] = element._clone() if newelem.__visit_name__ is join_name and \ - isinstance(newelem.right, sql.FromGrouping): + isinstance(newelem.right, selectable.FromGrouping): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) right = visit(newelem.right, **kw) - selectable = sql.select( + selectable_ = selectable.Select( [right.element], use_labels=True).alias() - for c in selectable.c: + for c in selectable_.c: c._key_label = c.key c._label = c.name translate_dict = dict( - zip(right.element.c, selectable.c) + zip(right.element.c, selectable_.c) ) - translate_dict[right.element.left] = selectable - translate_dict[right.element.right] = selectable + translate_dict[right.element.left] = selectable_ + translate_dict[right.element.right] = selectable_ # propagate translations that we've gained # from nested visit(newelem.right) outwards @@ -1190,7 +1292,7 @@ class SQLCompiler(engine.Compiled): column_translate[-1].update(translate_dict) - newelem.right = selectable + newelem.right = selectable_ newelem.onclause = visit(newelem.onclause, **kw) elif newelem.__visit_name__ is select_name: column_translate.append({}) @@ -1299,7 +1401,7 @@ class SQLCompiler(engine.Compiled): explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms) - new_correlate_froms = set(sql._from_objects(*froms)) + new_correlate_froms = set(selectable._from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) new_entry = { @@ -1461,11 +1563,11 @@ class SQLCompiler(engine.Compiled): def limit_clause(self, select): text = "" if select._limit is not None: - text += "\n LIMIT " + self.process(sql.literal(select._limit)) + text += "\n LIMIT " + self.process(elements.literal(select._limit)) if select._offset is not None: if select._limit is None: text += "\n LIMIT -1" - text += " OFFSET " + self.process(sql.literal(select._offset)) + text += " OFFSET " + self.process(elements.literal(select._offset)) return text def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, @@ -1692,7 +1794,7 @@ class SQLCompiler(engine.Compiled): def _create_crud_bind_param(self, col, value, required=False, name=None): if name is None: name = col.key - bindparam = sql.bindparam(name, value, + bindparam = elements.BindParameter(name, value, type_=col.type, required=required, quote=col.quote) bindparam._is_crud = True @@ -1732,7 +1834,7 @@ class SQLCompiler(engine.Compiled): if self.column_keys is None: parameters = {} else: - parameters = dict((sql._column_as_key(key), REQUIRED) + parameters = dict((elements._column_as_key(key), REQUIRED) for key in self.column_keys if not stmt_parameters or key not in stmt_parameters) @@ -1742,15 +1844,15 @@ class SQLCompiler(engine.Compiled): if stmt_parameters is not None: for k, v in stmt_parameters.items(): - colkey = sql._column_as_key(k) + colkey = elements._column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) else: # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param - if sql._is_literal(v): - v = self.process(sql.bindparam(None, v, type_=k.type)) + if elements._is_literal(v): + v = self.process(elements.BindParameter(None, v, type_=k.type)) else: v = self.process(v.self_group()) @@ -1771,7 +1873,7 @@ class SQLCompiler(engine.Compiled): # statements if extra_tables and stmt_parameters: normalized_params = dict( - (sql._clause_element_as_expr(c), param) + (elements._clause_element_as_expr(c), param) for c, param in stmt_parameters.items() ) assert self.isupdate @@ -1782,7 +1884,7 @@ class SQLCompiler(engine.Compiled): affected_tables.add(t) check_columns[c.key] = c value = normalized_params[c] - if sql._is_literal(value): + if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED) else: @@ -1816,7 +1918,7 @@ class SQLCompiler(engine.Compiled): for c in stmt.table.columns: if c.key in parameters and c.key not in check_columns: value = parameters.pop(c.key) - if sql._is_literal(value): + if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED, name=c.key @@ -1918,7 +2020,7 @@ class SQLCompiler(engine.Compiled): if parameters and stmt_parameters: check = set(parameters).intersection( - sql._column_as_key(k) for k in stmt.parameters + elements._column_as_key(k) for k in stmt.parameters ).difference(check_columns) if check: raise exc.CompileError( @@ -2013,7 +2115,7 @@ class SQLCompiler(engine.Compiled): self.preparer.format_savepoint(savepoint_stmt) -class DDLCompiler(engine.Compiled): +class DDLCompiler(Compiled): @util.memoized_property def sql_compiler(self): @@ -2183,7 +2285,7 @@ class DDLCompiler(engine.Compiled): schema_name = None ident = index.name - if isinstance(ident, sql._truncated_label): + if isinstance(ident, elements._truncated_label): max_ = self.dialect.max_index_name_length or \ self.dialect.max_identifier_length if len(ident) > max_: @@ -2343,7 +2445,7 @@ class DDLCompiler(engine.Compiled): return text -class GenericTypeCompiler(engine.TypeCompiler): +class GenericTypeCompiler(TypeCompiler): def visit_FLOAT(self, type_): return "FLOAT" |