summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Craven <5086-clyring@users.noreply.gitlab.haskell.org>2022-06-02 19:20:20 -0400
committerMarge Bot <ben+marge-bot@smart-cactus.org>2022-06-22 08:22:12 -0400
commit4ccefc6ea83073319c59690e340916175087dace (patch)
treee3b9ccd0e0db50b573a1007e6bd7c724b179b90b
parentf89bf85fcedb595f457dee2c7ef50a15cc958c1a (diff)
downloadhaskell-4ccefc6ea83073319c59690e340916175087dace.tar.gz
Check for Int overflows in Data.Array.Byte
-rw-r--r--libraries/base/Data/Array/Byte.hs118
1 files changed, 91 insertions, 27 deletions
diff --git a/libraries/base/Data/Array/Byte.hs b/libraries/base/Data/Array/Byte.hs
index b0697fe28c..3029f9c178 100644
--- a/libraries/base/Data/Array/Byte.hs
+++ b/libraries/base/Data/Array/Byte.hs
@@ -22,9 +22,11 @@ module Data.Array.Byte (
import Data.Bits ((.&.), unsafeShiftR)
import Data.Data (mkNoRepType, Data(..), Typeable)
import qualified Data.Foldable as F
+import Data.Maybe (fromMaybe)
import Data.Semigroup
-import GHC.Show (intToDigit)
import GHC.Exts
+import GHC.Num.Integer (Integer(..))
+import GHC.Show (intToDigit)
import GHC.ST (ST(..), runST)
import GHC.Word (Word8(..))
@@ -127,6 +129,22 @@ copyByteArray (MutableByteArray dst#) (I# doff#) (ByteArray src#) (I# soff#) (I#
ST (\s# -> case copyByteArray# src# soff# dst# doff# sz# s# of
s'# -> (# s'#, () #))
+-- | Copy a slice from one mutable byte array to another
+-- or to the same mutable byte array.
+--
+-- /Note:/ this function does not do bounds checking.
+copyMutableByteArray
+ :: MutableByteArray s -- ^ destination array
+ -> Int -- ^ offset into destination array
+ -> MutableByteArray s -- ^ source array
+ -> Int -- ^ offset into source array
+ -> Int -- ^ number of bytes to copy
+ -> ST s ()
+{-# INLINE copyMutableByteArray #-}
+copyMutableByteArray (MutableByteArray dst#) (I# doff#) (MutableByteArray src#) (I# soff#) (I# sz#) =
+ ST (\s# -> case copyMutableByteArray# src# soff# dst# doff# sz# s# of
+ s'# -> (# s'#, () #))
+
-- | @since 4.17.0.0
instance Data ByteArray where
toConstr _ = error "toConstr"
@@ -206,17 +224,23 @@ instance Ord ByteArray where
-- | Append two byte arrays.
appendByteArray :: ByteArray -> ByteArray -> ByteArray
-appendByteArray a b = runST $ do
- marr <- newByteArray (sizeofByteArray a + sizeofByteArray b)
- copyByteArray marr 0 a 0 (sizeofByteArray a)
- copyByteArray marr (sizeofByteArray a) b 0 (sizeofByteArray b)
+appendByteArray ba1 ba2 = runST $ do
+ let n1 = sizeofByteArray ba1
+ n2 = sizeofByteArray ba2
+ totSz = fromMaybe (sizeOverflowError "appendByteArray")
+ (checkedIntAdd n1 n2)
+ marr <- newByteArray totSz
+ copyByteArray marr 0 ba1 0 n1
+ copyByteArray marr n1 ba2 0 n2
unsafeFreezeByteArray marr
-- | Concatenate a list of 'ByteArray's.
concatByteArray :: [ByteArray] -> ByteArray
concatByteArray arrs = runST $ do
- let len = calcLength arrs 0
- marr <- newByteArray len
+ let addLen acc arr = fromMaybe (sizeOverflowError "concatByteArray")
+ (checkedIntAdd acc (sizeofByteArray arr))
+ totLen = F.foldl' addLen 0 arrs
+ marr <- newByteArray totLen
pasteByteArrays marr 0 arrs
unsafeFreezeByteArray marr
@@ -227,36 +251,56 @@ pasteByteArrays !marr !ix (x : xs) = do
copyByteArray marr ix x 0 (sizeofByteArray x)
pasteByteArrays marr (ix + sizeofByteArray x) xs
--- | Compute total length of 'ByteArray's, increased by accumulator.
-calcLength :: [ByteArray] -> Int -> Int
-calcLength [] !n = n
-calcLength (x : xs) !n = calcLength xs (sizeofByteArray x + n)
-
-- | An array of zero length.
emptyByteArray :: ByteArray
emptyByteArray = runST (newByteArray 0 >>= unsafeFreezeByteArray)
--- | Replicate 'ByteArray' given number of times and concatenate all together.
-replicateByteArray :: Int -> ByteArray -> ByteArray
-replicateByteArray n arr = runST $ do
- marr <- newByteArray (n * sizeofByteArray arr)
- let go i = if i < n
- then do
- copyByteArray marr (i * sizeofByteArray arr) arr 0 (sizeofByteArray arr)
- go (i + 1)
- else return ()
- go 0
+-- | Concatenates a given number of copies of an input ByteArray.
+stimesPolymorphic :: Integral t => t -> ByteArray -> ByteArray
+{-# INLINABLE stimesPolymorphic #-}
+stimesPolymorphic nRaw !arr = case toInteger nRaw of
+ IS nInt#
+ | isTrue# (nInt# ># 0#) -> stimesPositiveInt (I# nInt#) arr
+ | isTrue# (nInt# >=# 0#) -> emptyByteArray
+ -- This check is redundant for unsigned types like Word.
+ -- Using >=# intead of ==# may make it easier for GHC to notice that.
+ | otherwise -> stimesNegativeErr
+ IP _
+ | sizeofByteArray arr == 0 -> emptyByteArray
+ | otherwise -> stimesOverflowErr
+ IN _ -> stimesNegativeErr
+
+stimesNegativeErr :: ByteArray
+stimesNegativeErr =
+ errorWithoutStackTrace "stimes @ByteArray: negative multiplier"
+
+stimesOverflowErr :: a
+stimesOverflowErr = sizeOverflowError "stimes"
+
+stimesPositiveInt :: Int -> ByteArray -> ByteArray
+{-# NOINLINE stimesPositiveInt #-}
+-- NOINLINE to prevent its duplication in specialisations of stimesPolymorphic
+stimesPositiveInt n arr = runST $ do
+ let inpSz = sizeofByteArray arr
+ tarSz = fromMaybe stimesOverflowErr (checkedIntMultiply n inpSz)
+ marr <- newByteArray tarSz
+ copyByteArray marr 0 arr 0 inpSz
+ let
+ halfTarSz = (tarSz - 1) `div` 2
+ go copied
+ | copied <= halfTarSz = do
+ copyMutableByteArray marr copied marr 0 copied
+ go (copied + copied)
+ | otherwise = copyMutableByteArray marr copied marr 0 (tarSz - copied)
+ go inpSz
unsafeFreezeByteArray marr
-- | @since 4.17.0.0
instance Semigroup ByteArray where
(<>) = appendByteArray
sconcat = mconcat . F.toList
- stimes i arr
- | itgr < 1 = emptyByteArray
- | itgr <= (fromIntegral (maxBound :: Int)) = replicateByteArray (fromIntegral itgr) arr
- | otherwise = error "Data.Array.Byte#stimes: cannot allocate the requested amount of memory"
- where itgr = toInteger i :: Integer
+ {-# INLINE stimes #-}
+ stimes = stimesPolymorphic
-- | @since 4.17.0.0
instance Monoid ByteArray where
@@ -270,3 +314,23 @@ instance IsList ByteArray where
toList = byteArrayToList
fromList xs = byteArrayFromListN (length xs) xs
fromListN = byteArrayFromListN
+
+
+sizeOverflowError :: String -> a
+sizeOverflowError fun
+ = errorWithoutStackTrace $ "Data.Array.Byte." ++ fun ++ ": size overflow"
+
+
+-- TODO: Export these from a better home.
+
+-- | Adds two @Int@s, returning @Nothing@ if this results in an overflow
+checkedIntAdd :: Int -> Int -> Maybe Int
+checkedIntAdd (I# x#) (I# y#) = case addIntC# x# y# of
+ (# res, 0# #) -> Just (I# res)
+ _ -> Nothing
+
+-- | Multiplies two @Int@s, returning @Nothing@ if this results in an overflow
+checkedIntMultiply :: Int -> Int -> Maybe Int
+checkedIntMultiply (I# x#) (I# y#) = case timesInt2# x# y# of
+ (# 0#, _hi, lo #) -> Just (I# lo)
+ _ -> Nothing