diff options
Diffstat (limited to 'tests/test_split.py')
-rw-r--r-- | tests/test_split.py | 48 |
1 files changed, 23 insertions, 25 deletions
diff --git a/tests/test_split.py b/tests/test_split.py index 54e8d04..8b22aad 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Tests splitting functions. - -import unittest +import types from tests.utils import load_file, TestCaseBase import sqlparse +from sqlparse import compat class SQLSplitTest(TestCaseBase): @@ -19,8 +19,8 @@ class SQLSplitTest(TestCaseBase): sql2 = 'select * from foo where bar = \'foo;bar\';' stmts = sqlparse.parse(''.join([self._sql1, sql2])) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(unicode(stmts[0]), self._sql1) - self.ndiffAssertEqual(unicode(stmts[1]), sql2) + self.ndiffAssertEqual(compat.text_type(stmts[0]), self._sql1) + self.ndiffAssertEqual(compat.text_type(stmts[1]), sql2) def test_split_backslash(self): stmts = sqlparse.parse(r"select '\\'; select '\''; select '\\\'';") @@ -30,31 +30,31 @@ class SQLSplitTest(TestCaseBase): sql = load_file('function.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(compat.text_type(stmts[0]), sql) def test_create_function_psql(self): sql = load_file('function_psql.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(compat.text_type(stmts[0]), sql) def test_create_function_psql3(self): sql = load_file('function_psql3.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(compat.text_type(stmts[0]), sql) def test_create_function_psql2(self): sql = load_file('function_psql2.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(unicode(stmts[0]), sql) + self.ndiffAssertEqual(compat.text_type(stmts[0]), sql) def test_dashcomments(self): sql = load_file('dashcomment.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql) def test_dashcomments_eol(self): stmts = sqlparse.parse('select foo; -- comment\n') @@ -70,36 +70,38 @@ class SQLSplitTest(TestCaseBase): sql = load_file('begintag.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 3) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql) def test_begintag_2(self): sql = load_file('begintag_2.sql') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 1) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql) def test_dropif(self): sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql) def test_comment_with_umlaut(self): - sql = (u'select * from foo;\n' - u'-- Testing an umlaut: ä\n' - u'select * from bar;') + sql = (compat.u( + 'select * from foo;\n' + '-- Testing an umlaut: ä\n' + 'select * from bar;')) stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql) def test_comment_end_of_line(self): sql = ('select * from foo; -- foo\n' 'select * from bar;') stmts = sqlparse.parse(sql) self.assertEqual(len(stmts), 2) - self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + self.ndiffAssertEqual(''.join(compat.text_type(q) for q in stmts), sql) # make sure the comment belongs to first query - self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n') + self.ndiffAssertEqual( + compat.text_type(stmts[0]), 'select * from foo; -- foo\n') def test_casewhen(self): sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' @@ -122,19 +124,15 @@ class SQLSplitTest(TestCaseBase): self.assertEqual(len(stmts), 2) def test_split_stream(self): - import types - from cStringIO import StringIO - - stream = StringIO("SELECT 1; SELECT 2;") + stream = compat.StringIO("SELECT 1; SELECT 2;") stmts = sqlparse.parsestream(stream) self.assertEqual(type(stmts), types.GeneratorType) self.assertEqual(len(list(stmts)), 2) def test_encoding_parsestream(self): - from cStringIO import StringIO - stream = StringIO("SELECT 1; SELECT 2;") + stream = compat.StringIO("SELECT 1; SELECT 2;") stmts = list(sqlparse.parsestream(stream)) - self.assertEqual(type(stmts[0].tokens[0].value), unicode) + self.assertEqual(type(stmts[0].tokens[0].value), compat.text_type) def test_split_simple(): |