summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2020-04-02 13:42:51 +0100
committerMarge Bot <ben+marge-bot@smart-cactus.org>2020-04-06 13:16:44 -0400
commitcec2c71fe91c88649628c6e83416533b816b86a5 (patch)
tree065b3a34275f9605e01fd10578fa16bd72f8ad37
parentdcfe29c8520244764146c7a5f336be1f9700db6c (diff)
downloadhaskell-cec2c71fe91c88649628c6e83416533b816b86a5.tar.gz
Fix an tricky specialiser loop
Issue #17151 was a very tricky example of a bug in which the specialiser accidentally constructs a recurive dictionary, so that everything turns into bottom. I have fixed variants of this bug at least twice before: see Note [Avoiding loops]. It was a bit of a struggle to isolate the problem, greatly aided by the work that Alexey Kuleshevich did in distilling a test case. Once I'd understood the problem, it was not difficult to fix, though it did lead me a bit of refactoring in specImports.
-rw-r--r--compiler/GHC/Core/Op/Specialise.hs329
-rw-r--r--testsuite/tests/simplCore/should_run/T17151.hs18
-rw-r--r--testsuite/tests/simplCore/should_run/T17151.stdout2
-rw-r--r--testsuite/tests/simplCore/should_run/T17151a.hs205
-rw-r--r--testsuite/tests/simplCore/should_run/all.T1
5 files changed, 432 insertions, 123 deletions
diff --git a/compiler/GHC/Core/Op/Specialise.hs b/compiler/GHC/Core/Op/Specialise.hs
index d7e1ebe654..ba16ca4347 100644
--- a/compiler/GHC/Core/Op/Specialise.hs
+++ b/compiler/GHC/Core/Op/Specialise.hs
@@ -589,19 +589,11 @@ specProgram guts@(ModGuts { mg_module = this_mod
-- Specialise the bindings of this module
; (binds', uds) <- runSpecM dflags this_mod (go binds)
- -- Specialise imported functions
- ; hpt_rules <- getRuleBase
- ; let rule_base = extendRuleBaseList hpt_rules local_rules
- ; (new_rules, spec_binds) <- specImports dflags this_mod top_env emptyVarSet
- [] rule_base uds
-
- ; let final_binds
- | null spec_binds = binds'
- | otherwise = Rec (flattenBinds spec_binds) : binds'
- -- Note [Glom the bindings if imported functions are specialised]
+ ; (spec_rules, spec_binds) <- specImports dflags this_mod top_env
+ local_rules uds
- ; return (guts { mg_binds = final_binds
- , mg_rules = new_rules ++ local_rules }) }
+ ; return (guts { mg_binds = spec_binds ++ binds'
+ , mg_rules = spec_rules ++ local_rules }) }
where
-- We need to start with a Subst that knows all the things
-- that are in scope, so that the substitution engine doesn't
@@ -645,72 +637,93 @@ See #10491
* *
********************************************************************* -}
--- | Specialise a set of calls to imported bindings
-specImports :: DynFlags
- -> Module
- -> SpecEnv -- Passed in so that all top-level Ids are in scope
- -> VarSet -- Don't specialise these ones
- -- See Note [Avoiding recursive specialisation]
- -> [Id] -- Stack of imported functions being specialised
- -> RuleBase -- Rules from this module and the home package
- -- (but not external packages, which can change)
- -> UsageDetails -- Calls for imported things, and floating bindings
- -> CoreM ( [CoreRule] -- New rules
- , [CoreBind] ) -- Specialised bindings
- -- See Note [Wrapping bindings returned by specImports]
-specImports dflags this_mod top_env done callers rule_base
+specImports :: DynFlags -> Module -> SpecEnv
+ -> [CoreRule]
+ -> UsageDetails
+ -> CoreM ([CoreRule], [CoreBind])
+specImports dflags this_mod top_env local_rules
(MkUD { ud_binds = dict_binds, ud_calls = calls })
- -- See Note [Disabling cross-module specialisation]
| not $ gopt Opt_CrossModuleSpecialise dflags
- = return ([], [])
+ -- See Note [Disabling cross-module specialisation]
+ = return ([], wrapDictBinds dict_binds [])
| otherwise
- = do { let import_calls = dVarEnvElts calls
- ; (rules, spec_binds) <- go rule_base import_calls
+ = do { hpt_rules <- getRuleBase
+ ; let rule_base = extendRuleBaseList hpt_rules local_rules
+
+ ; (spec_rules, spec_binds) <- spec_imports dflags this_mod top_env
+ [] rule_base
+ dict_binds calls
-- Don't forget to wrap the specialized bindings with
-- bindings for the needed dictionaries.
-- See Note [Wrap bindings returned by specImports]
- ; let spec_binds' = wrapDictBinds dict_binds spec_binds
+ -- and Note [Glom the bindings if imported functions are specialised]
+ ; let final_binds
+ | null spec_binds = wrapDictBinds dict_binds []
+ | otherwise = [Rec $ flattenBinds $
+ wrapDictBinds dict_binds spec_binds]
+
+ ; return (spec_rules, final_binds)
+ }
+
+-- | Specialise a set of calls to imported bindings
+spec_imports :: DynFlags
+ -> Module
+ -> SpecEnv -- Passed in so that all top-level Ids are in scope
+ -> [Id] -- Stack of imported functions being specialised
+ -- See Note [specImport call stack]
+ -> RuleBase -- Rules from this module and the home package
+ -- (but not external packages, which can change)
+ -> Bag DictBind -- Dict bindings, used /only/ for filterCalls
+ -- See Note [Avoiding loops in specImports]
+ -> CallDetails -- Calls for imported things
+ -> CoreM ( [CoreRule] -- New rules
+ , [CoreBind] ) -- Specialised bindings
+spec_imports dflags this_mod top_env
+ callers rule_base dict_binds calls
+ = do { let import_calls = dVarEnvElts calls
+ -- ; debugTraceMsg (text "specImports {" <+>
+ -- vcat [ text "calls:" <+> ppr import_calls
+ -- , text "dict_binds:" <+> ppr dict_binds ])
+ ; (rules, spec_binds) <- go rule_base import_calls
+ -- ; debugTraceMsg (text "End specImports }" <+> ppr import_calls)
- ; return (rules, spec_binds') }
+ ; return (rules, spec_binds) }
where
go :: RuleBase -> [CallInfoSet] -> CoreM ([CoreRule], [CoreBind])
go _ [] = return ([], [])
- go rb (cis@(CIS fn _) : other_calls)
- = do { let ok_calls = filterCalls cis dict_binds
- -- Drop calls that (directly or indirectly) refer to fn
- -- See Note [Avoiding loops]
--- ; debugTraceMsg (text "specImport" <+> vcat [ ppr fn
--- , text "calls" <+> ppr cis
--- , text "ud_binds =" <+> ppr dict_binds
--- , text "dump set =" <+> ppr dump_set
--- , text "filtered calls =" <+> ppr ok_calls ])
- ; (rules1, spec_binds1) <- specImport dflags this_mod top_env
- done callers rb fn ok_calls
+ go rb (cis : other_calls)
+ = do { -- debugTraceMsg (text "specImport {" <+> ppr cis)
+ ; (rules1, spec_binds1) <- spec_import dflags this_mod top_env
+ callers rb dict_binds cis
+ -- ; debugTraceMsg (text "specImport }" <+> ppr cis)
; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls
; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
-specImport :: DynFlags
- -> Module
- -> SpecEnv -- Passed in so that all top-level Ids are in scope
- -> VarSet -- Don't specialise these
- -- See Note [Avoiding recursive specialisation]
- -> [Id] -- Stack of imported functions being specialised
- -> RuleBase -- Rules from this module
- -> Id -> [CallInfo] -- Imported function and calls for it
- -> CoreM ( [CoreRule] -- New rules
- , [CoreBind] ) -- Specialised bindings
-specImport dflags this_mod top_env done callers rb fn calls_for_fn
- | fn `elemVarSet` done
+spec_import :: DynFlags
+ -> Module
+ -> SpecEnv -- Passed in so that all top-level Ids are in scope
+ -> [Id] -- Stack of imported functions being specialised
+ -- See Note [specImport call stack]
+ -> RuleBase -- Rules from this module
+ -> Bag DictBind -- Dict bindings, used /only/ for filterCalls
+ -- See Note [Avoiding loops in specImports]
+ -> CallInfoSet -- Imported function and calls for it
+ -> CoreM ( [CoreRule] -- New rules
+ , [CoreBind] ) -- Specialised bindings
+spec_import dflags this_mod top_env callers
+ rb dict_binds cis@(CIS fn _)
+ | isIn "specImport" fn callers
= return ([], []) -- No warning. This actually happens all the time
-- when specialising a recursive function, because
-- the RHS of the specialised function contains a recursive
-- call to the original function
- | null calls_for_fn -- We filtered out all the calls in deleteCallsMentioning
- = return ([], [])
+ | null good_calls
+ = do { -- debugTraceMsg (text "specImport:no valid calls")
+ ; return ([], []) }
| wantSpecImport dflags unfolding
, Just rhs <- maybeUnfoldingTemplate unfolding
@@ -723,32 +736,37 @@ specImport dflags this_mod top_env done callers rb fn calls_for_fn
; let full_rb = unionRuleBase rb (eps_rule_base eps)
rules_for_fn = getRules (RuleEnv full_rb vis_orphs) fn
- ; (rules1, spec_pairs, uds)
- <- -- pprTrace "specImport1" (vcat [ppr fn, ppr calls_for_fn, ppr rhs]) $
- runSpecM dflags this_mod $
- specCalls (Just this_mod) top_env rules_for_fn calls_for_fn fn rhs
+ ; (rules1, spec_pairs, MkUD { ud_binds = dict_binds1, ud_calls = new_calls })
+ <- do { -- debugTraceMsg (text "specImport1" <+> vcat [ppr fn, ppr good_calls, ppr rhs])
+ ; runSpecM dflags this_mod $
+ specCalls (Just this_mod) top_env rules_for_fn good_calls fn rhs }
; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs]
-- After the rules kick in we may get recursion, but
-- we rely on a global GlomBinds to sort that out later
-- See Note [Glom the bindings if imported functions are specialised]
-- Now specialise any cascaded calls
- ; (rules2, spec_binds2) <- -- pprTrace "specImport 2" (ppr fn $$ ppr rules1 $$ ppr spec_binds1) $
- specImports dflags this_mod top_env
- (extendVarSet done fn)
- (fn:callers)
- (extendRuleBaseList rb rules1)
- uds
+ -- ; debugTraceMsg (text "specImport 2" <+> (ppr fn $$ ppr rules1 $$ ppr spec_binds1))
+ ; (rules2, spec_binds2) <- spec_imports dflags this_mod top_env
+ (fn:callers)
+ (extendRuleBaseList rb rules1)
+ (dict_binds `unionBags` dict_binds1)
+ new_calls
- ; let final_binds = spec_binds2 ++ spec_binds1
+ ; let final_binds = wrapDictBinds dict_binds1 $
+ spec_binds2 ++ spec_binds1
; return (rules2 ++ rules1, final_binds) }
- | otherwise = do { tryWarnMissingSpecs dflags callers fn calls_for_fn
- ; return ([], [])}
+ | otherwise
+ = do { tryWarnMissingSpecs dflags callers fn good_calls
+ ; return ([], [])}
where
unfolding = realIdUnfolding fn -- We want to see the unfolding even for loop breakers
+ good_calls = filterCalls cis dict_binds
+ -- SUPER IMPORTANT! Drop calls that (directly or indirectly) refer to fn
+ -- See Note [Avoiding loops in specImports]
-- | Returns whether or not to show a missed-spec warning.
-- If -Wall-missed-specializations is on, show the warning.
@@ -790,8 +808,114 @@ wantSpecImport dflags unf
-- inside it that we want to specialise
| otherwise -> False -- Stable, not INLINE, hence INLINABLE
-{- Note [Warning about missed specialisations]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Avoiding loops in specImports]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We must take great care when specialising instance declarations
+(functions like $fOrdList) lest we accidentally build a recursive
+dictionary. See Note [Avoiding loops].
+
+The basic strategy of Note [Avoiding loops] is to use filterCalls
+to discard loopy specialisations. But to do that we must ensure
+that the in-scope dict-binds (passed to filterCalls) contains
+all the needed dictionary bindings. In particular, in the recursive
+call to spec_imorpts in spec_import, we must include the dict-binds
+from the parent. Lacking this caused #17151, a really nasty bug.
+
+Here is what happened.
+* Class struture:
+ Source is a superclass of Mut
+ Index is a superclass of Source
+
+* We started with these dict binds
+ dSource = $fSourcePix @Int $fIndexInt
+ dIndex = sc_sel dSource
+ dMut = $fMutPix @Int dIndex
+ and these calls to specialise
+ $fMutPix @Int dIndex
+ $fSourcePix @Int $fIndexInt
+
+* We specialised the call ($fMutPix @Int dIndex)
+ ==> new call ($fSourcePix @Int dIndex)
+ (because Source is a superclass of Mut)
+
+* We specialised ($fSourcePix @Int dIndex)
+ ==> produces specialised dict $s$fSourcePix,
+ a record with dIndex as a field
+ plus RULE forall d. ($fSourcePix @Int d) = $s$fSourcePix
+ *** This is the bogus step ***
+
+* Now we decide not to specialise the call
+ $fSourcePix @Int $fIndexInt
+ because we alredy have a RULE that matches it
+
+* Finally the simplifer rewrites
+ dSource = $fSourcePix @Int $fIndexInt
+ ==> dSource = $s$fSourcePix
+
+Disaster. Now we have
+
+Rewrite dSource's RHS to $s$fSourcePix Disaster
+ dSource = $s$fSourcePix
+ dIndex = sc_sel dSource
+ $s$fSourcePix = MkSource dIndex ...
+
+Solution: filterCalls should have stopped the bogus step,
+by seeing that dIndex transitively uses $fSourcePix. But
+it can only do that if it sees all the dict_binds. Wow.
+
+--------------
+Here's another example (#13429). Suppose we have
+ class Monoid v => C v a where ...
+
+We start with a call
+ f @ [Integer] @ Integer $fC[]Integer
+
+Specialising call to 'f' gives dict bindings
+ $dMonoid_1 :: Monoid [Integer]
+ $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer
+
+ $dC_1 :: C [Integer] (Node [Integer] Integer)
+ $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
+
+...plus a recursive call to
+ f @ [Integer] @ (Node [Integer] Integer) $dC_1
+
+Specialising that call gives
+ $dMonoid_2 :: Monoid [Integer]
+ $dMonoid_2 = M.$p1C @ [Integer] $dC_1
+
+ $dC_2 :: C [Integer] (Node [Integer] Integer)
+ $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2
+
+Now we have two calls to the imported function
+ M.$fCvNode :: Monoid v => C v a
+ M.$fCvNode @v @a m = C m some_fun
+
+But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2)
+for specialisation, else we get:
+
+ $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
+ $dMonoid_2 = M.$p1C @ [Integer] $dC_1
+ $s$fCvNode = C $dMonoid_2 ...
+ RULE M.$fCvNode [Integer] _ _ = $s$fCvNode
+
+Now use the rule to rewrite the call in the RHS of $dC_1
+and we get a loop!
+
+
+Note [specImport call stack]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When specialising an imports function 'f', we may get new calls
+of an imported fuction 'g', which we want to specialise in turn,
+and similarly specialising 'g' might expose a new call to 'h'.
+
+We track the stack of enclosing functions. So when specialising 'h' we
+haev a specImport call stack of [g,f]. We do this for two reasons:
+* Note [Warning about missed specialisations]
+* Note [Avoiding recursive specialisation]
+
+Note [Warning about missed specialisations]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Suppose
* In module Lib, you carefully mark a function 'foo' INLINABLE
* Import Lib(foo) into another module M
@@ -807,6 +931,16 @@ is what Opt_WarnAllMissedSpecs does.
ToDo: warn about missed opportunities for local functions.
+Note [Avoiding recursive specialisation]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
+'f's RHS. So we want to specialise g,h. But we don't want to
+specialise f any more! It's possible that f's RHS might have a
+recursive yet-more-specialised call, so we'd diverge in that case.
+And if the call is to the same type, one specialisation is enough.
+Avoiding this recursive specialisation loop is one reason for the
+'callers' stack passed to specImports and specImport.
+
Note [Specialise imported INLINABLE things]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
What imported functions do we specialise? The basic set is
@@ -842,15 +976,6 @@ make sure that f_spec is recursive. Easiest thing is to make all
the specialisations for imported bindings recursive.
-Note [Avoiding recursive specialisation]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
-'f's RHS. So we want to specialise g,h. But we don't want to
-specialise f any more! It's possible that f's RHS might have a
-recursive yet-more-specialised call, so we'd diverge in that case.
-And if the call is to the same type, one specialisation is enough.
-Avoiding this recursive specialisation loop is the reason for the
-'done' VarSet passed to specImports and specImport.
************************************************************************
* *
@@ -1637,13 +1762,11 @@ This translates to
None of these definitions is recursive. What happened was that we
generated a specialisation:
-
RULE forall d. dfun T d = dT :: C [T]
dT = (MkD a d (meth d)) [T/a, d1/d]
= MkD T d1 (meth d1)
But now we use the RULE on the RHS of d2, to get
-
d2 = dT = MkD d1 (meth d1)
d1 = $p1 d2
@@ -1661,46 +1784,6 @@ Solution:
This is done by 'filterCalls'
--------------
-Here's another example, this time for an imported dfun, so the call
-to filterCalls is in specImports (#13429). Suppose we have
- class Monoid v => C v a where ...
-
-We start with a call
- f @ [Integer] @ Integer $fC[]Integer
-
-Specialising call to 'f' gives dict bindings
- $dMonoid_1 :: Monoid [Integer]
- $dMonoid_1 = M.$p1C @ [Integer] $fC[]Integer
-
- $dC_1 :: C [Integer] (Node [Integer] Integer)
- $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
-
-...plus a recursive call to
- f @ [Integer] @ (Node [Integer] Integer) $dC_1
-
-Specialising that call gives
- $dMonoid_2 :: Monoid [Integer]
- $dMonoid_2 = M.$p1C @ [Integer] $dC_1
-
- $dC_2 :: C [Integer] (Node [Integer] Integer)
- $dC_2 = M.$fCvNode @ [Integer] $dMonoid_2
-
-Now we have two calls to the imported function
- M.$fCvNode :: Monoid v => C v a
- M.$fCvNode @v @a m = C m some_fun
-
-But we must /not/ use the call (M.$fCvNode @ [Integer] $dMonoid_2)
-for specialisation, else we get:
-
- $dC_1 = M.$fCvNode @ [Integer] $dMonoid_1
- $dMonoid_2 = M.$p1C @ [Integer] $dC_1
- $s$fCvNode = C $dMonoid_2 ...
- RULE M.$fCvNode [Integer] _ _ = $s$fCvNode
-
-Now use the rule to rewrite the call in the RHS of $dC_1
-and we get a loop!
-
---------------
Here's yet another example
class C a where { foo,bar :: [a] -> [a] }
diff --git a/testsuite/tests/simplCore/should_run/T17151.hs b/testsuite/tests/simplCore/should_run/T17151.hs
new file mode 100644
index 0000000000..20c31ea11f
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/T17151.hs
@@ -0,0 +1,18 @@
+{-# LANGUAGE MonoLocalBinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+module Main where
+
+import T17151a
+
+main :: IO ()
+main = do
+ let ys :: Array P Int Int
+ ys = computeS (makeArray D 1 (const 5))
+ applyStencil ::
+ (Source P ix Int, Load D ix Int)
+ => Stencil ix Int Int
+ -> Array P ix Int
+ -> Array P ix Int
+ applyStencil s = computeS . mapStencil s
+ print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0)
+ print (applyStencil (makeConvolutionStencilFromKernel ys) ys `unsafeIndex` 0)
diff --git a/testsuite/tests/simplCore/should_run/T17151.stdout b/testsuite/tests/simplCore/should_run/T17151.stdout
new file mode 100644
index 0000000000..7a165dae5c
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/T17151.stdout
@@ -0,0 +1,2 @@
+55
+55
diff --git a/testsuite/tests/simplCore/should_run/T17151a.hs b/testsuite/tests/simplCore/should_run/T17151a.hs
new file mode 100644
index 0000000000..b2392e242e
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/T17151a.hs
@@ -0,0 +1,205 @@
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE EmptyDataDecls #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedTuples #-}
+module T17151a
+ ( computeS
+ , Stencil
+ , P(..)
+ , D(..)
+ , makeConvolutionStencilFromKernel
+ , mapStencil
+ , Array
+ , Construct(..)
+ , Source(..)
+ , Load(..)
+ , Mutable(..)
+ ) where
+
+import Control.Monad.ST
+import Data.Functor.Identity
+import GHC.STRef
+import GHC.ST
+import GHC.Exts
+import Unsafe.Coerce
+import Data.Kind
+
+---- Hacked up stuff to simulate primitive package
+class Prim e where
+ indexByteArray :: ByteArray -> Int -> e
+ sizeOf :: e ->Int
+instance Prim Int where
+ indexByteArray _ _ = 55
+ sizeOf _ = 99
+
+data ByteArray = BA
+type MutableByteArray s = STRef s Int
+
+class Monad m => PrimMonad m where
+ type PrimState m
+ primitive :: (State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
+instance PrimMonad (ST s) where
+ type PrimState (ST s) = s
+ primitive = ST
+
+unsafeFreezeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray
+unsafeFreezeByteArray a = return (unsafeCoerce a)
+
+newByteArray :: PrimMonad m => Int -> m (MutableByteArray (PrimState m))
+newByteArray (I# n#)
+ = primitive (\s# -> case newMutVar# 33 s# of
+ (# s'#, arr# #) -> (# s'#, STRef arr# #))
+
+writeByteArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> e -> m ()
+writeByteArray _ _ _ = return ()
+
+----- End of hacked up stuff
+
+--------------
+newtype Stencil ix e a =
+ Stencil ((ix -> e) -> ix -> a)
+
+mapStencil :: Source r ix e => Stencil ix e a -> Array r ix e -> Array D ix a
+mapStencil (Stencil stencilF) arr = DArray (size arr) (stencilF (unsafeIndex arr))
+{-# INLINE mapStencil #-}
+
+makeConvolutionStencilFromKernel
+ :: (Source r ix e, Num e)
+ => Array r ix e
+ -> Stencil ix e e
+makeConvolutionStencilFromKernel arr = Stencil stencil
+ where
+ sz = size arr
+ sCenter = liftIndex (`quot` 2) sz
+ stencil getVal ix =
+ runIdentity $
+ loopM 0 (< totalElem sz) (+ 1) 0 $ \i a ->
+ pure $ accum a (fromLinearIndex sz i) (unsafeLinearIndex arr i)
+ where
+ ixOff = liftIndex2 (+) ix sCenter
+ accum acc kIx kVal = getVal (liftIndex2 (-) ixOff kIx) * kVal + acc
+ {-# INLINE accum #-}
+ {-# INLINE stencil #-}
+{-# INLINE makeConvolutionStencilFromKernel #-}
+
+
+computeS :: (Mutable r ix e, Load r' ix e) => Array r' ix e -> Array r ix e
+computeS arr = runST $ do
+ marr <- unsafeNew (size arr)
+ unsafeLoadIntoS marr arr
+ unsafeFreeze marr
+{-# INLINE computeS #-}
+
+
+data D = D deriving Show
+
+data instance Array D ix e = DArray{dSize :: ix,
+ dIndex :: ix -> e}
+
+instance Index ix => Construct D ix e where
+ makeArray _ = DArray
+ {-# INLINE makeArray #-}
+
+instance Index ix => Source D ix e where
+ unsafeIndex = dIndex
+ {-# INLINE unsafeIndex #-}
+
+instance Index ix => Load D ix e where
+ size = dSize
+ {-# INLINE size #-}
+ loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr)
+ {-# INLINE loadArrayM #-}
+
+
+data P = P deriving Show
+
+data instance Array P ix e = PArray ix ByteArray
+
+instance (Prim e, Index ix) => Construct P ix e where
+ makeArray _ sz f = computeS (makeArray D sz f)
+ {-# INLINE makeArray #-}
+
+instance (Prim e, Index ix) => Source P ix e where
+ unsafeIndex (PArray sz a) = indexByteArray a . toLinearIndex sz
+ {-# INLINE unsafeIndex #-}
+
+instance (Prim e, Index ix) => Mutable P ix e where
+ data MArray s P ix e = MPArray ix (MutableByteArray s)
+ unsafeFreeze (MPArray sz a) = PArray sz <$> unsafeFreezeByteArray a
+ {-# INLINE unsafeFreeze #-}
+ unsafeNew sz = MPArray sz <$> newByteArray (totalElem sz * eSize)
+ where
+ eSize = sizeOf (undefined :: e)
+ {-# INLINE unsafeNew #-}
+ unsafeLinearWrite (MPArray _ ma) = writeByteArray ma
+ {-# INLINE unsafeLinearWrite #-}
+
+
+instance (Prim e, Index ix) => Load P ix e where
+ size (PArray sz _) = sz
+ {-# INLINE size #-}
+ loadArrayM arr = splitLinearlyWith_ (totalElem (size arr)) (unsafeLinearIndex arr)
+ {-# INLINE loadArrayM #-}
+
+
+unsafeLinearIndex :: Source r ix e => Array r ix e -> Int -> e
+unsafeLinearIndex arr = unsafeIndex arr . fromLinearIndex (size arr)
+{-# INLINE unsafeLinearIndex #-}
+
+
+loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
+loopM init' condition increment initAcc f = go init' initAcc
+ where
+ go step acc
+ | condition step = f step acc >>= go (increment step)
+ | otherwise = return acc
+{-# INLINE loopM #-}
+
+splitLinearlyWith_ ::
+ Monad m => Int -> (Int -> b) -> (Int -> b -> m ()) -> m ()
+splitLinearlyWith_ totalLength index write =
+ loopM 0 (< totalLength) (+1) () $ \i () -> write i (index i)
+{-# INLINE splitLinearlyWith_ #-}
+
+
+data family Array r ix e :: Type
+
+class Index ix => Construct r ix e where
+ makeArray :: r -> ix -> (ix -> e) -> Array r ix e
+
+class Load r ix e => Source r ix e where
+ unsafeIndex :: Array r ix e -> ix -> e
+
+class Index ix => Load r ix e where
+ size :: Array r ix e -> ix
+ loadArrayM :: Monad m => Array r ix e -> (Int -> e -> m ()) -> m ()
+ unsafeLoadIntoS ::
+ (Mutable r' ix e, PrimMonad m) => MArray (PrimState m) r' ix e -> Array r ix e -> m ()
+ unsafeLoadIntoS marr arr = loadArrayM arr (unsafeLinearWrite marr)
+ {-# INLINE unsafeLoadIntoS #-}
+
+class (Construct r ix e, Source r ix e) => Mutable r ix e where
+ data MArray s r ix e :: Type
+ unsafeFreeze :: PrimMonad m => MArray (PrimState m) r ix e -> m (Array r ix e)
+ unsafeNew :: PrimMonad m => ix -> m (MArray (PrimState m) r ix e)
+ unsafeLinearWrite :: PrimMonad m => MArray (PrimState m) r ix e -> Int -> e -> m ()
+
+
+class (Eq ix, Ord ix, Show ix) =>
+ Index ix
+ where
+ totalElem :: ix -> Int
+ liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
+ liftIndex :: (Int -> Int) -> ix -> ix
+ toLinearIndex :: ix -> ix -> Int
+ fromLinearIndex :: ix -> Int -> ix
+
+instance Index Int where
+ totalElem = id
+ toLinearIndex _ = id
+ fromLinearIndex _ = id
+ liftIndex f = f
+ liftIndex2 f = f
diff --git a/testsuite/tests/simplCore/should_run/all.T b/testsuite/tests/simplCore/should_run/all.T
index 3a61665cdf..d101bff84b 100644
--- a/testsuite/tests/simplCore/should_run/all.T
+++ b/testsuite/tests/simplCore/should_run/all.T
@@ -92,3 +92,4 @@ test('T15840', normal, compile_and_run, [''])
test('T15840a', normal, compile_and_run, [''])
test('T16066', exit_code(1), compile_and_run, ['-O1'])
test('T17206', exit_code(1), compile_and_run, [''])
+test('T17151', [], multimod_compile_and_run, ['T17151', ''])