summaryrefslogtreecommitdiff
path: root/pyasn1/codec/cer/encoder.py
blob: dad595597c801f8c0d6f3bd5523d612ba57486cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# CER encoder
from pyasn1.type import univ
from pyasn1.type import useful
from pyasn1.codec.ber import encoder
from pyasn1.compat.octets import int2oct, str2octs, null
from pyasn1 import error

class BooleanEncoder(encoder.IntegerEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        if client == 0:
            substrate = int2oct(0)
        else:
            substrate = int2oct(255)
        return substrate, 0

class BitStringEncoder(encoder.BitStringEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        return encoder.BitStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class OctetStringEncoder(encoder.OctetStringEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class RealEncoder(encoder.RealEncoder):
    def _chooseEncBase(self, value):
        m, b, e = value
        return self._dropFloatingPoint(m, b, e)

# specialized GeneralStringEncoder here

class GeneralizedTimeEncoder(OctetStringEncoder):
    zchar = str2octs('Z')
    zero = str2octs('0')
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        octets = client.asOctets()
        if '+' in octets:
            raise error.PyAsn1Error('Must be UTC time')
        if '.' not in octets:
            raise error.PyAsn1Error('Format must include fraction of second')
        if octets and octets[-1] != self.zchar:
            raise error.PyAsn1Error('Missing timezone specifier')
        if len(octets) < 16:
            raise error.PyAsn1Error('Bad UTC time length')
        if octets[-2] == self.zero or \
                octets[-3] != self.zero and octets[-2] == self.zero:
            raise error.PyAsn1Error('Meningless zero in fraction of second')
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class UTCTimeEncoder(encoder.OctetStringEncoder):
    zchar = str2octs('Z')
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        octets = client.asOctets()
        if '+' in octets:
            raise error.PyAsn1Error('Must be UTC time')
        if '.' in octets:
            raise error.PyAsn1Error('Must be no fraction of second')
        if octets and octets[-1] != self.zchar:
            client = client.clone(octets + self.zchar)
        if len(client) != 13:
            raise error.PyAsn1Error('Bad UTC time length')
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class SetOfEncoder(encoder.SequenceOfEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        if isinstance(client, univ.SequenceAndSetBase):
            client.setDefaultComponents()
        client.verifySizeSpec()
        substrate = null; idx = len(client)
        # This is certainly a hack but how else do I distinguish SetOf
        # from Set if they have the same tags&constraints?
        if isinstance(client, univ.SequenceAndSetBase):
            # Set
            comps = []
            while idx > 0:
                idx = idx - 1
                if client[idx] is None:  # Optional component
                    continue
                if client.getDefaultComponentByPosition(idx) == client[idx]:
                    continue
                comps.append(client[idx])
            comps.sort(key=lambda x: isinstance(x, univ.Choice) and \
                                     x.getMinTagSet() or x.getTagSet())
            for c in comps:
                substrate += encodeFun(c, defMode, maxChunkSize)
        else:
            # SetOf
            compSubs = []
            while idx > 0:
                idx = idx - 1
                compSubs.append(
                    encodeFun(client[idx], defMode, maxChunkSize)
                    )
            compSubs.sort()  # perhaps padding's not needed
            substrate = null
            for compSub in compSubs:
                substrate += compSub
        return substrate, 1

tagMap = encoder.tagMap.copy()
tagMap.update({
    univ.Boolean.tagSet: BooleanEncoder(),
    univ.BitString.tagSet: BitStringEncoder(),
    univ.OctetString.tagSet: OctetStringEncoder(),
    univ.Real.tagSet: RealEncoder(),
    useful.GeneralizedTime.tagSet: GeneralizedTimeEncoder(),
    useful.UTCTime.tagSet: UTCTimeEncoder(),
    univ.SetOf().tagSet: SetOfEncoder()  # conflcts with Set
})

typeMap = encoder.typeMap.copy()
typeMap.update({
    univ.Set.typeId: SetOfEncoder(),
    univ.SetOf.typeId: SetOfEncoder()
})

class Encoder(encoder.Encoder):
    def __call__(self, client, defMode=False, maxChunkSize=0):
        return encoder.Encoder.__call__(self, client, defMode, maxChunkSize)

encode = Encoder(tagMap, typeMap)

# EncoderFactory queries class instance and builds a map of tags -> encoders