summaryrefslogtreecommitdiff
path: root/tests/test_split.py
blob: 5146bcba5fa128bcd9be3d0d72def84c096bb7c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# -*- coding: utf-8 -*-

# Tests splitting functions.

import unittest

from tests.utils import load_file, TestCaseBase

import sqlparse


class SQLSplitTest(TestCaseBase):
    """Tests sqlparse.sqlsplit()."""

    _sql1 = 'select * from foo;'
    _sql2 = 'select * from bar;'

    def test_split_semicolon(self):
        sql2 = 'select * from foo where bar = \'foo;bar\';'
        stmts = sqlparse.parse(''.join([self._sql1, sql2]))
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
        self.ndiffAssertEqual(unicode(stmts[1]), sql2)

    def test_create_function(self):
        sql = load_file('function.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(unicode(stmts[0]), sql)

    def test_create_function_psql(self):
        sql = load_file('function_psql.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(unicode(stmts[0]), sql)

    def test_create_function_psql3(self):
        sql = load_file('function_psql3.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(unicode(stmts[0]), sql)

    def test_create_function_psql2(self):
        sql = load_file('function_psql2.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 1)
        self.ndiffAssertEqual(unicode(stmts[0]), sql)

    def test_dashcomments(self):
        sql = load_file('dashcomment.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 3)
        self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)

    def test_begintag(self):
        sql = load_file('begintag.sql')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 3)
        self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)

    def test_dropif(self):
        sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)

    def test_comment_with_umlaut(self):
        sql = (u'select * from foo;\n'
               u'-- Testing an umlaut: รค\n'
               u'select * from bar;')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)

    def test_comment_end_of_line(self):
        sql = ('select * from foo; -- foo\n'
               'select * from bar;')
        stmts = sqlparse.parse(sql)
        self.assertEqual(len(stmts), 2)
        self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
        # make sure the comment belongs to first query
        self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')

    def test_casewhen(self):
        sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
               'comment on table actor is \'The actor table.\';')
        stmts = sqlparse.split(sql)
        self.assertEqual(len(stmts), 2)