diff options
author | Matthew Craven <5086-clyring@users.noreply.gitlab.haskell.org> | 2022-06-02 19:20:20 -0400 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2022-06-22 08:22:12 -0400 |
commit | 4ccefc6ea83073319c59690e340916175087dace (patch) | |
tree | e3b9ccd0e0db50b573a1007e6bd7c724b179b90b /libraries | |
parent | f89bf85fcedb595f457dee2c7ef50a15cc958c1a (diff) | |
download | haskell-4ccefc6ea83073319c59690e340916175087dace.tar.gz |
Check for Int overflows in Data.Array.Byte
Diffstat (limited to 'libraries')
-rw-r--r-- | libraries/base/Data/Array/Byte.hs | 118 |
1 files changed, 91 insertions, 27 deletions
diff --git a/libraries/base/Data/Array/Byte.hs b/libraries/base/Data/Array/Byte.hs index b0697fe28c..3029f9c178 100644 --- a/libraries/base/Data/Array/Byte.hs +++ b/libraries/base/Data/Array/Byte.hs @@ -22,9 +22,11 @@ module Data.Array.Byte ( import Data.Bits ((.&.), unsafeShiftR) import Data.Data (mkNoRepType, Data(..), Typeable) import qualified Data.Foldable as F +import Data.Maybe (fromMaybe) import Data.Semigroup -import GHC.Show (intToDigit) import GHC.Exts +import GHC.Num.Integer (Integer(..)) +import GHC.Show (intToDigit) import GHC.ST (ST(..), runST) import GHC.Word (Word8(..)) @@ -127,6 +129,22 @@ copyByteArray (MutableByteArray dst#) (I# doff#) (ByteArray src#) (I# soff#) (I# ST (\s# -> case copyByteArray# src# soff# dst# doff# sz# s# of s'# -> (# s'#, () #)) +-- | Copy a slice from one mutable byte array to another +-- or to the same mutable byte array. +-- +-- /Note:/ this function does not do bounds checking. +copyMutableByteArray + :: MutableByteArray s -- ^ destination array + -> Int -- ^ offset into destination array + -> MutableByteArray s -- ^ source array + -> Int -- ^ offset into source array + -> Int -- ^ number of bytes to copy + -> ST s () +{-# INLINE copyMutableByteArray #-} +copyMutableByteArray (MutableByteArray dst#) (I# doff#) (MutableByteArray src#) (I# soff#) (I# sz#) = + ST (\s# -> case copyMutableByteArray# src# soff# dst# doff# sz# s# of + s'# -> (# s'#, () #)) + -- | @since 4.17.0.0 instance Data ByteArray where toConstr _ = error "toConstr" @@ -206,17 +224,23 @@ instance Ord ByteArray where -- | Append two byte arrays. appendByteArray :: ByteArray -> ByteArray -> ByteArray -appendByteArray a b = runST $ do - marr <- newByteArray (sizeofByteArray a + sizeofByteArray b) - copyByteArray marr 0 a 0 (sizeofByteArray a) - copyByteArray marr (sizeofByteArray a) b 0 (sizeofByteArray b) +appendByteArray ba1 ba2 = runST $ do + let n1 = sizeofByteArray ba1 + n2 = sizeofByteArray ba2 + totSz = fromMaybe (sizeOverflowError "appendByteArray") + (checkedIntAdd n1 n2) + marr <- newByteArray totSz + copyByteArray marr 0 ba1 0 n1 + copyByteArray marr n1 ba2 0 n2 unsafeFreezeByteArray marr -- | Concatenate a list of 'ByteArray's. concatByteArray :: [ByteArray] -> ByteArray concatByteArray arrs = runST $ do - let len = calcLength arrs 0 - marr <- newByteArray len + let addLen acc arr = fromMaybe (sizeOverflowError "concatByteArray") + (checkedIntAdd acc (sizeofByteArray arr)) + totLen = F.foldl' addLen 0 arrs + marr <- newByteArray totLen pasteByteArrays marr 0 arrs unsafeFreezeByteArray marr @@ -227,36 +251,56 @@ pasteByteArrays !marr !ix (x : xs) = do copyByteArray marr ix x 0 (sizeofByteArray x) pasteByteArrays marr (ix + sizeofByteArray x) xs --- | Compute total length of 'ByteArray's, increased by accumulator. -calcLength :: [ByteArray] -> Int -> Int -calcLength [] !n = n -calcLength (x : xs) !n = calcLength xs (sizeofByteArray x + n) - -- | An array of zero length. emptyByteArray :: ByteArray emptyByteArray = runST (newByteArray 0 >>= unsafeFreezeByteArray) --- | Replicate 'ByteArray' given number of times and concatenate all together. -replicateByteArray :: Int -> ByteArray -> ByteArray -replicateByteArray n arr = runST $ do - marr <- newByteArray (n * sizeofByteArray arr) - let go i = if i < n - then do - copyByteArray marr (i * sizeofByteArray arr) arr 0 (sizeofByteArray arr) - go (i + 1) - else return () - go 0 +-- | Concatenates a given number of copies of an input ByteArray. +stimesPolymorphic :: Integral t => t -> ByteArray -> ByteArray +{-# INLINABLE stimesPolymorphic #-} +stimesPolymorphic nRaw !arr = case toInteger nRaw of + IS nInt# + | isTrue# (nInt# ># 0#) -> stimesPositiveInt (I# nInt#) arr + | isTrue# (nInt# >=# 0#) -> emptyByteArray + -- This check is redundant for unsigned types like Word. + -- Using >=# intead of ==# may make it easier for GHC to notice that. + | otherwise -> stimesNegativeErr + IP _ + | sizeofByteArray arr == 0 -> emptyByteArray + | otherwise -> stimesOverflowErr + IN _ -> stimesNegativeErr + +stimesNegativeErr :: ByteArray +stimesNegativeErr = + errorWithoutStackTrace "stimes @ByteArray: negative multiplier" + +stimesOverflowErr :: a +stimesOverflowErr = sizeOverflowError "stimes" + +stimesPositiveInt :: Int -> ByteArray -> ByteArray +{-# NOINLINE stimesPositiveInt #-} +-- NOINLINE to prevent its duplication in specialisations of stimesPolymorphic +stimesPositiveInt n arr = runST $ do + let inpSz = sizeofByteArray arr + tarSz = fromMaybe stimesOverflowErr (checkedIntMultiply n inpSz) + marr <- newByteArray tarSz + copyByteArray marr 0 arr 0 inpSz + let + halfTarSz = (tarSz - 1) `div` 2 + go copied + | copied <= halfTarSz = do + copyMutableByteArray marr copied marr 0 copied + go (copied + copied) + | otherwise = copyMutableByteArray marr copied marr 0 (tarSz - copied) + go inpSz unsafeFreezeByteArray marr -- | @since 4.17.0.0 instance Semigroup ByteArray where (<>) = appendByteArray sconcat = mconcat . F.toList - stimes i arr - | itgr < 1 = emptyByteArray - | itgr <= (fromIntegral (maxBound :: Int)) = replicateByteArray (fromIntegral itgr) arr - | otherwise = error "Data.Array.Byte#stimes: cannot allocate the requested amount of memory" - where itgr = toInteger i :: Integer + {-# INLINE stimes #-} + stimes = stimesPolymorphic -- | @since 4.17.0.0 instance Monoid ByteArray where @@ -270,3 +314,23 @@ instance IsList ByteArray where toList = byteArrayToList fromList xs = byteArrayFromListN (length xs) xs fromListN = byteArrayFromListN + + +sizeOverflowError :: String -> a +sizeOverflowError fun + = errorWithoutStackTrace $ "Data.Array.Byte." ++ fun ++ ": size overflow" + + +-- TODO: Export these from a better home. + +-- | Adds two @Int@s, returning @Nothing@ if this results in an overflow +checkedIntAdd :: Int -> Int -> Maybe Int +checkedIntAdd (I# x#) (I# y#) = case addIntC# x# y# of + (# res, 0# #) -> Just (I# res) + _ -> Nothing + +-- | Multiplies two @Int@s, returning @Nothing@ if this results in an overflow +checkedIntMultiply :: Int -> Int -> Maybe Int +checkedIntMultiply (I# x#) (I# y#) = case timesInt2# x# y# of + (# 0#, _hi, lo #) -> Just (I# lo) + _ -> Nothing |