summaryrefslogtreecommitdiff
path: root/tests/test_split.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_split.py')
-rw-r--r--tests/test_split.py48
1 files changed, 23 insertions, 25 deletions
diff --git a/tests/test_split.py b/tests/test_split.py
index 54e8d04..8b22aad 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
# Tests splitting functions.
-
-import unittest
+import types
from tests.utils import load_file, TestCaseBase
import sqlparse
+from sqlparse import compat
class SQLSplitTest(TestCaseBase):
@@ -19,8 +19,8 @@ class SQLSplitTest(TestCaseBase):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
- self.ndiffAssertEqual(unicode(stmts[1]), sql2)
+ self.ndiffAssertEqual(compat.text_type(stmts[0]), self._sql1)
+ self.ndiffAssertEqual(compat.text_type(stmts[1]), sql2)
def test_split_backslash(self):
stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';")
@@ -30,31 +30,31 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)
def test_create_function_psql3(self):
sql = load_file('function_psql3.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(unicode(stmts[0]), sql)
+ self.ndiffAssertEqual(compat.text_type(stmts[0]), sql)
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
def test_dashcomments_eol(self):
stmts = sqlparse.parse('select foo; -- comment\n')
@@ -70,36 +70,38 @@ class SQLSplitTest(TestCaseBase):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
def test_begintag_2(self):
sql = load_file('begintag_2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
def test_dropif(self):
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
def test_comment_with_umlaut(self):
- sql = (u'select * from foo;\n'
- u'-- Testing an umlaut: ä\n'
- u'select * from bar;')
+ sql = (compat.u(
+ 'select * from foo;\n'
+ '-- Testing an umlaut: ä\n'
+ 'select * from bar;'))
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
- self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql)
# make sure the comment belongs to first query
- self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
+ self.ndiffAssertEqual(
+ compat.text_type(stmts[0]), 'select * from foo; -- foo\n')
def test_casewhen(self):
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
@@ -122,19 +124,15 @@ class SQLSplitTest(TestCaseBase):
self.assertEqual(len(stmts), 2)
def test_split_stream(self):
- import types
- from cStringIO import StringIO
-
- stream = StringIO("SELECT 1; SELECT 2;")
+ stream = compat.StringIO("SELECT 1; SELECT 2;")
stmts = sqlparse.parsestream(stream)
self.assertEqual(type(stmts), types.GeneratorType)
self.assertEqual(len(list(stmts)), 2)
def test_encoding_parsestream(self):
- from cStringIO import StringIO
- stream = StringIO("SELECT 1; SELECT 2;")
+ stream = compat.StringIO("SELECT 1; SELECT 2;")
stmts = list(sqlparse.parsestream(stream))
- self.assertEqual(type(stmts[0].tokens[0].value), unicode)
+ self.assertEqual(type(stmts[0].tokens[0].value), compat.text_type)
def test_split_simple():