summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Pipek <jan.pipek@gmail.com>2019-09-10 17:27:55 +0200
committerJan Pipek <jan.pipek@gmail.com>2019-09-10 17:36:22 +0200
commite27f97182e859fc6048ff13b028961da578dc340 (patch)
tree8c1e509ca01da2fa3c6c96b06a2cddb6e30ae0c6
parent043d97d7ecd01da7c5ac43a0e87565ba0f3bd35b (diff)
downloadpyasn1-git-e27f97182e859fc6048ff13b028961da578dc340.tar.gz
Implement _CachedStreamWrapper
-rw-r--r--pyasn1/codec/ber/decoder.py71
-rw-r--r--tests/codec/ber/test_decoder.py8
2 files changed, 66 insertions, 13 deletions
diff --git a/pyasn1/codec/ber/decoder.py b/pyasn1/codec/ber/decoder.py
index 0bd804c..b3a6c45 100644
--- a/pyasn1/codec/ber/decoder.py
+++ b/pyasn1/codec/ber/decoder.py
@@ -6,7 +6,7 @@
#
import os
import sys
-from io import BytesIO, BufferedReader
+from io import BytesIO, BufferedReader, IOBase
from pyasn1 import debug
from pyasn1 import error
@@ -29,10 +29,68 @@ LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER)
noValue = base.noValue
-_BUFFER_SIZE = 1024
+_MAX_BUFFER_SIZE = 1024
_PY2 = sys.version_info < (3,)
+class _CachedStreamWrapper(IOBase):
+ """Wrapper around non-seekable streams."""
+ def __init__(self, raw):
+ self._raw = raw
+ self._cache = BytesIO()
+ self._marked_position_ = 0
+
+ def peek(self, n):
+ pos = self._cache.tell()
+ result = self.read(n)
+ self._cache.seek(pos, os.SEEK_SET)
+ return result
+
+ def seekable(self):
+ return True
+
+ def seek(self, n=-1, whence=os.SEEK_SET):
+ return self._cache.seek(n, whence)
+
+ def read(self, n=-1):
+ read_from_cache = self._cache.read(n)
+ if n != -1:
+ n -= len(read_from_cache)
+ read_from_raw = self._raw.read(n)
+ self._cache.write(read_from_raw)
+ return read_from_cache + read_from_raw
+
+ @property
+ def _marked_position(self):
+ # This closely corresponds with how _marked_position attribute
+ # is manipulated with in Decoder.__call__ and (indefLen)ValueDecoder's
+ return self._marked_position_
+
+ @_marked_position.setter
+ def _marked_position(self, value):
+ self._marked_position_ = value
+ self.seek(value)
+ self.reset()
+
+ def tell(self):
+ return self._cache.tell()
+
+ def reset(self):
+ """Keep the buffered data reasonably large.
+
+ Whenever we se _marked_position, we know for sure
+ that we will not return back, and thus it is
+ safe to drop all cached data.
+ """
+ if self._cache.tell() > _MAX_BUFFER_SIZE:
+ current = self._cache.read()
+ self._cache.seek(0, os.SEEK_SET)
+ self._cache.truncate()
+ self._cache.write(current)
+ self._cache.seek(0, os.SEEK_SET)
+ self._marked_position_ = 0
+
+
def asSeekableStream(substrate):
"""Convert object to seekable byte-stream.
@@ -54,13 +112,12 @@ def asSeekableStream(substrate):
elif isinstance(substrate, univ.OctetString):
return BytesIO(substrate.asOctets())
try:
- if _PY2 and isinstance(substrate, file):
- return BytesIO(substrate.read()) # Not optimal for really large files
- elif substrate.seekable():
+ if _PY2 and isinstance(substrate, file): # Special case (it is not possible to set attributes)
+ return BufferedReader(substrate, _MAX_BUFFER_SIZE)
+ elif substrate.seekable(): # Will fail for most invalid types
return substrate
else:
- # TODO: Implement for non-seekable streams
- raise UnsupportedSubstrateError("Cannot use non-seekable bit stream: " + substrate.__class__.__name__)
+ return _CachedStreamWrapper(substrate)
except AttributeError:
raise UnsupportedSubstrateError("Cannot convert " + substrate.__class__.__name__ + " to a seekable bit stream.")
diff --git a/tests/codec/ber/test_decoder.py b/tests/codec/ber/test_decoder.py
index 0686c6d..141f7c7 100644
--- a/tests/codec/ber/test_decoder.py
+++ b/tests/codec/ber/test_decoder.py
@@ -1715,12 +1715,8 @@ class CompressedFilesTestCase(BaseTestCase):
with zipfile.ZipFile(path, "r") as myzip:
with myzip.open("data", "r") as source:
- if sys.version_info < (3,):
- with self.assertRaises(UnsupportedSubstrateError):
- _ = list(decoder.decodeStream(source))
- else:
- values = list(decoder.decodeStream(source))
- assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)]
+ values = list(decoder.decodeStream(source))
+ assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)]
finally:
os.remove(path)