summaryrefslogtreecommitdiff
path: root/pyasn1/codec/cer/encoder.py
blob: 2dd928a925aff07bc591ddaf36181cd321a1093e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# CER encoder
from pyasn1.type import univ
from pyasn1.type import useful
from pyasn1.codec.ber import encoder
from pyasn1.compat.octets import int2oct, str2octs, null
from pyasn1 import error

class BooleanEncoder(encoder.IntegerEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        if client == 0:
            substrate = int2oct(0)
        else:
            substrate = int2oct(255)
        return substrate, 0

class BitStringEncoder(encoder.BitStringEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        return encoder.BitStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class OctetStringEncoder(encoder.OctetStringEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class RealEncoder(encoder.RealEncoder):
    def _chooseEncBase(self, value):
        m, b, e = value
        return self._dropFloatingPoint(m, b, e)

# specialized GeneralStringEncoder here

class GeneralizedTimeEncoder(OctetStringEncoder):
    zchar = str2octs('Z')
    zero = str2octs('0')
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        octets = client.asOctets()
        if '+' in octets:
            raise error.PyAsn1Error('Must be UTC time: %r' % octets)
# This breaks too many existing data items
#        if '.' not in octets:
#            raise error.PyAsn1Error('Format must include fraction of second: %r' % octets)
        if octets and octets[-1] != self.zchar:
            raise error.PyAsn1Error('Missing timezone specifier: %r' % octets)
        if len(octets) < 15:
            raise error.PyAsn1Error('Bad UTC time length: %r' % octets)
        if octets[-2] == self.zero or \
                octets[-3] != self.zero and octets[-2] == self.zero:
            raise error.PyAsn1Error('Meningless zero in fraction of second: %r' % octets)
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class UTCTimeEncoder(encoder.OctetStringEncoder):
    zchar = str2octs('Z')
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        octets = client.asOctets()
        if '+' in octets:
            raise error.PyAsn1Error('Must be UTC time: %r' % octets)
        if octets and octets[-1] != self.zchar:
            client = client.clone(octets + self.zchar)
        if len(client) != 13:
            raise error.PyAsn1Error('Bad UTC time length: %r' % client)
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class SetOfEncoder(encoder.SequenceOfEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        if isinstance(client, univ.SequenceAndSetBase):
            client.setDefaultComponents()
        client.verifySizeSpec()
        substrate = null; idx = len(client)
        # This is certainly a hack but how else do I distinguish SetOf
        # from Set if they have the same tags&constraints?
        if isinstance(client, univ.SequenceAndSetBase):
            # Set
            comps = []
            while idx > 0:
                idx = idx - 1
                if client[idx] is None:  # Optional component
                    continue
                if client.getDefaultComponentByPosition(idx) == client[idx]:
                    continue
                comps.append(client[idx])
            comps.sort(key=lambda x: isinstance(x, univ.Choice) and \
                                     x.getMinTagSet() or x.getTagSet())
            for c in comps:
                substrate += encodeFun(c, defMode, maxChunkSize)
        else:
            # SetOf
            compSubs = []
            while idx > 0:
                idx = idx - 1
                compSubs.append(
                    encodeFun(client[idx], defMode, maxChunkSize)
                    )
            compSubs.sort()  # perhaps padding's not needed
            substrate = null
            for compSub in compSubs:
                substrate += compSub
        return substrate, 1

tagMap = encoder.tagMap.copy()
tagMap.update({
    univ.Boolean.tagSet: BooleanEncoder(),
    univ.BitString.tagSet: BitStringEncoder(),
    univ.OctetString.tagSet: OctetStringEncoder(),
    univ.Real.tagSet: RealEncoder(),
    useful.GeneralizedTime.tagSet: GeneralizedTimeEncoder(),
    useful.UTCTime.tagSet: UTCTimeEncoder(),
    univ.SetOf().tagSet: SetOfEncoder()  # conflcts with Set
})

typeMap = encoder.typeMap.copy()
typeMap.update({
    univ.Set.typeId: SetOfEncoder(),
    univ.SetOf.typeId: SetOfEncoder()
})

class Encoder(encoder.Encoder):
    def __call__(self, client, defMode=False, maxChunkSize=0):
        return encoder.Encoder.__call__(self, client, defMode, maxChunkSize)

encode = Encoder(tagMap, typeMap)

# EncoderFactory queries class instance and builds a map of tags -> encoders