summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2011-09-29 11:19:26 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2011-09-29 11:19:26 +0200
commit13c1d716a7798ae1a79c524dc967e9a33c534ff4 (patch)
tree5486734ad09578e3eb538a3faf545ae0681d0190
parentff50b33074f5c276b0cff8094e85582dcd467095 (diff)
downloadsqlparse-13c1d716a7798ae1a79c524dc967e9a33c534ff4.tar.gz
Detect alias for CASE statements (targets issue46).
-rw-r--r--sqlparse/engine/grouping.py11
-rw-r--r--sqlparse/sql.py69
-rw-r--r--tests/test_grouping.py5
3 files changed, 45 insertions, 40 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index cc75de4..f92a812 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -274,21 +274,20 @@ def group_where(tlist):
def group_aliased(tlist):
+ clss = (sql.Identifier, sql.Function, sql.Case)
[group_aliased(sgroup) for sgroup in tlist.get_sublists()
- if not isinstance(sgroup, (sql.Identifier, sql.Function))]
+ if not isinstance(sgroup, clss)]
idx = 0
- token = tlist.token_next_by_instance(idx, (sql.Identifier, sql.Function))
+ token = tlist.token_next_by_instance(idx, clss)
while token:
next_ = tlist.token_next(tlist.token_index(token))
- if next_ is not None and isinstance(next_,
- (sql.Identifier, sql.Function)):
+ if next_ is not None and isinstance(next_, clss):
grp = tlist.tokens_between(token, next_)[1:]
token.tokens.extend(grp)
for t in grp:
tlist.tokens.remove(t)
idx = tlist.token_index(token) + 1
- token = tlist.token_next_by_instance(idx,
- (sql.Identifier, sql.Function))
+ token = tlist.token_next_by_instance(idx, clss)
def group_typecasts(tlist):
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 4d56bf3..244733b 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -315,38 +315,6 @@ class TokenList(Token):
"""Inserts *token* before *where*."""
self.tokens.insert(self.token_index(where), token)
-
-class Statement(TokenList):
- """Represents a SQL statement."""
-
- __slots__ = ('value', 'ttype', 'tokens')
-
- def get_type(self):
- """Returns the type of a statement.
-
- The returned value is a string holding an upper-cased reprint of
- the first DML or DDL keyword. If the first token in this group
- isn't a DML or DDL keyword "UNKNOWN" is returned.
- """
- first_token = self.token_first()
- if first_token is None:
- # An "empty" statement that either has not tokens at all
- # or only whitespace tokens.
- return 'UNKNOWN'
- elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL):
- return first_token.value.upper()
- else:
- return 'UNKNOWN'
-
-
-class Identifier(TokenList):
- """Represents an identifier.
-
- Identifiers may have aliases or typecasts.
- """
-
- __slots__ = ('value', 'ttype', 'tokens')
-
def has_alias(self):
"""Returns ``True`` if an alias is present."""
return self.get_alias() is not None
@@ -359,8 +327,8 @@ class Identifier(TokenList):
if alias is None:
return None
else:
- next_ = self.token_next(0)
- if next_ is None or not isinstance(next_, Identifier):
+ next_ = self.token_next_by_instance(0, Identifier)
+ if next_ is None:
return None
alias = next_
if isinstance(alias, Identifier):
@@ -393,6 +361,39 @@ class Identifier(TokenList):
return None
return next_.value
+
+
+class Statement(TokenList):
+ """Represents a SQL statement."""
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ def get_type(self):
+ """Returns the type of a statement.
+
+ The returned value is a string holding an upper-cased reprint of
+ the first DML or DDL keyword. If the first token in this group
+ isn't a DML or DDL keyword "UNKNOWN" is returned.
+ """
+ first_token = self.token_first()
+ if first_token is None:
+ # An "empty" statement that either has not tokens at all
+ # or only whitespace tokens.
+ return 'UNKNOWN'
+ elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL):
+ return first_token.value.upper()
+ else:
+ return 'UNKNOWN'
+
+
+class Identifier(TokenList):
+ """Represents an identifier.
+
+ Identifiers may have aliases or typecasts.
+ """
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
def get_parent_name(self):
"""Return name of the parent object if any.
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index 134409e..a66d39d 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -157,6 +157,11 @@ class TestGrouping(TestCaseBase):
self.ndiffAssertEqual(s, p.to_unicode())
self.assertEqual(p.tokens[4].get_alias(), 'view')
+ def test_alias_case(self): # see issue46
+ p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0]
+ self.assertEqual(len(p.tokens), 1)
+ self.assertEqual(p.tokens[0].get_alias(), 'foo')
+
def test_idlist_function(self): # see issue10 too
p = sqlparse.parse('foo(1) x, bar')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))