summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-05-14 22:25:36 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-05-14 22:25:36 +0000
commitae4b954b1a6baf5a58c0e00e382196b581a7f06a (patch)
tree562d969dcef857594a62e3730e1305aaa284e3b0 /lib/sqlalchemy/ansisql.py
parent3de128138a896bc7373aa2684c920271c4781b7d (diff)
downloadsqlalchemy-ae4b954b1a6baf5a58c0e00e382196b581a7f06a.tar.gz
- parenthesis are applied to clauses via a new _Grouping construct.
uses operator precedence to more intelligently apply parenthesis to clauses, provides cleaner nesting of clauses (doesnt mutate clauses placed in other clauses, i.e. no 'parens' flag) - added 'modifier' keyword, works like func.<foo> except does not add parenthesis. e.g. select([modifier.DISTINCT(...)]) etc.
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py80
1 files changed, 35 insertions, 45 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index ab043f3ec..28dd0866c 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -245,7 +245,10 @@ class ANSICompiler(sql.Compiled):
"""
return ""
-
+
+ def visit_grouping(self, grouping):
+ self.strings[grouping] = "(" + self.strings[grouping.elem] + ")"
+
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
@@ -298,10 +301,7 @@ class ANSICompiler(sql.Compiled):
self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
def visit_textclause(self, textclause):
- if textclause.parens and len(textclause.text):
- self.strings[textclause] = "(" + textclause.text + ")"
- else:
- self.strings[textclause] = textclause.text
+ self.strings[textclause] = textclause.text
self.froms[textclause] = textclause.text
if textclause.typemap is not None:
self.typemap.update(textclause.typemap)
@@ -309,32 +309,21 @@ class ANSICompiler(sql.Compiled):
def visit_null(self, null):
self.strings[null] = 'NULL'
- def visit_compound(self, compound):
- if compound.operator is None:
- sep = " "
- else:
- sep = " " + compound.operator + " "
-
- s = string.join([self.get_str(c) for c in compound.clauses], sep)
- if compound.parens:
- self.strings[compound] = "(" + s + ")"
- else:
- self.strings[compound] = s
-
def visit_clauselist(self, list):
- if list.parens:
- self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")"
+ sep = list.operator
+ if sep == ',':
+ sep = ', '
+ elif sep is None or sep == " ":
+ sep = " "
else:
- self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ')
+ sep = " " + sep + " "
+ self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep)
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
- def visit_calculatedclause(self, list):
- if list.parens:
- self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")"
- else:
- self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ')
+ def visit_calculatedclause(self, clause):
+ self.strings[clause] = self.get_str(clause.clause_expr)
def visit_cast(self, cast):
if len(self.select_stack):
@@ -349,7 +338,7 @@ class ANSICompiler(sql.Compiled):
self.strings[func] = ".".join(func.packagenames + [func.name])
self.froms[func] = self.strings[func]
else:
- self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+ self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr)
self.froms[func] = self.strings[func]
def visit_compound_select(self, cs):
@@ -359,19 +348,22 @@ class ANSICompiler(sql.Compiled):
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
text += self.visit_select_postclauses(cs)
- if cs.parens:
- self.strings[cs] = "(" + text + ")"
- else:
- self.strings[cs] = text
+ self.strings[cs] = text
self.froms[cs] = "(" + text + ")"
+ def visit_unary(self, unary):
+ s = self.get_str(unary.element)
+ if unary.operator:
+ s = unary.operator + " " + s
+ if unary.modifier:
+ s = s + " " + unary.modifier
+ self.strings[unary] = s
+
def visit_binary(self, binary):
result = self.get_str(binary.left)
if binary.operator is not None:
result += " " + self.binary_operator_string(binary)
result += " " + self.get_str(binary.right)
- if binary.parens:
- result = "(" + result + ")"
self.strings[binary] = result
def binary_operator_string(self, binary):
@@ -438,10 +430,6 @@ class ANSICompiler(sql.Compiled):
self.select_stack.append(select)
for c in select._raw_columns:
- if isinstance(c, sql.Select) and c.is_scalar:
- self.traverse(c)
- inner_columns[self.get_str(c)] = c
- continue
if hasattr(c, '_selectable'):
s = c._selectable()
else:
@@ -484,6 +472,7 @@ class ANSICompiler(sql.Compiled):
for f in select.froms:
if self.parameters is not None:
+ # TODO: whack this feature in 0.4
# look at our own parameters, see if they
# are all present in the form of BindParamClauses. if
# not, then append to the above whereclause column conditions
@@ -494,16 +483,20 @@ class ANSICompiler(sql.Compiled):
else:
continue
clause = c==value
- self.traverse(clause)
- whereclause = sql.and_(clause, whereclause)
- self.visit_compound(whereclause)
+ if whereclause is not None:
+ whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
+ else:
+ whereclause = clause
+ self.traverse(whereclause)
# special thingy used by oracle to redefine a join
w = self.get_whereclause(f)
if w is not None:
# TODO: move this more into the oracle module
- whereclause = sql.and_(w, whereclause)
- self.visit_compound(whereclause)
+ if whereclause is not None:
+ whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+ else:
+ whereclause = w
t = self.get_from_text(f)
if t is not None:
@@ -533,10 +526,7 @@ class ANSICompiler(sql.Compiled):
text += self.visit_select_postclauses(select)
text += self.for_update_clause(select)
- if getattr(select, 'parens', False):
- self.strings[select] = "(" + text + ")"
- else:
- self.strings[select] = text
+ self.strings[select] = text
self.froms[select] = "(" + text + ")"
def visit_select_precolumns(self, select):