diff options
-rw-r--r-- | compiler/GHC/Core/Op/Specialise.hs | 329 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_run/T17151.hs | 18 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_run/T17151.stdout | 2 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_run/T17151a.hs | 205 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_run/all.T | 1 |
5 files changed, 432 insertions, 123 deletions
diff --git a/compiler/GHC/Core/Op/Specialise.hs b/compiler/GHC/Core/Op/Specialise.hs index d7e1ebe654..ba16ca4347 100644 --- a/compiler/GHC/Core/Op/Specialise.hs +++ b/compiler/GHC/Core/Op/Specialise.hs @@ -589,19 +589,11 @@ specProgram guts@(ModGuts { mg_module = this_mod -- Specialise the bindings of this module ; (binds', uds) <- runSpecM dflags this_mod (go binds) - -- Specialise imported functions - ; hpt_rules <- getRuleBase - ; let rule_base = extendRuleBaseList hpt_rules local_rules - ; (new_rules, spec_binds) <- specImports dflags this_mod top_env emptyVarSet - [] rule_base uds - - ; let final_binds - | null spec_binds = binds' - | otherwise = Rec (flattenBinds spec_binds) : binds' - -- Note [Glom the bindings if imported functions are specialised] + ; (spec_rules, spec_binds) <- specImports dflags this_mod top_env + local_rules uds - ; return (guts { mg_binds = final_binds - , mg_rules = new_rules ++ local_rules }) } + ; return (guts { mg_binds = spec_binds ++ binds' + , mg_rules = spec_rules ++ local_rules }) } where -- We need to start with a Subst that knows all the things -- that are in scope, so that the substitution engine doesn't @@ -645,72 +637,93 @@ See #10491 * * ********************************************************************* -} --- | Specialise a set of calls to imported bindings -specImports :: DynFlags - -> Module - -> SpecEnv -- Passed in so that all top-level Ids are in scope - -> VarSet -- Don't specialise these ones - -- See Note [Avoiding recursive specialisation] - -> [Id] -- Stack of imported functions being specialised - -> RuleBase -- Rules from this module and the home package - -- (but not external packages, which can change) - -> UsageDetails -- Calls for imported things, and floating bindings - -> CoreM ( [CoreRule] -- New rules - , [CoreBind] ) -- Specialised bindings - -- See Note [Wrapping bindings returned by specImports] -specImports dflags this_mod top_env done callers rule_base +specImports :: DynFlags -> Module -> SpecEnv + -> [CoreRule] + -> UsageDetails + -> CoreM ([CoreRule], [CoreBind]) +specImports dflags this_mod top_env local_rules (MkUD { ud_binds = dict_binds, ud_calls = calls }) - -- See Note [Disabling cross-module specialisation] | not $ gopt Opt_CrossModuleSpecialise dflags - = return ([], []) + -- See Note [Disabling cross-module specialisation] + = return ([], wrapDictBinds dict_binds []) | otherwise - = do { let import_calls = dVarEnvElts calls - ; (rules, spec_binds) <- go rule_base import_calls + = do { hpt_rules <- getRuleBase + ; let rule_base = extendRuleBaseList hpt_rules local_rules + + ; (spec_rules, spec_binds) <- spec_imports dflags this_mod top_env + [] rule_base + dict_binds calls -- Don't forget to wrap the specialized bindings with -- bindings for the needed dictionaries. -- See Note [Wrap bindings returned by specImports] - ; let spec_binds' = wrapDictBinds dict_binds spec_binds + -- and Note [Glom the bindings if imported functions are specialised] + ; let final_binds + | null spec_binds = wrapDictBinds dict_binds [] + | otherwise = [Rec $ flattenBinds $ + wrapDictBinds dict_binds spec_binds] + + ; return (spec_rules, final_binds) + } + +-- | Specialise a set of calls to imported bindings +spec_imports :: DynFlags + -> Module + -> SpecEnv -- Passed in so that all top-level Ids are in scope + -> [Id] -- Stack of imported functions being specialised + -- See Note [specImport call stack] + -> RuleBase -- Rules from this module and the home package + -- (but not external packages, which can change) + -> Bag DictBind -- Dict bindings, used /only/ for filterCalls + -- See Note [Avoiding loops in specImports] + -> CallDetails -- Calls for imported things + -> CoreM ( [CoreRule] -- New rules + , [CoreBind] ) -- Specialised bindings +spec_imports dflags this_mod top_env + callers rule_base dict_binds calls + = do { let import_calls = dVarEnvElts calls + -- ; debugTraceMsg (text "specImports {" <+> + -- vcat [ text "calls:" <+> ppr import_calls + -- , text "dict_binds:" <+> ppr dict_binds ]) + ; (rules, spec_binds) <- go rule_base import_calls + -- ; debugTraceMsg (text "End specImports }" <+> ppr import_calls) - ; return (rules, spec_binds') } + ; return (rules, spec_binds) } where go :: RuleBase -> [CallInfoSet] -> CoreM ([CoreRule], [CoreBind]) go _ [] = return ([], []) - go rb (cis@(CIS fn _) : other_calls) - = do { let ok_calls = filterCalls cis dict_binds - -- Drop calls that (directly or indirectly) refer to fn - -- See Note [Avoiding loops] --- ; debugTraceMsg (text "specImport" <+> vcat [ ppr fn --- , text "calls" <+> ppr cis --- , text "ud_binds =" <+> ppr dict_binds --- , text "dump set =" <+> ppr dump_set --- , text "filtered calls =" <+> ppr ok_calls ]) - ; (rules1, spec_binds1) <- specImport dflags this_mod top_env - done callers rb fn ok_calls + go rb (cis : other_calls) + = do { -- debugTraceMsg (text "specImport {" <+> ppr cis) + ; (rules1, spec_binds1) <- spec_import dflags this_mod top_env + callers rb dict_binds cis + -- ; debugTraceMsg (text "specImport }" <+> ppr cis) ; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls ; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) } -specImport :: DynFlags - -> Module - -> SpecEnv -- Passed in so that all top-level Ids are in scope - -> VarSet -- Don't specialise these - -- See Note [Avoiding recursive specialisation] - -> [Id] -- Stack of imported functions being specialised - -> RuleBase -- Rules from this module - -> Id -> [CallInfo] -- Imported function and calls for it - -> CoreM ( [CoreRule] -- New rules - , [CoreBind] ) -- Specialised bindings -specImport dflags this_mod top_env done callers rb fn calls_for_fn - | fn `elemVarSet` done +spec_import :: DynFlags + -> Module + -> SpecEnv -- Passed in so that all top-level Ids are in scope + -> [Id] -- Stack of imported functions being specialised + -- See Note [specImport call stack] + -> RuleBase -- Rules from this module + -> Bag DictBind -- Dict bindings, used /only/ for filterCalls + -- See Note [Avoiding loops in specImports] + -> CallInfoSet -- Imported function and calls for it + -> CoreM ( [CoreRule] -- New rules + , [CoreBind] ) -- Specialised bindings +spec_import dflags this_mod top_env callers + rb dict_binds cis@(CIS fn _) + | isIn "specImport" fn callers = return ([], []) -- No warning. This actually happens all the time -- when specialising a recursive function, because -- the RHS of the specialised function contains a recursive -- call to the original function - | null calls_for_fn -- We filtered out all the calls in deleteCallsMentioning - = return ([], []) + | null good_calls + = do { -- debugTraceMsg (text "specImport:no valid calls") + ; return ([], []) } | wantSpecImport dflags unfolding , Just rhs <- maybeUnfoldingTemplate unfolding @@ -723,32 +736,37 @@ specImport dflags this_mod top_env done callers rb fn calls_for_fn ; let full_rb = unionRuleBase rb (eps_rule_base eps) rules_for_fn = getRules (RuleEnv full_rb vis_orphs) fn - ; (rules1, spec_pairs, uds) - <- -- pprTrace "specImport1" (vcat [ppr fn, ppr calls_for_fn, ppr rhs]) $ - runSpecM dflags this_mod $ - specCalls (Just this_mod) top_env rules_for_fn calls_for_fn fn rhs + ; (rules1, spec_pairs, MkUD { ud_binds = dict_binds1, ud_calls = new_calls }) + <- do { -- debugTraceMsg (text "specImport1" <+> vcat [ppr fn, ppr good_calls, ppr rhs]) + ; runSpecM dflags this_mod $ + specCalls (Just this_mod) top_env rules_for_fn good_calls fn rhs } ; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs] -- After the rules kick in we may get recursion, but -- we rely on a global GlomBinds to sort that out later -- See Note [Glom the bindings if imported functions are specialised] -- Now specialise any cascaded calls - ; (rules2, spec_binds2) <- -- pprTrace "specImport 2" (ppr fn $$ ppr rules1 $$ ppr spec_binds1) $ - specImports dflags this_mod top_env - (extendVarSet done fn) - (fn:callers) - (extendRuleBaseList rb rules1) - uds + -- ; debugTraceMsg (text "specImport 2" <+> (ppr fn $$ ppr rules1 $$ ppr spec_binds1)) + ; (rules2, spec_binds2) <- spec_imports dflags this_mod top_env + (fn:callers) + (extendRuleBaseList rb rules1) + (dict_binds `unionBags` dict_binds1) + new_calls - ; let final_binds = spec_binds2 ++ spec_binds1 + ; let final_binds = wrapDictBinds dict_binds1 $ + spec_binds2 ++ spec_binds1 ; return (rules2 ++ rules1, final_binds) } - | otherwise = do { tryWarnMissingSpecs dflags callers fn calls_for_fn - ; return ([], [])} + | otherwise + = do { tryWarnMissingSpecs dflags callers fn good_calls + ; return ([], [])} where unfolding = realIdUnfolding fn -- We want to see the unfolding even for loop breakers + good_calls = filterCalls cis dict_binds + -- SUPER IMPORTANT! Drop calls that (directly or indirectly) refer to fn + -- See Note [Avoiding loops in specImports] -- | Returns whether or not to show a missed-spec warning. -- If -Wall-missed-specializations is on, show the warning. @@ -790,8 +808,114 @@ wantSpecImport dflags unf -- inside it that we want to specialise | otherwise -> False -- Stable, not INLINE, hence INLINABLE -{- Note [Warning about missed specialisations] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +{- Note [Avoiding loops in specImports] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +We must take great care when specialising instance declarations +(functions like $fOrdList) lest we accidentally build a recursive +dictionary. See Note [Avoiding loops]. + +The basic strategy of Note [Avoiding loops] is to use filterCalls +to discard loopy specialisations. But to do that we must ensure +that the in-scope dict-binds (passed to filterCalls) contains +all the needed dictionary bindings. In particular, in the recursive +call to spec_imorpts in spec_import, we must include the dict-binds +from the parent. Lacking this caused #17151, a really nasty bug. + +Here is what happened. +* Class struture: + Source is a superclass of Mut + Index is a superclass of Source + +* We started with these dict binds + dSource = $fSourcePix @Int $fIndexInt + dIndex = sc_sel dSource + dMut = $fMutPix @Int dIndex + and these calls to specialise + $fMutPix @Int dIndex + $fSourcePix @Int $fIndexInt + +* We specialised the call ($fMutPix @Int dIndex) + ==> new call ($fSourcePix @Int dIndex) + (because Source is a superclass of Mut) + +* We specialised ($fSourcePix @Int dIndex) + ==> produces specialised dict $s$fSourcePix, + a record with dIndex as a field + plus RULE forall d. ($fSourcePix @Int d) = $s$fSourcePix + *** This is the bogus step *** + +* Now we decide not to specialise the call + $fSourcePix @Int $fIndexInt + because we alredy have a RULE that matches it + +* Finally the simplifer rewrites + dSource = $fSourcePix @Int $fIndexInt + ==> dSource = $s$fSourcePix + +Disaster. Now we have + +Rewrite dSource's RHS to $s$fSourcePix Disaster + dSource = $s$fSourcePix + dIndex = sc_sel dSource + $s$fSourcePix = MkSource dIndex ... + +Solution: filterCalls should have stopped the bogus step, +by seeing that dIndex transitively uses $fSourcePix. But +it can only do that if it sees all the dict_binds. Wow. + +-------------- +Here's another example (#13429). Suppose we have + class Monoid v => C v a where ... + +We start with a call + f @ [Integer] @ Integer $fC[]Integer + +Specialising call to 'f' gives dict bindings + $dMonoid_1 :: Monoid [Integer] + $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer + + $dC_1 :: C [Integer] (Node [Integer] Integer) + $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1 + +...plus a recursive call to + f @ [Integer] @ (Node [Integer] Integer) $dC_1 + +Specialising that call gives + $dMonoid_2 :: Monoid [Integer] + $dMonoid_2 = M.$p1C @ [Integer] $dC_1 + + $dC_2 :: C [Integer] (Node [Integer] Integer) + $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2 + +Now we have two calls to the imported function + M.$fCvNode :: Monoid v => C v a + M.$fCvNode @v @a m = C m some_fun + +But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2) +for specialisation, else we get: + + $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1 + $dMonoid_2 = M.$p1C @ [Integer] $dC_1 + $s$fCvNode = C $dMonoid_2 ... + RULE M.$fCvNode [Integer] _ _ = $s$fCvNode + +Now use the rule to rewrite the call in the RHS of $dC_1 +and we get a loop! + + +Note [specImport call stack] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When specialising an imports function 'f', we may get new calls +of an imported fuction 'g', which we want to specialise in turn, +and similarly specialising 'g' might expose a new call to 'h'. + +We track the stack of enclosing functions. So when specialising 'h' we +haev a specImport call stack of [g,f]. We do this for two reasons: +* Note [Warning about missed specialisations] +* Note [Avoiding recursive specialisation] + +Note [Warning about missed specialisations] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Suppose * In module Lib, you carefully mark a function 'foo' INLINABLE * Import Lib(foo) into another module M @@ -807,6 +931,16 @@ is what Opt_WarnAllMissedSpecs does. ToDo: warn about missed opportunities for local functions. +Note [Avoiding recursive specialisation] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When we specialise 'f' we may find new overloaded calls to 'g', 'h' in +'f's RHS. So we want to specialise g,h. But we don't want to +specialise f any more! It's possible that f's RHS might have a +recursive yet-more-specialised call, so we'd diverge in that case. +And if the call is to the same type, one specialisation is enough. +Avoiding this recursive specialisation loop is one reason for the +'callers' stack passed to specImports and specImport. + Note [Specialise imported INLINABLE things] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ What imported functions do we specialise? The basic set is @@ -842,15 +976,6 @@ make sure that f_spec is recursive. Easiest thing is to make all the specialisations for imported bindings recursive. -Note [Avoiding recursive specialisation] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -When we specialise 'f' we may find new overloaded calls to 'g', 'h' in -'f's RHS. So we want to specialise g,h. But we don't want to -specialise f any more! It's possible that f's RHS might have a -recursive yet-more-specialised call, so we'd diverge in that case. -And if the call is to the same type, one specialisation is enough. -Avoiding this recursive specialisation loop is the reason for the -'done' VarSet passed to specImports and specImport. ************************************************************************ * * @@ -1637,13 +1762,11 @@ This translates to None of these definitions is recursive. What happened was that we generated a specialisation: - RULE forall d. dfun T d = dT :: C [T] dT = (MkD a d (meth d)) [T/a, d1/d] = MkD T d1 (meth d1) But now we use the RULE on the RHS of d2, to get - d2 = dT = MkD d1 (meth d1) d1 = $p1 d2 @@ -1661,46 +1784,6 @@ Solution: This is done by 'filterCalls' -------------- -Here's another example, this time for an imported dfun, so the call -to filterCalls is in specImports (#13429). Suppose we have - class Monoid v => C v a where ... - -We start with a call - f @ [Integer] @ Integer $fC[]Integer - -Specialising call to 'f' gives dict bindings - $dMonoid_1 :: Monoid [Integer] - $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer - - $dC_1 :: C [Integer] (Node [Integer] Integer) - $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1 - -...plus a recursive call to - f @ [Integer] @ (Node [Integer] Integer) $dC_1 - -Specialising that call gives - $dMonoid_2 :: Monoid [Integer] - $dMonoid_2 = M.$p1C @ [Integer] $dC_1 - - $dC_2 :: C [Integer] (Node [Integer] Integer) - $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2 - -Now we have two calls to the imported function - M.$fCvNode :: Monoid v => C v a - M.$fCvNode @v @a m = C m some_fun - -But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2) -for specialisation, else we get: - - $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1 - $dMonoid_2 = M.$p1C @ [Integer] $dC_1 - $s$fCvNode = C $dMonoid_2 ... - RULE M.$fCvNode [Integer] _ _ = $s$fCvNode - -Now use the rule to rewrite the call in the RHS of $dC_1 -and we get a loop! - --------------- Here's yet another example class C a where { foo,bar :: [a] -> [a] } diff --git a/testsuite/tests/simplCore/should_run/T17151.hs b/testsuite/tests/simplCore/should_run/T17151.hs new file mode 100644 index 0000000000..20c31ea11f --- /dev/null +++ b/testsuite/tests/simplCore/should_run/T17151.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FlexibleContexts #-} +module Main where + +import T17151a + +main :: IO () +main = do + let ys :: Array P Int Int + ys = computeS (makeArray D 1 (const 5)) + applyStencil :: + (Source P ix Int, Load D ix Int) + => Stencil ix Int Int + -> Array P ix Int + -> Array P ix Int + applyStencil s = computeS . mapStencil s + print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0) + print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0) diff --git a/testsuite/tests/simplCore/should_run/T17151.stdout b/testsuite/tests/simplCore/should_run/T17151.stdout new file mode 100644 index 0000000000..7a165dae5c --- /dev/null +++ b/testsuite/tests/simplCore/should_run/T17151.stdout @@ -0,0 +1,2 @@ +55 +55 diff --git a/testsuite/tests/simplCore/should_run/T17151a.hs b/testsuite/tests/simplCore/should_run/T17151a.hs new file mode 100644 index 0000000000..b2392e242e --- /dev/null +++ b/testsuite/tests/simplCore/should_run/T17151a.hs @@ -0,0 +1,205 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +module T17151a + ( computeS + , Stencil + , P(..) + , D(..) + , makeConvolutionStencilFromKernel + , mapStencil + , Array + , Construct(..) + , Source(..) + , Load(..) + , Mutable(..) + ) where + +import Control.Monad.ST +import Data.Functor.Identity +import GHC.STRef +import GHC.ST +import GHC.Exts +import Unsafe.Coerce +import Data.Kind + +---- Hacked up stuff to simulate primitive package +class Prim e where + indexByteArray :: ByteArray -> Int -> e + sizeOf :: e ->Int +instance Prim Int where + indexByteArray _ _ = 55 + sizeOf _ = 99 + +data ByteArray = BA +type MutableByteArray s = STRef s Int + +class Monad m => PrimMonad m where + type PrimState m + primitive :: (State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a +instance PrimMonad (ST s) where + type PrimState (ST s) = s + primitive = ST + +unsafeFreezeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray +unsafeFreezeByteArray a = return (unsafeCoerce a) + +newByteArray :: PrimMonad m => Int -> m (MutableByteArray (PrimState m)) +newByteArray (I# n#) + = primitive (\s# -> case newMutVar# 33 s# of + (# s'#, arr# #) -> (# s'#, STRef arr# #)) + +writeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> e -> m () +writeByteArray _ _ _ = return () + +----- End of hacked up stuff + +-------------- +newtype Stencil ix e a = + Stencil ((ix -> e) -> ix -> a) + +mapStencil :: Source r ix e => Stencil ix e a -> Array r ix e -> Array D ix a +mapStencil (Stencil stencilF) arr = DArray (size arr) (stencilF (unsafeIndex arr)) +{-# INLINE mapStencil #-} + +makeConvolutionStencilFromKernel + :: (Source r ix e, Num e) + => Array r ix e + -> Stencil ix e e +makeConvolutionStencilFromKernel arr = Stencil stencil + where + sz = size arr + sCenter = liftIndex (`quot` 2) sz + stencil getVal ix = + runIdentity $ + loopM 0 (< totalElem sz) (+ 1) 0 $ \i a -> + pure $ accum a (fromLinearIndex sz i) (unsafeLinearIndex arr i) + where + ixOff = liftIndex2 (+) ix sCenter + accum acc kIx kVal = getVal (liftIndex2 (-) ixOff kIx) * kVal + acc + {-# INLINE accum #-} + {-# INLINE stencil #-} +{-# INLINE makeConvolutionStencilFromKernel #-} + + +computeS :: (Mutable r ix e, Load r' ix e) => Array r' ix e -> Array r ix e +computeS arr = runST $ do + marr <- unsafeNew (size arr) + unsafeLoadIntoS marr arr + unsafeFreeze marr +{-# INLINE computeS #-} + + +data D = D deriving Show + +data instance Array D ix e = DArray{dSize :: ix, + dIndex :: ix -> e} + +instance Index ix => Construct D ix e where + makeArray _ = DArray + {-# INLINE makeArray #-} + +instance Index ix => Source D ix e where + unsafeIndex = dIndex + {-# INLINE unsafeIndex #-} + +instance Index ix => Load D ix e where + size = dSize + {-# INLINE size #-} + loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr) + {-# INLINE loadArrayM #-} + + +data P = P deriving Show + +data instance Array P ix e = PArray ix ByteArray + +instance (Prim e, Index ix) => Construct P ix e where + makeArray _ sz f = computeS (makeArray D sz f) + {-# INLINE makeArray #-} + +instance (Prim e, Index ix) => Source P ix e where + unsafeIndex (PArray sz a) = indexByteArray a . toLinearIndex sz + {-# INLINE unsafeIndex #-} + +instance (Prim e, Index ix) => Mutable P ix e where + data MArray s P ix e = MPArray ix (MutableByteArray s) + unsafeFreeze (MPArray sz a) = PArray sz <$> unsafeFreezeByteArray a + {-# INLINE unsafeFreeze #-} + unsafeNew sz = MPArray sz <$> newByteArray (totalElem sz * eSize) + where + eSize = sizeOf (undefined :: e) + {-# INLINE unsafeNew #-} + unsafeLinearWrite (MPArray _ ma) = writeByteArray ma + {-# INLINE unsafeLinearWrite #-} + + +instance (Prim e, Index ix) => Load P ix e where + size (PArray sz _) = sz + {-# INLINE size #-} + loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr) + {-# INLINE loadArrayM #-} + + +unsafeLinearIndex :: Source r ix e => Array r ix e -> Int -> e +unsafeLinearIndex arr = unsafeIndex arr . fromLinearIndex (size arr) +{-# INLINE unsafeLinearIndex #-} + + +loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a +loopM init' condition increment initAcc f = go init' initAcc + where + go step acc + | condition step = f step acc >>= go (increment step) + | otherwise = return acc +{-# INLINE loopM #-} + +splitLinearlyWith_ :: + Monad m => Int -> (Int -> b) -> (Int -> b -> m ()) -> m () +splitLinearlyWith_ totalLength index write = + loopM 0 (< totalLength) (+1) () $ \i () -> write i (index i) +{-# INLINE splitLinearlyWith_ #-} + + +data family Array r ix e :: Type + +class Index ix => Construct r ix e where + makeArray :: r -> ix -> (ix -> e) -> Array r ix e + +class Load r ix e => Source r ix e where + unsafeIndex :: Array r ix e -> ix -> e + +class Index ix => Load r ix e where + size :: Array r ix e -> ix + loadArrayM :: Monad m => Array r ix e -> (Int -> e -> m ()) -> m () + unsafeLoadIntoS :: + (Mutable r' ix e, PrimMonad m) => MArray (PrimState m) r' ix e -> Array r ix e -> m () + unsafeLoadIntoS marr arr = loadArrayM arr (unsafeLinearWrite marr) + {-# INLINE unsafeLoadIntoS #-} + +class (Construct r ix e, Source r ix e) => Mutable r ix e where + data MArray s r ix e :: Type + unsafeFreeze :: PrimMonad m => MArray (PrimState m) r ix e -> m (Array r ix e) + unsafeNew :: PrimMonad m => ix -> m (MArray (PrimState m) r ix e) + unsafeLinearWrite :: PrimMonad m => MArray (PrimState m) r ix e -> Int -> e -> m () + + +class (Eq ix, Ord ix, Show ix) => + Index ix + where + totalElem :: ix -> Int + liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix + liftIndex :: (Int -> Int) -> ix -> ix + toLinearIndex :: ix -> ix -> Int + fromLinearIndex :: ix -> Int -> ix + +instance Index Int where + totalElem = id + toLinearIndex _ = id + fromLinearIndex _ = id + liftIndex f = f + liftIndex2 f = f diff --git a/testsuite/tests/simplCore/should_run/all.T b/testsuite/tests/simplCore/should_run/all.T index 3a61665cdf..d101bff84b 100644 --- a/testsuite/tests/simplCore/should_run/all.T +++ b/testsuite/tests/simplCore/should_run/all.T @@ -92,3 +92,4 @@ test('T15840', normal, compile_and_run, ['']) test('T15840a', normal, compile_and_run, ['']) test('T16066', exit_code(1), compile_and_run, ['-O1']) test('T17206', exit_code(1), compile_and_run, ['']) +test('T17151', [], multimod_compile_and_run, ['T17151', '']) |