summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-08-12 17:50:37 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2013-08-12 17:50:37 -0400
commitf6198d9abf453182f4b111e0579a7a4ef1614e79 (patch)
treee258eafc9db70c4745d98a56b55b439732aebf91 /lib/sqlalchemy/sql/compiler.py
parente8c2a2738b6c15cb12e7571b9e12c15cc2f200c9 (diff)
downloadsqlalchemy-f6198d9abf453182f4b111e0579a7a4ef1614e79.tar.gz
- A large refactoring of the ``sqlalchemy.sql`` package has reorganized
the import structure of many core modules. ``sqlalchemy.schema`` and ``sqlalchemy.types`` remain in the top-level package, but are now just lists of names that pull from within ``sqlalchemy.sql``. Their implementations are now broken out among ``sqlalchemy.sql.type_api``, ``sqlalchemy.sql.sqltypes``, ``sqlalchemy.sql.schema`` and ``sqlalchemy.sql.ddl``, the last of which was moved from ``sqlalchemy.engine``. ``sqlalchemy.sql.expression`` is also a namespace now which pulls implementations mostly from ``sqlalchemy.sql.elements``, ``sqlalchemy.sql.selectable``, and ``sqlalchemy.sql.dml``. Most of the "factory" functions used to create SQL expression objects have been moved to classmethods or constructors, which are exposed in ``sqlalchemy.sql.expression`` using a programmatic system. Care has been taken such that all the original import namespaces remain intact and there should be no impact on any existing applications. The rationale here was to break out these very large modules into smaller ones, provide more manageable lists of function names, to greatly reduce "import cycles" and clarify the up-front importing of names, and to remove the need for redundant functions and documentation throughout the expression package.
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py216
1 files changed, 159 insertions, 57 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index daed7c50f..a6e6987c5 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -23,11 +23,9 @@ To generate user-defined SQL strings, see
"""
import re
-import sys
-from .. import schema, engine, util, exc, types
-from . import (
- operators, functions, util as sql_util, visitors, expression as sql
-)
+from . import schema, sqltypes, operators, functions, \
+ util as sql_util, visitors, elements, selectable
+from .. import util, exc
import decimal
import itertools
@@ -150,14 +148,118 @@ EXTRACT_MAP = {
}
COMPOUND_KEYWORDS = {
- sql.CompoundSelect.UNION: 'UNION',
- sql.CompoundSelect.UNION_ALL: 'UNION ALL',
- sql.CompoundSelect.EXCEPT: 'EXCEPT',
- sql.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
- sql.CompoundSelect.INTERSECT: 'INTERSECT',
- sql.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
+ selectable.CompoundSelect.UNION: 'UNION',
+ selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
+ selectable.CompoundSelect.EXCEPT: 'EXCEPT',
+ selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
+ selectable.CompoundSelect.INTERSECT: 'INTERSECT',
+ selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
}
+class Compiled(object):
+ """Represent a compiled SQL or DDL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ def __init__(self, dialect, statement, bind=None,
+ compile_kwargs=util.immutabledict()):
+ """Construct a new ``Compiled`` object.
+
+ :param dialect: ``Dialect`` to compile against.
+
+ :param statement: ``ClauseElement`` to be compiled.
+
+ :param bind: Optional Engine or Connection to compile this
+ statement against.
+
+ :param compile_kwargs: additional kwargs that will be
+ passed to the initial call to :meth:`.Compiled.process`.
+
+ .. versionadded:: 0.8
+
+ """
+
+ self.dialect = dialect
+ self.bind = bind
+ if statement is not None:
+ self.statement = statement
+ self.can_execute = statement.supports_execution
+ self.string = self.process(self.statement, **compile_kwargs)
+
+ @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
+ "within the constructor.")
+ def compile(self):
+ """Produce the internal string representation of this element."""
+ pass
+
+ @property
+ def sql_compiler(self):
+ """Return a Compiled that is capable of processing SQL expressions.
+
+ If this compiler is one, it would likely just return 'self'.
+
+ """
+
+ raise NotImplementedError()
+
+ def process(self, obj, **kwargs):
+ return obj._compiler_dispatch(self, **kwargs)
+
+ def __str__(self):
+ """Return the string text of the generated SQL or DDL."""
+
+ return self.string or ''
+
+ def construct_params(self, params=None):
+ """Return the bind params for this compiled object.
+
+ :param params: a dict of string/object pairs whose values will
+ override bind values compiled in to the
+ statement.
+ """
+
+ raise NotImplementedError()
+
+ @property
+ def params(self):
+ """Return the bind params for this compiled object."""
+ return self.construct_params()
+
+ def execute(self, *multiparams, **params):
+ """Execute this compiled object."""
+
+ e = self.bind
+ if e is None:
+ raise exc.UnboundExecutionError(
+ "This Compiled object is not bound to any Engine "
+ "or Connection.")
+ return e._execute_compiled(self, multiparams, params)
+
+ def scalar(self, *multiparams, **params):
+ """Execute this compiled object and return the result's
+ scalar value."""
+
+ return self.execute(*multiparams, **params).scalar()
+
+
+class TypeCompiler(object):
+ """Produces DDL specification for TypeEngine objects."""
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def process(self, type_):
+ return type_._compiler_dispatch(self)
+
+
class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
@@ -183,7 +285,7 @@ class _CompileLabel(visitors.Visitable):
return self.element.quote
-class SQLCompiler(engine.Compiled):
+class SQLCompiler(Compiled):
"""Default implementation of Compiled.
Compiles ClauseElements into SQL strings. Uses a similar visit
@@ -284,7 +386,7 @@ class SQLCompiler(engine.Compiled):
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
- engine.Compiled.__init__(self, dialect, statement, **kwargs)
+ Compiled.__init__(self, dialect, statement, **kwargs)
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@@ -397,7 +499,7 @@ class SQLCompiler(engine.Compiled):
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
- if isinstance(label.name, sql._truncated_label):
+ if isinstance(label.name, elements._truncated_label):
labelname = self._truncated_identifier("colident", label.name)
else:
labelname = label.name
@@ -432,7 +534,7 @@ class SQLCompiler(engine.Compiled):
"its 'name' is assigned.")
is_literal = column.is_literal
- if not is_literal and isinstance(name, sql._truncated_label):
+ if not is_literal and isinstance(name, elements._truncated_label):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
@@ -459,7 +561,7 @@ class SQLCompiler(engine.Compiled):
else:
schema_prefix = ''
tablename = table.name
- if isinstance(tablename, sql._truncated_label):
+ if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \
@@ -687,8 +789,8 @@ class SQLCompiler(engine.Compiled):
def visit_binary(self, binary, **kw):
# don't allow "? = ?" to render
if self.ansi_bind_rules and \
- isinstance(binary.left, sql.BindParameter) and \
- isinstance(binary.right, sql.BindParameter):
+ isinstance(binary.left, elements.BindParameter) and \
+ isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
operator = binary.operator
@@ -728,7 +830,7 @@ class SQLCompiler(engine.Compiled):
@util.memoized_property
def _like_percent_literal(self):
- return sql.literal_column("'%'", type_=types.String())
+ return elements.literal_column("'%'", type_=sqltypes.String())
def visit_contains_op_binary(self, binary, operator, **kw):
binary = binary._clone()
@@ -888,7 +990,7 @@ class SQLCompiler(engine.Compiled):
return self.bind_names[bindparam]
bind_name = bindparam.key
- if isinstance(bind_name, sql._truncated_label):
+ if isinstance(bind_name, elements._truncated_label):
bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
@@ -937,7 +1039,7 @@ class SQLCompiler(engine.Compiled):
if self.positional:
kwargs['positional_names'] = self.cte_positional
- if isinstance(cte.name, sql._truncated_label):
+ if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
@@ -966,7 +1068,7 @@ class SQLCompiler(engine.Compiled):
if orig_cte not in self.ctes:
self.visit_cte(orig_cte)
cte_alias_name = cte._cte_alias.name
- if isinstance(cte_alias_name, sql._truncated_label):
+ if isinstance(cte_alias_name, elements._truncated_label):
cte_alias_name = self._truncated_identifier("alias", cte_alias_name)
else:
orig_cte = cte
@@ -976,9 +1078,9 @@ class SQLCompiler(engine.Compiled):
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
- if isinstance(cte.original, sql.Select):
+ if isinstance(cte.original, selectable.Select):
col_source = cte.original
- elif isinstance(cte.original, sql.CompoundSelect):
+ elif isinstance(cte.original, selectable.CompoundSelect):
col_source = cte.original.selects[0]
else:
assert False
@@ -1007,7 +1109,7 @@ class SQLCompiler(engine.Compiled):
iscrud=False,
fromhints=None, **kwargs):
if asfrom or ashint:
- if isinstance(alias.name, sql._truncated_label):
+ if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
else:
alias_name = alias.name
@@ -1065,7 +1167,7 @@ class SQLCompiler(engine.Compiled):
if not within_columns_clause:
result_expr = col_expr
- elif isinstance(column, sql.Label):
+ elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
col_expr,
@@ -1084,23 +1186,23 @@ class SQLCompiler(engine.Compiled):
elif \
asfrom and \
- isinstance(column, sql.ColumnClause) and \
+ isinstance(column, elements.ColumnClause) and \
not column.is_literal and \
column.table is not None and \
- not isinstance(column.table, sql.Select):
+ not isinstance(column.table, selectable.Select):
result_expr = _CompileLabel(col_expr,
- sql._as_truncated(column.name),
+ elements._as_truncated(column.name),
alt_names=(column.key,))
elif not isinstance(column,
- (sql.UnaryExpression, sql.TextClause)) \
+ (elements.UnaryExpression, elements.TextClause)) \
and (not hasattr(column, 'name') or \
- isinstance(column, sql.Function)):
+ isinstance(column, functions.Function)):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
- # assert isinstance(column, sql.ColumnClause)
+ # assert isinstance(column, elements.ColumnClause)
result_expr = _CompileLabel(col_expr,
- sql._as_truncated(column.name),
+ elements._as_truncated(column.name),
alt_names=(column.key,))
else:
result_expr = col_expr
@@ -1143,8 +1245,8 @@ class SQLCompiler(engine.Compiled):
# as this whole system won't work for custom Join/Select
# subclasses where compilation routines
# call down to compiler.visit_join(), compiler.visit_select()
- join_name = sql.Join.__visit_name__
- select_name = sql.Select.__visit_name__
+ join_name = selectable.Join.__visit_name__
+ select_name = selectable.Select.__visit_name__
def visit(element, **kw):
if element in column_translate[-1]:
@@ -1156,25 +1258,25 @@ class SQLCompiler(engine.Compiled):
newelem = cloned[element] = element._clone()
if newelem.__visit_name__ is join_name and \
- isinstance(newelem.right, sql.FromGrouping):
+ isinstance(newelem.right, selectable.FromGrouping):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
right = visit(newelem.right, **kw)
- selectable = sql.select(
+ selectable_ = selectable.Select(
[right.element],
use_labels=True).alias()
- for c in selectable.c:
+ for c in selectable_.c:
c._key_label = c.key
c._label = c.name
translate_dict = dict(
- zip(right.element.c, selectable.c)
+ zip(right.element.c, selectable_.c)
)
- translate_dict[right.element.left] = selectable
- translate_dict[right.element.right] = selectable
+ translate_dict[right.element.left] = selectable_
+ translate_dict[right.element.right] = selectable_
# propagate translations that we've gained
# from nested visit(newelem.right) outwards
@@ -1190,7 +1292,7 @@ class SQLCompiler(engine.Compiled):
column_translate[-1].update(translate_dict)
- newelem.right = selectable
+ newelem.right = selectable_
newelem.onclause = visit(newelem.onclause, **kw)
elif newelem.__visit_name__ is select_name:
column_translate.append({})
@@ -1299,7 +1401,7 @@ class SQLCompiler(engine.Compiled):
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms)
- new_correlate_froms = set(sql._from_objects(*froms))
+ new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
@@ -1461,11 +1563,11 @@ class SQLCompiler(engine.Compiled):
def limit_clause(self, select):
text = ""
if select._limit is not None:
- text += "\n LIMIT " + self.process(sql.literal(select._limit))
+ text += "\n LIMIT " + self.process(elements.literal(select._limit))
if select._offset is not None:
if select._limit is None:
text += "\n LIMIT -1"
- text += " OFFSET " + self.process(sql.literal(select._offset))
+ text += " OFFSET " + self.process(elements.literal(select._offset))
return text
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
@@ -1692,7 +1794,7 @@ class SQLCompiler(engine.Compiled):
def _create_crud_bind_param(self, col, value, required=False, name=None):
if name is None:
name = col.key
- bindparam = sql.bindparam(name, value,
+ bindparam = elements.BindParameter(name, value,
type_=col.type, required=required,
quote=col.quote)
bindparam._is_crud = True
@@ -1732,7 +1834,7 @@ class SQLCompiler(engine.Compiled):
if self.column_keys is None:
parameters = {}
else:
- parameters = dict((sql._column_as_key(key), REQUIRED)
+ parameters = dict((elements._column_as_key(key), REQUIRED)
for key in self.column_keys
if not stmt_parameters or
key not in stmt_parameters)
@@ -1742,15 +1844,15 @@ class SQLCompiler(engine.Compiled):
if stmt_parameters is not None:
for k, v in stmt_parameters.items():
- colkey = sql._column_as_key(k)
+ colkey = elements._column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
else:
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
- if sql._is_literal(v):
- v = self.process(sql.bindparam(None, v, type_=k.type))
+ if elements._is_literal(v):
+ v = self.process(elements.BindParameter(None, v, type_=k.type))
else:
v = self.process(v.self_group())
@@ -1771,7 +1873,7 @@ class SQLCompiler(engine.Compiled):
# statements
if extra_tables and stmt_parameters:
normalized_params = dict(
- (sql._clause_element_as_expr(c), param)
+ (elements._clause_element_as_expr(c), param)
for c, param in stmt_parameters.items()
)
assert self.isupdate
@@ -1782,7 +1884,7 @@ class SQLCompiler(engine.Compiled):
affected_tables.add(t)
check_columns[c.key] = c
value = normalized_params[c]
- if sql._is_literal(value):
+ if elements._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is REQUIRED)
else:
@@ -1816,7 +1918,7 @@ class SQLCompiler(engine.Compiled):
for c in stmt.table.columns:
if c.key in parameters and c.key not in check_columns:
value = parameters.pop(c.key)
- if sql._is_literal(value):
+ if elements._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is REQUIRED,
name=c.key
@@ -1918,7 +2020,7 @@ class SQLCompiler(engine.Compiled):
if parameters and stmt_parameters:
check = set(parameters).intersection(
- sql._column_as_key(k) for k in stmt.parameters
+ elements._column_as_key(k) for k in stmt.parameters
).difference(check_columns)
if check:
raise exc.CompileError(
@@ -2013,7 +2115,7 @@ class SQLCompiler(engine.Compiled):
self.preparer.format_savepoint(savepoint_stmt)
-class DDLCompiler(engine.Compiled):
+class DDLCompiler(Compiled):
@util.memoized_property
def sql_compiler(self):
@@ -2183,7 +2285,7 @@ class DDLCompiler(engine.Compiled):
schema_name = None
ident = index.name
- if isinstance(ident, sql._truncated_label):
+ if isinstance(ident, elements._truncated_label):
max_ = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length
if len(ident) > max_:
@@ -2343,7 +2445,7 @@ class DDLCompiler(engine.Compiled):
return text
-class GenericTypeCompiler(engine.TypeCompiler):
+class GenericTypeCompiler(TypeCompiler):
def visit_FLOAT(self, type_):
return "FLOAT"