summaryrefslogtreecommitdiff
path: root/pyasn1/codec/cer/encoder.py
blob: 8a7180a802db19e6c86efa3cde1bdaeb99d93c69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#
# This file is part of pyasn1 software.
#
# Copyright (c) 2005-2015, Ilya Etingof <ilya@glas.net>
# License: http://pyasn1.sf.net/license.html
#
from pyasn1.type import univ
from pyasn1.type import useful
from pyasn1.codec.ber import encoder
from pyasn1.compat.octets import int2oct, str2octs, null
from pyasn1 import error

class BooleanEncoder(encoder.IntegerEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        if client == 0:
            substrate = int2oct(0)
        else:
            substrate = int2oct(255)
        return substrate, 0

class BitStringEncoder(encoder.BitStringEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        return encoder.BitStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class OctetStringEncoder(encoder.OctetStringEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class RealEncoder(encoder.RealEncoder):
    def _chooseEncBase(self, value):
        m, b, e = value
        return self._dropFloatingPoint(m, b, e)

# specialized GeneralStringEncoder here

class GeneralizedTimeEncoder(OctetStringEncoder):
    zchar = str2octs('Z')
    pluschar = str2octs('+')
    minuschar = str2octs('-')
    zero = str2octs('0')
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        octets = client.asOctets()
# This breaks too many existing data items
#        if '.' not in octets:
#            raise error.PyAsn1Error('Format must include fraction of second: %r' % octets)
        if len(octets) < 15:
            raise error.PyAsn1Error('Bad UTC time length: %r' % octets)
        if self.pluschar in octets or self.minuschar in octets:
            raise error.PyAsn1Error('Must be UTC time: %r' % octets)
        if octets[-1] != self.zchar[0]:
            raise error.PyAsn1Error('Missing timezone specifier: %r' % octets)
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class UTCTimeEncoder(encoder.OctetStringEncoder):
    zchar = str2octs('Z')
    pluschar = str2octs('+')
    minuschar = str2octs('-')
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        octets = client.asOctets()
        if self.pluschar in octets or self.minuschar in octets:
            raise error.PyAsn1Error('Must be UTC time: %r' % octets)
        if octets and octets[-1] != self.zchar[0]:
            client = client.clone(octets + self.zchar)
        if len(client) != 13:
            raise error.PyAsn1Error('Bad UTC time length: %r' % client)
        return encoder.OctetStringEncoder.encodeValue(
            self, encodeFun, client, defMode, 1000
        )

class SetOfEncoder(encoder.SequenceOfEncoder):
    def encodeValue(self, encodeFun, client, defMode, maxChunkSize):
        if isinstance(client, univ.SequenceAndSetBase):
            client.setDefaultComponents()
        client.verifySizeSpec()
        substrate = null; idx = len(client)
        # This is certainly a hack but how else do I distinguish SetOf
        # from Set if they have the same tags&constraints?
        if isinstance(client, univ.SequenceAndSetBase):
            # Set
            comps = []
            while idx > 0:
                idx = idx - 1
                if client[idx] is None:  # Optional component
                    continue
                if client.getDefaultComponentByPosition(idx) == client[idx]:
                    continue
                comps.append(client[idx])
            comps.sort(key=lambda x: isinstance(x, univ.Choice) and \
                                     x.getMinTagSet() or x.getTagSet())
            for c in comps:
                substrate += encodeFun(c, defMode, maxChunkSize)
        else:
            # SetOf
            compSubs = []
            while idx > 0:
                idx = idx - 1
                compSubs.append(
                    encodeFun(client[idx], defMode, maxChunkSize)
                    )
            compSubs.sort()  # perhaps padding's not needed
            substrate = null
            for compSub in compSubs:
                substrate += compSub
        return substrate, 1

tagMap = encoder.tagMap.copy()
tagMap.update({
    univ.Boolean.tagSet: BooleanEncoder(),
    univ.BitString.tagSet: BitStringEncoder(),
    univ.OctetString.tagSet: OctetStringEncoder(),
    univ.Real.tagSet: RealEncoder(),
    useful.GeneralizedTime.tagSet: GeneralizedTimeEncoder(),
    useful.UTCTime.tagSet: UTCTimeEncoder(),
    univ.SetOf().tagSet: SetOfEncoder()  # conflcts with Set
})

typeMap = encoder.typeMap.copy()
typeMap.update({
    univ.Set.typeId: SetOfEncoder(),
    univ.SetOf.typeId: SetOfEncoder()
})

class Encoder(encoder.Encoder):
    def __call__(self, client, defMode=False, maxChunkSize=0):
        return encoder.Encoder.__call__(self, client, defMode, maxChunkSize)

encode = Encoder(tagMap, typeMap)

# EncoderFactory queries class instance and builds a map of tags -> encoders