summaryrefslogtreecommitdiff
path: root/pyasn1
diff options
context:
space:
mode:
authorJan Pipek <jan.pipek@gmail.com>2019-09-10 17:27:55 +0200
committerIlya Etingof <etingof@gmail.com>2019-11-15 19:31:42 +0100
commite279319d412c6d7045c8bf90d0d887ed5097ff29 (patch)
tree18ae8ab6056f2a95fab2e535be0af6db82e2f2dc /pyasn1
parent21b4e64d28da30d3276228db5f5dd44f493a0092 (diff)
downloadpyasn1-git-e279319d412c6d7045c8bf90d0d887ed5097ff29.tar.gz
Implement _CachedStreamWrapper
Diffstat (limited to 'pyasn1')
-rw-r--r--pyasn1/codec/ber/decoder.py71
1 files changed, 64 insertions, 7 deletions
diff --git a/pyasn1/codec/ber/decoder.py b/pyasn1/codec/ber/decoder.py
index 7a22da0..820ee14 100644
--- a/pyasn1/codec/ber/decoder.py
+++ b/pyasn1/codec/ber/decoder.py
@@ -6,7 +6,7 @@
#
import os
import sys
-from io import BytesIO, BufferedReader
+from io import BytesIO, BufferedReader, IOBase
from pyasn1 import debug
from pyasn1 import error
@@ -29,10 +29,68 @@ LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER)
noValue = base.noValue
-_BUFFER_SIZE = 1024
+_MAX_BUFFER_SIZE = 1024
_PY2 = sys.version_info < (3,)
+class _CachedStreamWrapper(IOBase):
+ """Wrapper around non-seekable streams."""
+ def __init__(self, raw):
+ self._raw = raw
+ self._cache = BytesIO()
+ self._marked_position_ = 0
+
+ def peek(self, n):
+ pos = self._cache.tell()
+ result = self.read(n)
+ self._cache.seek(pos, os.SEEK_SET)
+ return result
+
+ def seekable(self):
+ return True
+
+ def seek(self, n=-1, whence=os.SEEK_SET):
+ return self._cache.seek(n, whence)
+
+ def read(self, n=-1):
+ read_from_cache = self._cache.read(n)
+ if n != -1:
+ n -= len(read_from_cache)
+ read_from_raw = self._raw.read(n)
+ self._cache.write(read_from_raw)
+ return read_from_cache + read_from_raw
+
+ @property
+ def _marked_position(self):
+ # This closely corresponds with how _marked_position attribute
+ # is manipulated with in Decoder.__call__ and (indefLen)ValueDecoder's
+ return self._marked_position_
+
+ @_marked_position.setter
+ def _marked_position(self, value):
+ self._marked_position_ = value
+ self.seek(value)
+ self.reset()
+
+ def tell(self):
+ return self._cache.tell()
+
+ def reset(self):
+ """Keep the buffered data reasonably large.
+
+ Whenever we se _marked_position, we know for sure
+ that we will not return back, and thus it is
+ safe to drop all cached data.
+ """
+ if self._cache.tell() > _MAX_BUFFER_SIZE:
+ current = self._cache.read()
+ self._cache.seek(0, os.SEEK_SET)
+ self._cache.truncate()
+ self._cache.write(current)
+ self._cache.seek(0, os.SEEK_SET)
+ self._marked_position_ = 0
+
+
def asSeekableStream(substrate):
"""Convert object to seekable byte-stream.
@@ -54,13 +112,12 @@ def asSeekableStream(substrate):
elif isinstance(substrate, univ.OctetString):
return BytesIO(substrate.asOctets())
try:
- if _PY2 and isinstance(substrate, file):
- return BytesIO(substrate.read()) # Not optimal for really large files
- elif substrate.seekable():
+ if _PY2 and isinstance(substrate, file): # Special case (it is not possible to set attributes)
+ return BufferedReader(substrate, _MAX_BUFFER_SIZE)
+ elif substrate.seekable(): # Will fail for most invalid types
return substrate
else:
- # TODO: Implement for non-seekable streams
- raise UnsupportedSubstrateError("Cannot use non-seekable bit stream: " + substrate.__class__.__name__)
+ return _CachedStreamWrapper(substrate)
except AttributeError:
raise UnsupportedSubstrateError("Cannot convert " + substrate.__class__.__name__ + " to a seekable bit stream.")