summaryrefslogtreecommitdiff
path: root/tests/test_obj.py
blob: 7cdada0b19568d4d2381ccdf0545cc0be9f93f22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python

"""Unit tests for M2Crypto.m2 obj_* functions.
"""

try:
    import unittest2 as unittest
except ImportError:
    import unittest

from M2Crypto import ASN1, BIO, Rand, X509, m2, six

"""
These functions must be cleaned up and moved to some python module
Taken from CA managment code
"""


def x509_name2list(name):
    for i in range(0, name.entry_count()):
        yield X509.X509_Name_Entry(m2.x509_name_get_entry(name._ptr(), i),
                                   _pyfree=0)


def x509_name_entry2tuple(entry):
    bio = BIO.MemoryBuffer()
    m2.asn1_string_print(bio._ptr(), m2.x509_name_entry_get_data(entry._ptr()))
    return (
        six.ensure_text(m2.obj_obj2txt(
            m2.x509_name_entry_get_object(entry._ptr()), 0)),
        six.ensure_text(bio.getvalue()))


def tuple2x509_name_entry(tup):
    obj, data = tup
    # TODO This is evil, isn't it? Shouldn't we use only official API?
    # Something like X509.X509_Name.add_entry_by_txt()
    _x509_ne = m2.x509_name_entry_create_by_txt(None, six.ensure_str(obj),
                                                ASN1.MBSTRING_ASC,
                                                six.ensure_str(data), len(data))
    if not _x509_ne:
        raise ValueError("Invalid object indentifier: %s" % obj)
    return X509.X509_Name_Entry(_x509_ne, _pyfree=1)  # Prevent memory leaks


class ObjectsTestCase(unittest.TestCase):

    def callback(self, *args):
        pass

    def test_obj2txt(self):
        self.assertEqual(m2.obj_obj2txt(m2.obj_txt2obj("commonName", 0), 1),
                         b"2.5.4.3", b"2.5.4.3")
        self.assertEqual(m2.obj_obj2txt(m2.obj_txt2obj("commonName", 0), 0),
                         b"commonName", b"commonName")

    def test_nid(self):
        self.assertEqual(m2.obj_ln2nid("commonName"),
                         m2.obj_txt2nid("2.5.4.3"),
                         "ln2nid and txt2nid mismatch")
        self.assertEqual(m2.obj_ln2nid("CN"),
                         0, "ln2nid on sn")
        self.assertEqual(m2.obj_sn2nid("CN"),
                         m2.obj_ln2nid("commonName"),
                         "ln2nid and sn2nid mismatch")
        self.assertEqual(m2.obj_sn2nid("CN"),
                         m2.obj_obj2nid(m2.obj_txt2obj("CN", 0)), "obj2nid")
        self.assertEqual(m2.obj_txt2nid("__unknown"),
                         0, "__unknown")

    def test_tuple2tuple(self):
        tup = ("CN", "someCommonName")
        tup1 = x509_name_entry2tuple(tuple2x509_name_entry(tup))
        # tup1[0] is 'commonName', not 'CN'
        self.assertEqual(tup1[1], tup[1], tup1)
        self.assertEqual(x509_name_entry2tuple(tuple2x509_name_entry(tup1)),
                         tup1, tup1)

    def test_unknown(self):
        with self.assertRaises(ValueError):
            tuple2x509_name_entry(("__unknown", "_"))

    def test_x509_name(self):
        n = X509.X509_Name()
        # It seems this actually needs to be a real 2 letter country code
        n.C = b'US'
        n.SP = b'State or Province'
        n.L = b'locality name'
        n.O = b'orhanization name'
        n.OU = b'org unit'
        n.CN = b'common name'
        n.Email = b'bob@example.com'
        n.serialNumber = b'1234'
        n.SN = b'surname'
        n.GN = b'given name'

        n.givenName = b'name given'
        self.assertEqual(len(n), 11, len(n))

        # Thierry: this call to list seems extraneous...
        tl = [x509_name_entry2tuple(x) for x in x509_name2list(n)]

        self.assertEqual(len(tl), len(n), len(tl))

        x509_n = m2.x509_name_new()
        for o in [tuple2x509_name_entry(x) for x in tl]:
            m2.x509_name_add_entry(x509_n, o._ptr(), -1, 0)
            o._pyfree = 0  # Take care of underlying object
        n1 = X509.X509_Name(x509_n)

        self.assertEqual(n.as_text(), n1.as_text(), n1.as_text())


def suite():
    s = unittest.TestSuite()
    s.addTest(unittest.makeSuite(ObjectsTestCase))
    return s


if __name__ == '__main__':
    Rand.load_file('randpool.dat', -1)
    unittest.TextTestRunner().run(suite())
    Rand.save_file('randpool.dat')