summaryrefslogtreecommitdiff
path: root/tests/test_bio_ssl.py
blob: ed2ea1a295ca7fa0404e82dc9d32691240e5a63c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python
from __future__ import absolute_import, print_function

"""Unit tests for M2Crypto.BIO.File.

Copyright (c) 1999-2002 Ng Pheng Siong. All rights reserved."""

import socket
import sys
import threading

from M2Crypto import BIO
from M2Crypto import SSL
from M2Crypto import Err
from M2Crypto import Rand
from M2Crypto import threading as m2threading

from tests import unittest
from tests.test_ssl import srv_host, allocate_srv_port


class HandshakeClient(threading.Thread):

    def __init__(self, host, port):
        threading.Thread.__init__(self)
        self.host = host
        self.port = port

    def run(self):
        ctx = SSL.Context()
        ctx.load_cert_chain("tests/server.pem")
        conn = SSL.Connection(ctx)
        cipher_list = conn.get_cipher_list()
        sslbio = BIO.SSLBio()
        readbio = BIO.MemoryBuffer()
        writebio = BIO.MemoryBuffer()
        sslbio.set_ssl(conn)
        conn.set_bio(readbio, writebio)
        conn.set_connect_state()
        sock = socket.socket()
        sock.connect((self.host, self.port))

        handshake_complete = False
        while not handshake_complete:
            ret = sslbio.do_handshake()
            if ret <= 0:
                if not sslbio.should_retry() or not sslbio.should_read():
                    err_string = Err.get_error()
                    print(err_string)
                    sys.exit("unrecoverable error in handshake - client")
                else:
                    output_token = writebio.read()
                    if output_token is not None:
                        sock.sendall(output_token)
                    else:
                        input_token = sock.recv(1024)
                        readbio.write(input_token)
            else:
                handshake_complete = True

        output_token = writebio.read()
        if output_token is not None:
            sock.sendall(output_token)
        sock.close()


class SSLTestCase(unittest.TestCase):

    def setUp(self):
        self.sslbio = BIO.SSLBio()

    def test_pass(self):  # XXX leaks 64/24 bytes
        pass

    def test_set_ssl(self):  # XXX leaks 64/1312 bytes
        ctx = SSL.Context()
        conn = SSL.Connection(ctx)
        self.sslbio.set_ssl(conn)

    def test_do_handshake_fail(self):  # XXX leaks 64/42066 bytes
        ctx = SSL.Context()
        conn = SSL.Connection(ctx)
        conn.set_connect_state()
        self.sslbio.set_ssl(conn)
        ret = self.sslbio.do_handshake()
        self.assertIn(ret, (-1, 0))

    def test_should_retry_fail(self):  # XXX leaks 64/1312 bytes
        ctx = SSL.Context()
        conn = SSL.Connection(ctx)
        self.sslbio.set_ssl(conn)
        ret = self.sslbio.do_handshake()
        self.assertIn(ret, (-1, 0))
        ret = self.sslbio.should_retry()
        self.assertEqual(ret, 0)

    def test_should_write_fail(self):  # XXX leaks 64/1312 bytes
        ctx = SSL.Context()
        conn = SSL.Connection(ctx)
        self.sslbio.set_ssl(conn)
        ret = self.sslbio.do_handshake()
        self.assertIn(ret, (-1, 0))
        ret = self.sslbio.should_write()
        self.assertEqual(ret, 0)

    def test_should_read_fail(self):  # XXX leaks 64/1312 bytes
        ctx = SSL.Context()
        conn = SSL.Connection(ctx)
        self.sslbio.set_ssl(conn)
        ret = self.sslbio.do_handshake()
        self.assertIn(ret, (-1, 0))
        ret = self.sslbio.should_read()
        self.assertEqual(ret, 0)

    def test_do_handshake_succeed(self):  # XXX leaks 196/26586 bytes
        ctx = SSL.Context()
        ctx.load_cert_chain("tests/server.pem")
        conn = SSL.Connection(ctx)
        self.sslbio.set_ssl(conn)
        readbio = BIO.MemoryBuffer()
        writebio = BIO.MemoryBuffer()
        conn.set_bio(readbio, writebio)
        conn.set_accept_state()
        handshake_complete = False
        srv_port = allocate_srv_port()
        sock = socket.socket()
        sock.bind((srv_host, srv_port))
        sock.listen(5)
        handshake_client = HandshakeClient(srv_host, srv_port)
        handshake_client.start()
        new_sock, _ = sock.accept()
        while not handshake_complete:
            input_token = new_sock.recv(1024)
            readbio.write(input_token)

            ret = self.sslbio.do_handshake()
            if ret <= 0:
                if not self.sslbio.should_retry() or not self.sslbio.should_read():
                    sys.exit("unrecoverable error in handshake - server")
            else:
                handshake_complete = True

            output_token = writebio.read()
            if output_token is not None:
                new_sock.sendall(output_token)

        handshake_client.join()
        sock.close()
        new_sock.close()


def suite():
    return unittest.TestLoader().loadTestsFromTestCase(SSLTestCase)


if __name__ == '__main__':
    Rand.load_file('randpool.dat', -1)
    m2threading.init()
    unittest.TextTestRunner().run(suite())
    m2threading.cleanup()
    Rand.save_file('randpool.dat')