diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-05-14 22:25:36 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-05-14 22:25:36 +0000 |
commit | ae4b954b1a6baf5a58c0e00e382196b581a7f06a (patch) | |
tree | 562d969dcef857594a62e3730e1305aaa284e3b0 /lib/sqlalchemy/ansisql.py | |
parent | 3de128138a896bc7373aa2684c920271c4781b7d (diff) | |
download | sqlalchemy-ae4b954b1a6baf5a58c0e00e382196b581a7f06a.tar.gz |
- parenthesis are applied to clauses via a new _Grouping construct.
uses operator precedence to more intelligently apply parenthesis
to clauses, provides cleaner nesting of clauses (doesnt mutate
clauses placed in other clauses, i.e. no 'parens' flag)
- added 'modifier' keyword, works like func.<foo> except does not
add parenthesis. e.g. select([modifier.DISTINCT(...)]) etc.
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 80 |
1 files changed, 35 insertions, 45 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index ab043f3ec..28dd0866c 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -245,7 +245,10 @@ class ANSICompiler(sql.Compiled): """ return "" - + + def visit_grouping(self, grouping): + self.strings[grouping] = "(" + self.strings[grouping.elem] + ")" + def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) @@ -298,10 +301,7 @@ class ANSICompiler(sql.Compiled): self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec() def visit_textclause(self, textclause): - if textclause.parens and len(textclause.text): - self.strings[textclause] = "(" + textclause.text + ")" - else: - self.strings[textclause] = textclause.text + self.strings[textclause] = textclause.text self.froms[textclause] = textclause.text if textclause.typemap is not None: self.typemap.update(textclause.typemap) @@ -309,32 +309,21 @@ class ANSICompiler(sql.Compiled): def visit_null(self, null): self.strings[null] = 'NULL' - def visit_compound(self, compound): - if compound.operator is None: - sep = " " - else: - sep = " " + compound.operator + " " - - s = string.join([self.get_str(c) for c in compound.clauses], sep) - if compound.parens: - self.strings[compound] = "(" + s + ")" - else: - self.strings[compound] = s - def visit_clauselist(self, list): - if list.parens: - self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")" + sep = list.operator + if sep == ',': + sep = ', ' + elif sep is None or sep == " ": + sep = " " else: - self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + sep = " " + sep + " " + self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, list): - if list.parens: - self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")" - else: - self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ') + def visit_calculatedclause(self, clause): + self.strings[clause] = self.get_str(clause.clause_expr) def visit_cast(self, cast): if len(self.select_stack): @@ -349,7 +338,7 @@ class ANSICompiler(sql.Compiled): self.strings[func] = ".".join(func.packagenames + [func.name]) self.froms[func] = self.strings[func] else: - self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" + self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr) self.froms[func] = self.strings[func] def visit_compound_select(self, cs): @@ -359,19 +348,22 @@ class ANSICompiler(sql.Compiled): text += " GROUP BY " + group_by text += self.order_by_clause(cs) text += self.visit_select_postclauses(cs) - if cs.parens: - self.strings[cs] = "(" + text + ")" - else: - self.strings[cs] = text + self.strings[cs] = text self.froms[cs] = "(" + text + ")" + def visit_unary(self, unary): + s = self.get_str(unary.element) + if unary.operator: + s = unary.operator + " " + s + if unary.modifier: + s = s + " " + unary.modifier + self.strings[unary] = s + def visit_binary(self, binary): result = self.get_str(binary.left) if binary.operator is not None: result += " " + self.binary_operator_string(binary) result += " " + self.get_str(binary.right) - if binary.parens: - result = "(" + result + ")" self.strings[binary] = result def binary_operator_string(self, binary): @@ -438,10 +430,6 @@ class ANSICompiler(sql.Compiled): self.select_stack.append(select) for c in select._raw_columns: - if isinstance(c, sql.Select) and c.is_scalar: - self.traverse(c) - inner_columns[self.get_str(c)] = c - continue if hasattr(c, '_selectable'): s = c._selectable() else: @@ -484,6 +472,7 @@ class ANSICompiler(sql.Compiled): for f in select.froms: if self.parameters is not None: + # TODO: whack this feature in 0.4 # look at our own parameters, see if they # are all present in the form of BindParamClauses. if # not, then append to the above whereclause column conditions @@ -494,16 +483,20 @@ class ANSICompiler(sql.Compiled): else: continue clause = c==value - self.traverse(clause) - whereclause = sql.and_(clause, whereclause) - self.visit_compound(whereclause) + if whereclause is not None: + whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause])) + else: + whereclause = clause + self.traverse(whereclause) # special thingy used by oracle to redefine a join w = self.get_whereclause(f) if w is not None: # TODO: move this more into the oracle module - whereclause = sql.and_(w, whereclause) - self.visit_compound(whereclause) + if whereclause is not None: + whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w])) + else: + whereclause = w t = self.get_from_text(f) if t is not None: @@ -533,10 +526,7 @@ class ANSICompiler(sql.Compiled): text += self.visit_select_postclauses(select) text += self.for_update_clause(select) - if getattr(select, 'parens', False): - self.strings[select] = "(" + text + ")" - else: - self.strings[select] = text + self.strings[select] = text self.froms[select] = "(" + text + ")" def visit_select_precolumns(self, select): |