summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2008-02-14 20:02:10 +0000
committerJason Kirtland <jek@discorporate.us>2008-02-14 20:02:10 +0000
commit71e745e96b8c5be990b3dc949cb99310dd055609 (patch)
tree00c748e65e7e85e0231a1c7c504dec6cfcab8e87 /lib/sqlalchemy/sql/compiler.py
parent8dd5eb402ef65194af4c54a6fd33a181b7d5eaf0 (diff)
downloadsqlalchemy-71e745e96b8c5be990b3dc949cb99310dd055609.tar.gz
- Fixed a couple pyflakes, cleaned up imports & whitespace
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py172
1 files changed, 84 insertions, 88 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 3f32778d6..8d8cfa38f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -7,21 +7,20 @@
"""Base SQL and DDL compiler implementations.
Provides the [sqlalchemy.sql.compiler#DefaultCompiler] class, which is
-responsible for generating all SQL query strings, as well as
+responsible for generating all SQL query strings, as well as
[sqlalchemy.sql.compiler#SchemaGenerator] and [sqlalchemy.sql.compiler#SchemaDropper]
which issue CREATE and DROP DDL for tables, sequences, and indexes.
The elements in this module are used by public-facing constructs like
[sqlalchemy.sql.expression#ClauseElement] and [sqlalchemy.engine#Engine].
While dialect authors will want to be familiar with this module for the purpose of
-creating database-specific compilers and schema generators, the module
+creating database-specific compilers and schema generators, the module
is otherwise internal to SQLAlchemy.
"""
import string, re
from sqlalchemy import schema, engine, util, exceptions
-from sqlalchemy.sql import operators, visitors, functions
-from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
RESERVED_WORDS = util.Set([
@@ -57,7 +56,7 @@ BIND_TEMPLATES = {
'numeric':"%(position)s",
'named':":%(name)s"
}
-
+
OPERATORS = {
operators.and_ : 'AND',
@@ -96,14 +95,14 @@ OPERATORS = {
FUNCTIONS = {
functions.coalesce : 'coalesce%(expr)s',
- functions.current_date: 'CURRENT_DATE',
- functions.current_time: 'CURRENT_TIME',
+ functions.current_date: 'CURRENT_DATE',
+ functions.current_time: 'CURRENT_TIME',
functions.current_timestamp: 'CURRENT_TIMESTAMP',
- functions.current_user: 'CURRENT_USER',
- functions.localtime: 'LOCALTIME',
+ functions.current_user: 'CURRENT_USER',
+ functions.localtime: 'LOCALTIME',
functions.localtimestamp: 'LOCALTIMESTAMP',
functions.sysdate: 'sysdate',
- functions.session_user :'SESSION_USER',
+ functions.session_user :'SESSION_USER',
functions.user: 'USER'
}
@@ -118,7 +117,7 @@ class DefaultCompiler(engine.Compiled):
operators = OPERATORS
functions = FUNCTIONS
-
+
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -132,35 +131,35 @@ class DefaultCompiler(engine.Compiled):
a list of column names to be compiled into an INSERT or UPDATE
statement.
"""
-
+
super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs)
# if we are insert/update/delete. set to true when we visit an INSERT, UPDATE or DELETE
self.isdelete = self.isinsert = self.isupdate = False
-
+
# compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
self.inline = inline or getattr(statement, 'inline', False)
-
+
# a dictionary of bind parameter keys to _BindParamClause instances.
self.binds = {}
-
+
# a dictionary of _BindParamClause instances to "compiled" names that are
# actually present in the generated SQL
self.bind_names = {}
# a stack. what recursive compiler doesn't have a stack ? :)
self.stack = []
-
+
# relates label names in the final SQL to
# a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
# ResultProxy uses this for type processing and column targeting
self.result_map = {}
-
+
# a dictionary of ClauseElement subclasses to counters, which are used to
# generate truncated identifier names or "anonymous" identifiers such as
# for aliases
self.generated_ids = {}
-
+
# paramstyle from the dialect (comes from DB-API)
self.paramstyle = self.dialect.paramstyle
@@ -168,17 +167,17 @@ class DefaultCompiler(engine.Compiled):
self.positional = self.dialect.positional
self.bindtemplate = BIND_TEMPLATES[self.paramstyle]
-
+
# a list of the compiled's bind parameter names, used to help
# formulate a positional argument list
self.positiontup = []
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = self.dialect.identifier_preparer
-
+
def compile(self):
self.string = self.process(self.statement)
-
+
def process(self, obj, stack=None, **kwargs):
if stack:
self.stack.append(stack)
@@ -189,23 +188,23 @@ class DefaultCompiler(engine.Compiled):
finally:
if stack:
self.stack.pop(-1)
-
+
def is_subquery(self, select):
return self.stack and self.stack[-1].get('is_subquery')
-
+
def get_whereclause(self, obj):
- """given a FROM clause, return an additional WHERE condition that should be
- applied to a SELECT.
-
+ """given a FROM clause, return an additional WHERE condition that should be
+ applied to a SELECT.
+
Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
constructs in non-ansi mode.
"""
-
+
return None
def construct_params(self, params=None):
"""return a dictionary of bind parameter keys and values"""
-
+
if params:
pd = {}
for bindparam, name in self.bind_names.iteritems():
@@ -218,9 +217,9 @@ class DefaultCompiler(engine.Compiled):
return pd
else:
return dict([(self.bind_names[bindparam], bindparam.value) for bindparam in self.bind_names])
-
+
params = property(construct_params)
-
+
def default_from(self):
"""Called when a SELECT statement has no froms, and no FROM clause is to be appended.
@@ -228,22 +227,22 @@ class DefaultCompiler(engine.Compiled):
"""
return ""
-
+
def visit_grouping(self, grouping, **kwargs):
return "(" + self.process(grouping.elem) + ")"
-
+
def visit_label(self, label, result_map=None):
labelname = self._truncated_identifier("colident", label.name)
-
+
if result_map is not None:
result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type)
-
+
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
-
+
def visit_column(self, column, result_map=None, use_schema=False, **kwargs):
# there is actually somewhat of a ruleset when you would *not* necessarily
- # want to truncate a column identifier, if its mapped to the name of a
- # physical column. but thats very hard to identify at this point, and
+ # want to truncate a column identifier, if its mapped to the name of a
+ # physical column. but thats very hard to identify at this point, and
# the identifier length should be greater than the id lengths of any physical
# columns so should not matter.
@@ -259,7 +258,7 @@ class DefaultCompiler(engine.Compiled):
if result_map is not None:
result_map[name.lower()] = (name, (column, ), column.type)
-
+
if column._is_oid:
n = self.dialect.oid_column_name(column)
if n is not None:
@@ -288,7 +287,7 @@ class DefaultCompiler(engine.Compiled):
# TODO: some dialects might need different behavior here
return text.replace('%', '%%')
-
+
def visit_fromclause(self, fromclause, **kwargs):
return fromclause.name
@@ -302,7 +301,7 @@ class DefaultCompiler(engine.Compiled):
if textclause.typemap is not None:
for colname, type_ in textclause.typemap.iteritems():
self.result_map[colname.lower()] = (colname, None, type_)
-
+
def do_bindparam(m):
name = m.group(1)
if name in textclause.bindparams:
@@ -311,7 +310,7 @@ class DefaultCompiler(engine.Compiled):
return self.bindparam_string(name)
# un-escape any \:params
- return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
+ return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
BIND_PARAMS.sub(do_bindparam, textclause.text)
)
@@ -339,37 +338,37 @@ class DefaultCompiler(engine.Compiled):
result_map[func.name.lower()] = (func.name, None, func.type)
name = self.function_string(func)
-
+
if callable(name):
return name(*[self.process(x) for x in func.clause_expr])
else:
return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
-
+
def function_argspec(self, func):
return self.process(func.clause_expr)
-
+
def function_string(self, func):
return self.functions.get(func.__class__, func.name + "%(expr)s")
def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
stack_entry = {'select':cs}
-
+
if asfrom:
stack_entry['is_subquery'] = True
elif self.stack and self.stack[-1].get('select'):
stack_entry['is_subquery'] = True
self.stack.append(stack_entry)
-
+
text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ")
group_by = self.process(cs._group_by_clause, asfrom=asfrom)
if group_by:
text += " GROUP BY " + group_by
- text += self.order_by_clause(cs)
+ text += self.order_by_clause(cs)
text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
-
+
self.stack.pop(-1)
-
+
if asfrom and parens:
return "(" + text + ")"
else:
@@ -382,19 +381,17 @@ class DefaultCompiler(engine.Compiled):
if unary.modifier:
s = s + " " + self.operator_string(unary.modifier)
return s
-
+
def visit_binary(self, binary, **kwargs):
op = self.operator_string(binary.operator)
if callable(op):
return op(self.process(binary.left), self.process(binary.right))
else:
return self.process(binary.left) + " " + op + " " + self.process(binary.right)
-
- return ret
-
+
def operator_string(self, operator):
return self.operators.get(operator, str(operator))
-
+
def visit_bindparam(self, bindparam, **kwargs):
name = self._truncate_bindparam(bindparam)
if name in self.binds:
@@ -403,22 +400,22 @@ class DefaultCompiler(engine.Compiled):
raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
-
+
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
-
+
bind_name = bindparam.key
bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
self.bind_names[bindparam] = bind_name
-
+
return bind_name
-
+
def _truncated_identifier(self, ident_class, name):
if (ident_class, name) in self.generated_ids:
return self.generated_ids[(ident_class, name)]
-
+
anonname = ANONYMOUS_LABEL.sub(self._process_anon, name)
if len(anonname) > self.dialect.max_identifier_length:
@@ -441,14 +438,14 @@ class DefaultCompiler(engine.Compiled):
self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1
self.generated_ids[key] = newname
return newname
-
+
def _anonymize(self, name):
return ANONYMOUS_LABEL.sub(self._process_anon, name)
-
+
def bindparam_string(self, name):
if self.positional:
self.positiontup.append(name)
-
+
return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
def visit_alias(self, alias, asfrom=False, **kwargs):
@@ -459,13 +456,13 @@ class DefaultCompiler(engine.Compiled):
def label_select_column(self, select, column, asfrom):
"""label columns present in a select()."""
-
+
if isinstance(column, sql._Label):
return column
-
+
if select.use_labels and getattr(column, '_label', None):
return column.label(column._label)
-
+
if \
asfrom and \
isinstance(column, sql._ColumnClause) and \
@@ -494,12 +491,12 @@ class DefaultCompiler(engine.Compiled):
stack_entry['iswrapper'] = True
else:
column_clause_args = {'result_map':self.result_map}
-
+
if self.stack and 'from' in self.stack[-1]:
existingfroms = self.stack[-1]['from']
else:
existingfroms = None
-
+
froms = select._get_display_froms(existingfroms)
correlate_froms = util.Set()
@@ -510,17 +507,17 @@ class DefaultCompiler(engine.Compiled):
# TODO: might want to propigate existing froms for select(select(select))
# where innermost select should correlate to outermost
# if existingfroms:
-# correlate_froms = correlate_froms.union(existingfroms)
+# correlate_froms = correlate_froms.union(existingfroms)
stack_entry['from'] = correlate_froms
self.stack.append(stack_entry)
# the actual list of columns to print in the SELECT column list.
inner_columns = util.OrderedSet()
-
+
for co in select.inner_columns:
l = self.label_select_column(select, co, asfrom=asfrom)
inner_columns.add(self.process(l, **column_clause_args))
-
+
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
@@ -539,7 +536,7 @@ class DefaultCompiler(engine.Compiled):
whereclause = sql.and_(w, whereclause)
else:
whereclause = w
-
+
if froms:
text += " \nFROM "
text += string.join(from_strings, ', ')
@@ -559,7 +556,7 @@ class DefaultCompiler(engine.Compiled):
t = self.process(select._having)
if t:
text += " \nHAVING " + t
-
+
text += self.order_by_clause(select)
text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
@@ -625,10 +622,10 @@ class DefaultCompiler(engine.Compiled):
', '.join([preparer.quote(c[0], c[0].name)
for c in colparams]),
', '.join([c[1] for c in colparams])))
-
+
def visit_update(self, update_stmt):
self.stack.append({'from':util.Set([update_stmt.table])})
-
+
self.isupdate = True
colparams = self._get_colparams(update_stmt)
@@ -636,15 +633,15 @@ class DefaultCompiler(engine.Compiled):
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
-
+
self.stack.pop(-1)
-
+
return text
def _get_colparams(self, stmt):
- """create a set of tuples representing column/string pairs for use
+ """create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
-
+
"""
def create_bind_param(col, value):
@@ -654,7 +651,7 @@ class DefaultCompiler(engine.Compiled):
self.postfetch = []
self.prefetch = []
-
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
@@ -688,7 +685,7 @@ class DefaultCompiler(engine.Compiled):
if (((isinstance(c.default, schema.Sequence) and
not c.default.optional) or
not self.dialect.supports_pk_autoincrement) or
- (c.default is not None and
+ (c.default is not None and
not isinstance(c.default, schema.Sequence))):
values.append((c, create_bind_param(c, None)))
self.prefetch.append(c)
@@ -732,18 +729,18 @@ class DefaultCompiler(engine.Compiled):
text += " WHERE " + self.process(delete_stmt._whereclause)
self.stack.pop(-1)
-
+
return text
-
+
def visit_savepoint(self, savepoint_stmt):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
-
+
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
-
+
def __str__(self):
return self.string or ''
@@ -1072,10 +1069,10 @@ class IdentifierPreparer(object):
def format_column(self, column, use_table=False, name=None, table_name=None):
"""Prepare a quoted column name.
-
+
deprecated. use preparer.quote(col, column.name) or combine with format_table()
"""
-
+
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
@@ -1121,7 +1118,6 @@ class IdentifierPreparer(object):
'final': final,
'escaped': escaped_final })
self._r_identifiers = r
-
+
return [self._unescape_identifier(i)
for i in [a or b for a, b in r.findall(identifiers)]]
-