diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-06-25 12:42:47 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-06-25 12:42:47 -0400 |
commit | 5771ae2ebf169b6daccac0c928013f98975e4c67 (patch) | |
tree | 46a3f2266e0f376252538515f664ef994db8f570 /lib/sqlalchemy/sql/compiler.py | |
parent | 7b1a1a66cd36fdfac6541e6b771fd6c849b0bd7d (diff) | |
download | sqlalchemy-5771ae2ebf169b6daccac0c928013f98975e4c67.tar.gz |
- move cte tests into their own test/sql/test_cte.py
- rework bindtemplate system of "numbered" params by applying
the numbers last, as we now need to generate these out of order
in some cases
- add positional assertion to assert_compile
- add new cte_positional collection to track bindparams generated
within cte visits; splice this onto the beginning of self.positiontup
at cte render time, [ticket:2521]
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 52 |
1 files changed, 36 insertions, 16 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fe56b1a83..0dac0f3d6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -29,6 +29,7 @@ from . import ( operators, functions, util as sql_util, visitors, expression as sql ) import decimal +import itertools RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -59,7 +60,7 @@ BIND_TEMPLATES = { 'pyformat':"%%(%(name)s)s", 'qmark':"?", 'format':"%%s", - 'numeric':":%(position)s", + 'numeric':":[_POSITION]", 'named':":%(name)s" } @@ -252,16 +253,18 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} - # collect CTEs to tack on top of a SELECT - self.ctes = util.OrderedDict() - self.ctes_recursive = False - # true if the paramstyle is positional self.positional = dialect.positional if self.positional: self.positiontup = [] self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + if self.positional: + self.cte_positional = [] + # an IdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer self.label_length = dialect.label_length \ @@ -276,7 +279,15 @@ class SQLCompiler(engine.Compiled): self.truncated_names = {} engine.Compiled.__init__(self, dialect, statement, **kwargs) + if self.positional and dialect.paramstyle == 'numeric': + self._apply_numbered_params() + def _apply_numbered_params(self): + poscount = itertools.count(1) + self.string = re.sub( + r'\[_POSITION\]', + lambda m:str(next(poscount)), + self.string) @util.memoized_property def _bind_processors(self): @@ -455,7 +466,7 @@ class SQLCompiler(engine.Compiled): if name in textclause.bindparams: return self.process(textclause.bindparams[name]) else: - return self.bindparam_string(name) + return self.bindparam_string(name, **kwargs) # un-escape any \:params return BIND_PARAMS_ESC.sub(lambda m: m.group(1), @@ -690,7 +701,7 @@ class SQLCompiler(engine.Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string(name, quote=bindparam.quote) + return self.bindparam_string(name, quote=bindparam.quote, **kwargs) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.value @@ -760,20 +771,25 @@ class SQLCompiler(engine.Compiled): self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) - def bindparam_string(self, name, quote=None): + def bindparam_string(self, name, quote=None, + positional_names=None, **kw): if self.positional: - self.positiontup.append(name) - return self.bindtemplate % { - 'name':name, 'position':len(self.positiontup)} - else: - return self.bindtemplate % {'name':name} + if positional_names is not None: + positional_names.append(name) + else: + self.positiontup.append(name) + return self.bindtemplate % {'name':name} def visit_cte(self, cte, asfrom=False, ashint=False, fromhints=None, **kwargs): + if self.positional: + kwargs['positional_names'] = self.cte_positional + if isinstance(cte.name, sql._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: cte_name = cte.name + if cte.cte_alias: if isinstance(cte.cte_alias, sql._truncated_label): cte_alias = self._truncated_identifier("alias", cte.cte_alias) @@ -883,7 +899,8 @@ class SQLCompiler(engine.Compiled): def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, - compound_index=1, **kwargs): + compound_index=1, + positional_names=None, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -902,9 +919,10 @@ class SQLCompiler(engine.Compiled): : iswrapper}) if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map} + column_clause_args = {'result_map':self.result_map, + 'positional_names':positional_names} else: - column_clause_args = {} + column_clause_args = {'positional_names':positional_names} # the actual list of columns to print in the SELECT column list. inner_columns = [ @@ -991,6 +1009,8 @@ class SQLCompiler(engine.Compiled): return text def _render_cte_clause(self): + if self.positional: + self.positiontup = self.cte_positional + self.positiontup cte_text = self.get_cte_preamble(self.ctes_recursive) + " " cte_text += ", \n".join( [txt for txt in self.ctes.values()] |