1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
# -*- coding: utf-8 -*-
# Tests splitting functions.
import unittest
from tests.utils import load_file, TestCaseBase
import sqlparse
class SQLSplitTest(TestCaseBase):
"""Tests sqlparse.sqlsplit()."""
_sql1 = 'select * from foo;'
_sql2 = 'select * from bar;'
def test_split_semicolon(self):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
self.ndiffAssertEqual(unicode(stmts[1]), sql2)
def test_create_function(self):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql3(self):
sql = load_file('function_psql3.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_dashcomments_eol(self):
stmts = sqlparse.parse('select foo; -- comment\n')
self.assertEqual(len(stmts), 1)
stmts = sqlparse.parse('select foo; -- comment\r')
self.assertEqual(len(stmts), 1)
stmts = sqlparse.parse('select foo; -- comment\r\n')
self.assertEqual(len(stmts), 1)
stmts = sqlparse.parse('select foo; -- comment')
self.assertEqual(len(stmts), 1)
def test_begintag(self):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_begintag_2(self):
sql = load_file('begintag_2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_dropif(self):
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_comment_with_umlaut(self):
sql = (u'select * from foo;\n'
u'-- Testing an umlaut: รค\n'
u'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
# make sure the comment belongs to first query
self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
def test_casewhen(self):
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
'comment on table actor is \'The actor table.\';')
stmts = sqlparse.split(sql)
self.assertEqual(len(stmts), 2)
def test_cursor_declare(self):
sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
'SELECT 2;')
stmts = sqlparse.split(sql)
self.assertEqual(len(stmts), 2)
def test_if_function(self): # see issue 33
# don't let IF as a function confuse the splitter
sql = ('CREATE TEMPORARY TABLE tmp SELECT IF(a=1, a, b) AS o FROM one; '
'SELECT t FROM two')
stmts = sqlparse.split(sql)
self.assertEqual(len(stmts), 2)
def test_split_stream(self):
import types
from cStringIO import StringIO
stream = StringIO("SELECT 1; SELECT 2;")
stmts = sqlparse.parsestream(stream)
self.assertEqual(type(stmts), types.GeneratorType)
self.assertEqual(len(list(stmts)), 2)
def test_encoding_parsestream(self):
from cStringIO import StringIO
stream = StringIO("SELECT 1; SELECT 2;")
stmts = list(sqlparse.parsestream(stream))
self.assertEqual(type(stmts[0].tokens[0].value), unicode)
|