1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
# -*- coding: utf-8 -*-
# Tests splitting functions.
import unittest
from tests.utils import load_file, TestCaseBase
import sqlparse
class SQLSplitTest(TestCaseBase):
"""Tests sqlparse.sqlsplit()."""
_sql1 = 'select * from foo;'
_sql2 = 'select * from bar;'
def test_split_semicolon(self):
sql2 = 'select * from foo where bar = \'foo;bar\';'
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
self.ndiffAssertEqual(unicode(stmts[1]), sql2)
def test_create_function(self):
sql = load_file('function.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql(self):
sql = load_file('function_psql.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql3(self):
sql = load_file('function_psql3.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_create_function_psql2(self):
sql = load_file('function_psql2.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 1)
self.ndiffAssertEqual(unicode(stmts[0]), sql)
def test_dashcomments(self):
sql = load_file('dashcomment.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_begintag(self):
sql = load_file('begintag.sql')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 3)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_dropif(self):
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_comment_with_umlaut(self):
sql = (u'select * from foo;\n'
u'-- Testing an umlaut: รค\n'
u'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
def test_comment_end_of_line(self):
sql = ('select * from foo; -- foo\n'
'select * from bar;')
stmts = sqlparse.parse(sql)
self.assertEqual(len(stmts), 2)
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
# make sure the comment belongs to first query
self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
def test_casewhen(self):
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
'comment on table actor is \'The actor table.\';')
stmts = sqlparse.split(sql)
self.assertEqual(len(stmts), 2)
|