summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2009-05-27 20:58:12 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2009-05-27 20:58:12 +0200
commit895f021a0515dbf948efb1dfe960c0aa63cd160d (patch)
tree7a68db321f6d58b458345cecef1b6a0ca7b9c4a5
parent9917967e25669d21e577123583a2d3a191844c62 (diff)
downloadsqlparse-895f021a0515dbf948efb1dfe960c0aa63cd160d.tar.gz
Grouping of function/procedure calls.
-rw-r--r--sqlparse/engine/grouping.py17
-rw-r--r--sqlparse/sql.py14
-rw-r--r--tests/test_grouping.py9
3 files changed, 39 insertions, 1 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 471116e..66f4df5 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -245,8 +245,25 @@ def group_typecasts(tlist):
_group_left_right(tlist, T.Punctuation, '::', Identifier)
+def group_functions(tlist):
+ [group_functions(sgroup) for sgroup in tlist.get_sublists()
+ if not isinstance(sgroup, Function)]
+ idx = 0
+ token = tlist.token_next_by_type(idx, T.Name)
+ while token:
+ next_ = tlist.token_next(token)
+ if not isinstance(next_, Parenthesis):
+ idx = tlist.token_index(token)+1
+ else:
+ func = tlist.group_tokens(Function,
+ tlist.tokens_between(token, next_))
+ idx = tlist.token_index(func)+1
+ token = tlist.token_next_by_type(idx, T.Name)
+
+
def group(tlist):
for func in [group_parenthesis,
+ group_functions,
group_comments,
group_where,
group_case,
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 7c607c4..4502fa2 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -455,3 +455,17 @@ class Case(TokenList):
elif in_value:
ret[-1][1].append(token)
return ret
+
+
+class Function(TokenList):
+ """A function or procedure call."""
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def get_parameters(self):
+ """Return a list of parameters."""
+ parenthesis = self.tokens[-1]
+ for t in parenthesis.tokens:
+ if isinstance(t, IdentifierList):
+ return t.get_identifiers()
+ return []
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index 6477123..8d62ea4 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -10,7 +10,7 @@ from tests.utils import TestCaseBase
class TestGrouping(TestCaseBase):
def test_parenthesis(self):
- s ='x1 (x2 (x3) x2) foo (y2) bar'
+ s ='select (select (x3) x2) and (y2) bar'
parsed = sqlparse.parse(s)[0]
self.ndiffAssertEqual(s, str(parsed))
self.assertEqual(len(parsed.tokens), 9)
@@ -142,6 +142,13 @@ class TestGrouping(TestCaseBase):
p = sqlparse.parse('(a+1)')[0]
self.assert_(isinstance(p.tokens[0].tokens[1], Comparsion))
+ def test_function(self):
+ p = sqlparse.parse('foo()')[0]
+ self.assert_(isinstance(p.tokens[0], Function))
+ p = sqlparse.parse('foo(null, bar)')[0]
+ self.assert_(isinstance(p.tokens[0], Function))
+ self.assertEqual(len(p.tokens[0].get_parameters()), 2)
+
class TestStatement(TestCaseBase):