summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorCarlos Rivas <carlos@twobitcoder.com>2016-01-26 13:45:31 -0800
committerCarlos Rivas <carlos@twobitcoder.com>2016-01-26 13:45:31 -0800
commitc6d630ca819239bf1b18bd6e51f265fb1be951c9 (patch)
treee30838e4e462d7994cc69d0c281a2d4a88b89edf /lib/sqlalchemy/sql/compiler.py
parent28365040ace29c9ceea28946ed19f07c3a4fcefc (diff)
parent8163de4cc9e01460d3476b9fb3ed14a5b3e70bae (diff)
downloadsqlalchemy-c6d630ca819239bf1b18bd6e51f265fb1be951c9.tar.gz
Merged zzzeek/sqlalchemy into master
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py190
1 files changed, 139 insertions, 51 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6766c99b7..492999d16 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -167,25 +167,39 @@ class Compiled(object):
_cached_metadata = None
def __init__(self, dialect, statement, bind=None,
+ schema_translate_map=None,
compile_kwargs=util.immutabledict()):
- """Construct a new ``Compiled`` object.
+ """Construct a new :class:`.Compiled` object.
- :param dialect: ``Dialect`` to compile against.
+ :param dialect: :class:`.Dialect` to compile against.
- :param statement: ``ClauseElement`` to be compiled.
+ :param statement: :class:`.ClauseElement` to be compiled.
:param bind: Optional Engine or Connection to compile this
statement against.
+ :param schema_translate_map: dictionary of schema names to be
+ translated when forming the resultant SQL
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`schema_translating`
+
:param compile_kwargs: additional kwargs that will be
passed to the initial call to :meth:`.Compiled.process`.
- .. versionadded:: 0.8
"""
self.dialect = dialect
self.bind = bind
+ self.preparer = self.dialect.identifier_preparer
+ if schema_translate_map:
+ self.preparer = self.preparer._with_schema_translate(
+ schema_translate_map)
+
if statement is not None:
self.statement = statement
self.can_execute = statement.supports_execution
@@ -286,12 +300,11 @@ class _CompileLabel(visitors.Visitable):
def self_group(self, **kw):
return self
-class SQLCompiler(Compiled):
- """Default implementation of Compiled.
+class SQLCompiler(Compiled):
+ """Default implementation of :class:`.Compiled`.
- Compiles ClauseElements into SQL strings. Uses a similar visit
- paradigm as visitors.ClauseVisitor but implements its own traversal.
+ Compiles :class:`.ClauseElement` objects into SQL strings.
"""
@@ -305,6 +318,8 @@ class SQLCompiler(Compiled):
INSERT/UPDATE/DELETE
"""
+ isplaintext = False
+
returning = None
"""holds the "returning" collection of columns if
the statement is CRUD and defines returning columns
@@ -330,19 +345,34 @@ class SQLCompiler(Compiled):
driver/DB enforces this
"""
+ _textual_ordered_columns = False
+ """tell the result object that the column names as rendered are important,
+ but they are also "ordered" vs. what is in the compiled object here.
+ """
+
+ _ordered_columns = True
+ """
+ if False, means we can't be sure the list of entries
+ in _result_columns is actually the rendered order. Usually
+ True unless using an unordered TextAsFrom.
+ """
+
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
- """Construct a new ``DefaultCompiler`` object.
+ """Construct a new :class:`.SQLCompiler` object.
- dialect
- Dialect to be used
+ :param dialect: :class:`.Dialect` to be used
- statement
- ClauseElement to be compiled
+ :param statement: :class:`.ClauseElement` to be compiled
- column_keys
- a list of column names to be compiled into an INSERT or UPDATE
- statement.
+ :param column_keys: a list of column names to be compiled into an
+ INSERT or UPDATE statement.
+
+ :param inline: whether to generate INSERT statements as "inline", e.g.
+ not formatted to return any generated defaults
+
+ :param kwargs: additional keyword arguments to be consumed by the
+ superclass.
"""
self.column_keys = column_keys
@@ -368,11 +398,6 @@ class SQLCompiler(Compiled):
# column targeting
self._result_columns = []
- # if False, means we can't be sure the list of entries
- # in _result_columns is actually the rendered order. This
- # gets flipped when we use TextAsFrom, for example.
- self._ordered_columns = True
-
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
@@ -381,8 +406,6 @@ class SQLCompiler(Compiled):
self.ctes = None
- # an IdentifierPreparer that formats the quoting of identifiers
- self.preparer = dialect.identifier_preparer
self.label_length = dialect.label_length \
or dialect.max_identifier_length
@@ -649,8 +672,11 @@ class SQLCompiler(Compiled):
if table is None or not include_table or not table.named_with_column:
return name
else:
- if table.schema:
- schema_prefix = self.preparer.quote_schema(table.schema) + '.'
+ effective_schema = self.preparer.schema_for_object(table)
+
+ if effective_schema:
+ schema_prefix = self.preparer.quote_schema(
+ effective_schema) + '.'
else:
schema_prefix = ''
tablename = table.name
@@ -688,6 +714,9 @@ class SQLCompiler(Compiled):
else:
return self.bindparam_string(name, **kw)
+ if not self.stack:
+ self.isplaintext = True
+
# un-escape any \:params
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
@@ -711,7 +740,8 @@ class SQLCompiler(Compiled):
) or entry.get('need_result_map_for_nested', False)
if populate_result_map:
- self._ordered_columns = False
+ self._ordered_columns = \
+ self._textual_ordered_columns = taf.positional
for c in taf.column_args:
self.process(c, within_columns_clause=True,
add_to_result_map=self._add_to_result_map)
@@ -873,22 +903,28 @@ class SQLCompiler(Compiled):
else:
return text
+ def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
+ attrname = "visit_%s_%s%s" % (
+ operator_.__name__, qualifier1,
+ "_" + qualifier2 if qualifier2 else "")
+ return getattr(self, attrname, None)
+
def visit_unary(self, unary, **kw):
if unary.operator:
if unary.modifier:
raise exc.CompileError(
"Unary expression does not support operator "
"and modifier simultaneously")
- disp = getattr(self, "visit_%s_unary_operator" %
- unary.operator.__name__, None)
+ disp = self._get_operator_dispatch(
+ unary.operator, "unary", "operator")
if disp:
return disp(unary, unary.operator, **kw)
else:
return self._generate_generic_unary_operator(
unary, OPERATORS[unary.operator], **kw)
elif unary.modifier:
- disp = getattr(self, "visit_%s_unary_modifier" %
- unary.modifier.__name__, None)
+ disp = self._get_operator_dispatch(
+ unary.modifier, "unary", "modifier")
if disp:
return disp(unary, unary.modifier, **kw)
else:
@@ -922,7 +958,7 @@ class SQLCompiler(Compiled):
kw['literal_binds'] = True
operator_ = override_operator or binary.operator
- disp = getattr(self, "visit_%s_binary" % operator_.__name__, None)
+ disp = self._get_operator_dispatch(operator_, "binary", None)
if disp:
return disp(binary, operator_, **kw)
else:
@@ -1298,7 +1334,7 @@ class SQLCompiler(Compiled):
add_to_result_map = lambda keyname, name, objects, type_: \
self._add_to_result_map(
keyname, name,
- objects + (column,), type_)
+ (column,) + objects, type_)
else:
col_expr = column
if populate_result_map:
@@ -1386,7 +1422,7 @@ class SQLCompiler(Compiled):
"""Rewrite any "a JOIN (b JOIN c)" expression as
"a JOIN (select * from b JOIN c) AS anon", to support
databases that can't parse a parenthesized join correctly
- (i.e. sqlite the main one).
+ (i.e. sqlite < 3.7.16).
"""
cloned = {}
@@ -1801,8 +1837,10 @@ class SQLCompiler(Compiled):
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
if asfrom or ashint:
- if use_schema and getattr(table, "schema", None):
- ret = self.preparer.quote_schema(table.schema) + \
+ effective_schema = self.preparer.schema_for_object(table)
+
+ if use_schema and effective_schema:
+ ret = self.preparer.quote_schema(effective_schema) + \
"." + self.preparer.quote(table.name)
else:
ret = self.preparer.quote(table.name)
@@ -2080,6 +2118,30 @@ class SQLCompiler(Compiled):
self.preparer.format_savepoint(savepoint_stmt)
+class StrSQLCompiler(SQLCompiler):
+ """"a compiler subclass with a few non-standard SQL features allowed.
+
+ Used for stringification of SQL statements when a real dialect is not
+ available.
+
+ """
+
+ def visit_getitem_binary(self, binary, operator, **kw):
+ return "%s[%s]" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)
+ )
+
+ def returning_clause(self, stmt, returning_cols):
+
+ columns = [
+ self._label_select_column(None, c, True, False, {})
+ for c in elements._select_iterables(returning_cols)
+ ]
+
+ return 'RETURNING ' + ', '.join(columns)
+
+
class DDLCompiler(Compiled):
@util.memoized_property
@@ -2090,10 +2152,6 @@ class DDLCompiler(Compiled):
def type_compiler(self):
return self.dialect.type_compiler
- @property
- def preparer(self):
- return self.dialect.identifier_preparer
-
def construct_params(self, params=None):
return None
@@ -2103,7 +2161,7 @@ class DDLCompiler(Compiled):
if isinstance(ddl.target, schema.Table):
context = context.copy()
- preparer = self.dialect.identifier_preparer
+ preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
table, sch = path[0], ''
@@ -2129,7 +2187,7 @@ class DDLCompiler(Compiled):
def visit_create_table(self, create):
table = create.element
- preparer = self.dialect.identifier_preparer
+ preparer = self.preparer
text = "\nCREATE "
if table._prefixes:
@@ -2256,9 +2314,12 @@ class DDLCompiler(Compiled):
index, include_schema=True)
def _prepared_index_name(self, index, include_schema=False):
- if include_schema and index.table is not None and index.table.schema:
- schema = index.table.schema
- schema_name = self.preparer.quote_schema(schema)
+ if index.table is not None:
+ effective_schema = self.preparer.schema_for_object(index.table)
+ else:
+ effective_schema = None
+ if include_schema and effective_schema:
+ schema_name = self.preparer.quote_schema(effective_schema)
else:
schema_name = None
@@ -2386,7 +2447,7 @@ class DDLCompiler(Compiled):
return text
def visit_foreign_key_constraint(self, constraint):
- preparer = self.dialect.identifier_preparer
+ preparer = self.preparer
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -2603,6 +2664,17 @@ class GenericTypeCompiler(TypeCompiler):
return type_.get_col_spec(**kw)
+class StrSQLTypeCompiler(GenericTypeCompiler):
+ def __getattr__(self, key):
+ if key.startswith("visit_"):
+ return self._visit_unknown
+ else:
+ raise AttributeError(key)
+
+ def _visit_unknown(self, type_, **kw):
+ return "%s" % type_.__class__.__name__
+
+
class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
@@ -2613,6 +2685,8 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
+ schema_for_object = schema._schema_getter(None)
+
def __init__(self, dialect, initial_quote='"',
final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
@@ -2637,6 +2711,12 @@ class IdentifierPreparer(object):
self.omit_schema = omit_schema
self._strings = {}
+ def _with_schema_translate(self, schema_translate_map):
+ prep = self.__class__.__new__(self.__class__)
+ prep.__dict__.update(self.__dict__)
+ prep.schema_for_object = schema._schema_getter(schema_translate_map)
+ return prep
+
def _escape_identifier(self, value):
"""Escape an identifier.
@@ -2709,9 +2789,12 @@ class IdentifierPreparer(object):
def format_sequence(self, sequence, use_schema=True):
name = self.quote(sequence.name)
+
+ effective_schema = self.schema_for_object(sequence)
+
if (not self.omit_schema and use_schema and
- sequence.schema is not None):
- name = self.quote_schema(sequence.schema) + "." + name
+ effective_schema is not None):
+ name = self.quote_schema(effective_schema) + "." + name
return name
def format_label(self, label, name=None):
@@ -2740,9 +2823,12 @@ class IdentifierPreparer(object):
if name is None:
name = table.name
result = self.quote(name)
+
+ effective_schema = self.schema_for_object(table)
+
if not self.omit_schema and use_schema \
- and getattr(table, "schema", None):
- result = self.quote_schema(table.schema) + "." + result
+ and effective_schema:
+ result = self.quote_schema(effective_schema) + "." + result
return result
def format_schema(self, name, quote=None):
@@ -2781,9 +2867,11 @@ class IdentifierPreparer(object):
# ('database', 'owner', etc.) could override this and return
# a longer sequence.
+ effective_schema = self.schema_for_object(table)
+
if not self.omit_schema and use_schema and \
- getattr(table, 'schema', None):
- return (self.quote_schema(table.schema),
+ effective_schema:
+ return (self.quote_schema(effective_schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )