summaryrefslogtreecommitdiff
path: root/src/saml2/aes.py
blob: 027c7a331c4fdccc916261950c07906bf061a9c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
import os
from base64 import b64encode
from base64 import b64decode

from Cryptodome import Random
from Cryptodome.Cipher import AES

__author__ = 'rolandh'

POSTFIX_MODE = {
    "cbc": AES.MODE_CBC,
    "cfb": AES.MODE_CFB,
    "ecb": AES.MODE_CFB,
}

BLOCK_SIZE = 16


class AESCipher(object):
    def __init__(self, key, iv=""):
        """

        :param key: The encryption key
        :param iv: Init vector
        :return: AESCipher instance
        """
        self.key = key
        self.iv = iv

    def build_cipher(self, iv="", alg="aes_128_cbc"):
        """
        :param iv: init vector
        :param alg: cipher algorithm
        :return: A Cipher instance
        """
        typ, bits, cmode = alg.split("_")

        if not iv:
            if self.iv:
                iv = self.iv
            else:
                iv = Random.new().read(AES.block_size)
        else:
            assert len(iv) == AES.block_size

        if bits not in ["128", "192", "256"]:
            raise Exception("Unsupported key length")
        try:
            assert len(self.key) == int(bits) >> 3
        except AssertionError:
            raise Exception("Wrong Key length")

        try:
            return AES.new(self.key, POSTFIX_MODE[cmode], iv), iv
        except KeyError:
            raise Exception("Unsupported chaining mode")


    def encrypt(self, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7",
                b64enc=True, block_size=BLOCK_SIZE):
        """
        :param key: The encryption key
        :param iv: init vector
        :param msg: Message to be encrypted
        :param padding: Which padding that should be used
        :param b64enc: Whether the result should be base64encoded
        :param block_size: If PKCS#7 padding which block size to use
        :return: The encrypted message
        """

        if padding == "PKCS#7":
            _block_size = block_size
        elif padding == "PKCS#5":
            _block_size = 8
        else:
            _block_size = 0

        if _block_size:
            plen = _block_size - (len(msg) % _block_size)
            c = chr(plen)
            msg += c*plen

        cipher, iv = self.build_cipher(iv, alg)
        cmsg = iv + cipher.encrypt(msg)
        if b64enc:
            return b64encode(cmsg)
        else:
            return cmsg


    def decrypt(self, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7", b64dec=True):
        """
        :param key: The encryption key
        :param iv: init vector
        :param msg: Base64 encoded message to be decrypted
        :return: The decrypted message
        """
        if b64dec:
            data = b64decode(msg)
        else:
            data = msg

        _iv = data[:AES.block_size]
        if iv:
            assert iv == _iv
        cipher, iv = self.build_cipher(iv, alg=alg)
        res = cipher.decrypt(data)[AES.block_size:]
        if padding in ["PKCS#5", "PKCS#7"]:
            res = res[:-ord(res[-1])]
        return res

if __name__ == "__main__":
    key_ = "1234523451234545"  # 16 byte key
    # Iff padded, the message doesn't have to be multiple of 16 in length
    msg_ = "ToBeOrNotTobe W.S."
    aes = AESCipher(key_)
    iv_ = os.urandom(16)
    encrypted_msg = aes.encrypt(key_, msg_, iv_)
    txt = aes.decrypt(key_, encrypted_msg, iv_)
    assert txt == msg_

    encrypted_msg = aes.encrypt(key_, msg_, 0)
    txt = aes.decrypt(key_, encrypted_msg, 0)
    assert txt == msg_