summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/GHC/Core/Opt/Arity.hs6
-rw-r--r--compiler/GHC/Core/Opt/Simplify/Iteration.hs87
-rw-r--r--testsuite/tests/simplCore/should_compile/T22491.hs319
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T1
4 files changed, 379 insertions, 34 deletions
diff --git a/compiler/GHC/Core/Opt/Arity.hs b/compiler/GHC/Core/Opt/Arity.hs
index fbbcf1c2ad..832fba354c 100644
--- a/compiler/GHC/Core/Opt/Arity.hs
+++ b/compiler/GHC/Core/Opt/Arity.hs
@@ -3103,9 +3103,13 @@ etaBodyForJoinPoint need_args body
| Just (tv, res_ty) <- splitForAllTyCoVar_maybe ty
, let (subst', tv') = substVarBndr subst tv
= go (n-1) res_ty subst' (tv' : rev_bs) (e `App` varToCoreExpr tv')
+ -- The varToCoreExpr is important: `tv` might be a coercion variable
+
| Just (_, mult, arg_ty, res_ty) <- splitFunTy_maybe ty
, let (subst', b) = freshEtaId n subst (Scaled mult arg_ty)
- = go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b)
+ = go (n-1) res_ty subst' (b : rev_bs) (e `App` varToCoreExpr b)
+ -- The varToCoreExpr is important: `b` might be a coercion variable
+
| otherwise
= pprPanic "etaBodyForJoinPoint" $ int need_args $$
ppr body $$ ppr (exprType body)
diff --git a/compiler/GHC/Core/Opt/Simplify/Iteration.hs b/compiler/GHC/Core/Opt/Simplify/Iteration.hs
index f8ed00c119..36c969224c 100644
--- a/compiler/GHC/Core/Opt/Simplify/Iteration.hs
+++ b/compiler/GHC/Core/Opt/Simplify/Iteration.hs
@@ -1227,11 +1227,23 @@ simplExprF1 env (Let (NonRec bndr rhs) body) cont
do { ty' <- simplType env ty
; simplExprF (extendTvSubst env bndr ty') body cont }
+ | Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
+ -- Because of the let-can-float invariant, it's ok to
+ -- inline freely, or to drop the binding if it is dead.
+ = do { tick (PreInlineUnconditionally bndr)
+ ; simplExprF env' body cont }
+
+ -- Now check for a join point. It's better to do the preInlineUnconditionally
+ -- test first, because joinPointBinding_maybe has to eta-expand, so a trivial
+ -- binding like { j = j2 |> co } would first be eta-expanded and then inlined
+ -- Better to test preInlineUnconditionally first.
| Just (bndr', rhs') <- joinPointBinding_maybe bndr rhs
- = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' body cont
+ = {-#SCC "simplNonRecJoinPoint" #-}
+ simplNonRecJoinPoint env bndr' rhs' body cont
| otherwise
- = {-#SCC "simplNonRecE" #-} simplNonRecE env bndr (rhs, env) body cont
+ = {-#SCC "simplNonRecE" #-}
+ simplNonRecE env False bndr (rhs, env) body cont
{- Note [Avoiding space leaks in OutType]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1680,12 +1692,12 @@ simpl_lam env bndr body (ApplyToVal { sc_arg = arg, sc_env = arg_se
, sc_cont = cont, sc_dup = dup })
| isSimplified dup -- Don't re-simplify if we've simplified it once
-- See Note [Avoiding exponential behaviour]
- = do { tick (BetaReduction bndr)
- ; completeBindX env bndr arg body cont }
+ = do { tick (BetaReduction bndr)
+ ; completeBindX env bndr arg body cont }
| otherwise -- See Note [Avoiding exponential behaviour]
- = do { tick (BetaReduction bndr)
- ; simplNonRecE env bndr (arg, arg_se) body cont }
+ = do { tick (BetaReduction bndr)
+ ; simplNonRecE env True bndr (arg, arg_se) body cont }
-- Discard a non-counting tick on a lambda. This may change the
-- cost attribution slightly (moving the allocation of the
@@ -1717,6 +1729,8 @@ simplLamBndrs env bndrs = mapAccumLM simplLamBndr env bndrs
------------------
simplNonRecE :: SimplEnv
+ -> Bool -- True <=> from a lambda
+ -- False <=> from a let
-> InId -- The binder, always an Id
-- Never a join point
-> (InExpr, SimplEnv) -- Rhs of binding (or arg of lambda)
@@ -1735,34 +1749,46 @@ simplNonRecE :: SimplEnv
-- It deals with strict bindings, via the StrictBind continuation,
-- which may abort the whole process.
--
--- The RHS may not satisfy the let-can-float invariant yet
+-- from_lam=False => the RHS satisfies the let-can-float invariant
+-- Otherwise it may or may not satisfy it.
-simplNonRecE env bndr (rhs, rhs_se) body cont
+simplNonRecE env from_lam bndr (rhs, rhs_se) body cont
= assert (isId bndr && not (isJoinId bndr) ) $
do { (env1, bndr1) <- simplNonRecBndr env bndr
; let needs_case_binding = needsCaseBinding (idType bndr1) rhs
-- See Note [Dark corner with representation polymorphism]
- ; if | not needs_case_binding
- , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs rhs_se ->
- do { tick (PreInlineUnconditionally bndr)
- ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $
- simplLam env' body cont }
-
+ -- If from_lam=False then needs_case_binding is False,
+ -- because the binding started as a let, which must
+ -- satisfy let-can-float
+
+ ; if | from_lam && not needs_case_binding
+ -- If not from_lam we are coming from a (NonRec bndr rhs) binding
+ -- and preInlineUnconditionally has been done already;
+ -- no need to repeat it. But for lambdas we must be careful about
+ -- preInlineUndonditionally: consider (\(x:Int#). 3) (error "urk")
+ -- We must not drop the (error "urk").
+ , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs rhs_se
+ -> do { tick (PreInlineUnconditionally bndr)
+ ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $
+ simplLam env' body cont }
-- Deal with strict bindings
- -- See Note [Dark corner with representation polymorphism]
- | isStrictId bndr1 && seCaseCase env
- || needs_case_binding ->
- simplExprF (rhs_se `setInScopeFromE` env) rhs
- (StrictBind { sc_bndr = bndr, sc_body = body
- , sc_env = env, sc_cont = cont, sc_dup = NoDup })
+ | isStrictId bndr1 && seCaseCase env
+ || from_lam && needs_case_binding
+ -- The important bit here is needs_case_binds; but no need to
+ -- test it if from_lam is False because then needs_case_binding is False too
+ -- NB: either way, the RHS may or may not satisfy let-can-float
+ -- but that's ok for StrictBind.
+ -> simplExprF (rhs_se `setInScopeFromE` env) rhs
+ (StrictBind { sc_bndr = bndr, sc_body = body
+ , sc_env = env, sc_cont = cont, sc_dup = NoDup })
-- Deal with lazy bindings
- | otherwise ->
- do { (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Let NotTopLevel NonRecursive)
- ; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
- ; (floats2, expr') <- simplLam env3 body cont
- ; return (floats1 `addFloats` floats2, expr') } }
+ | otherwise
+ -> do { (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Let NotTopLevel NonRecursive)
+ ; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
+ ; (floats2, expr') <- simplLam env3 body cont
+ ; return (floats1 `addFloats` floats2, expr') } }
------------------
simplRecE :: SimplEnv
@@ -1806,7 +1832,7 @@ care here.
Note [Avoiding exponential behaviour]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
One way in which we can get exponential behaviour is if we simplify a
-big expression, and the re-simplify it -- and then this happens in a
+big expression, and then re-simplify it -- and then this happens in a
deeply-nested way. So we must be jolly careful about re-simplifying
an expression. That is why simplNonRecX does not try
preInlineUnconditionally (unlike simplNonRecE).
@@ -1864,13 +1890,8 @@ simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
-> InExpr -> SimplCont
-> SimplM (SimplFloats, OutExpr)
simplNonRecJoinPoint env bndr rhs body cont
- | assert (isJoinId bndr ) True
- , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env
- = do { tick (PreInlineUnconditionally bndr)
- ; simplExprF env' body cont }
-
- | otherwise
- = wrapJoinCont env cont $ \ env cont ->
+ = assert (isJoinId bndr ) $
+ wrapJoinCont env cont $ \ env cont ->
do { -- We push join_cont into the join RHS and the body;
-- and wrap wrap_cont around the whole thing
; let mult = contHoleScaling cont
diff --git a/testsuite/tests/simplCore/should_compile/T22491.hs b/testsuite/tests/simplCore/should_compile/T22491.hs
new file mode 100644
index 0000000000..ed27654979
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T22491.hs
@@ -0,0 +1,319 @@
+{-# LANGUAGE Haskell2010 #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module T22491 (heapster_add_block_hints) where
+
+import qualified Control.Exception as X
+import Control.Applicative
+import Control.Monad
+import Control.Monad.Catch (MonadThrow(..), MonadCatch(..), catches, Handler(..))
+import Control.Monad.IO.Class
+import qualified Control.Monad.Fail as Fail
+import Control.Monad.Trans.Class (MonadTrans(..))
+import Control.Monad.Trans.Reader (ReaderT)
+import Data.Coerce (Coercible, coerce)
+import Data.IORef
+import Data.Kind (Type)
+import Data.Monoid
+import GHC.Exts (build)
+
+failOnNothing :: Fail.MonadFail m => String -> Maybe a -> m a
+failOnNothing err_str Nothing = Fail.fail err_str
+failOnNothing _ (Just a) = return a
+
+lookupLLVMSymbolModAndCFG :: HeapsterEnv -> String -> IO (Maybe (AnyCFG LLVM))
+lookupLLVMSymbolModAndCFG _ _ = pure Nothing
+
+heapster_add_block_hints :: HeapsterEnv -> String -> [Int] ->
+ (forall ext blocks ret.
+ CFG ext blocks ret ->
+ TopLevel Hint) ->
+ TopLevel ()
+heapster_add_block_hints henv nm blks hintF =
+ do env <- liftIO $ readIORef $ heapsterEnvPermEnvRef henv
+ AnyCFG cfg <-
+ failOnNothing ("Could not find symbol definition: " ++ nm) =<<
+ io (lookupLLVMSymbolModAndCFG henv nm)
+ let blocks = fmapFC blockInputs $ cfgBlockMap cfg
+ block_idxs = fmapFC (blockIDIndex . blockID) $ cfgBlockMap cfg
+ blkIDs <- case blks of
+ [] -> pure $ toListFC (Some . BlockID) block_idxs
+ _ -> forM blks $ \blk ->
+ failOnNothing ("Block ID " ++ show blk ++
+ " not found in function " ++ nm)
+ (fmapF BlockID <$> intIndex blk (size blocks))
+ env' <- foldM (\env' _ ->
+ permEnvAddHint env' <$>
+ hintF cfg)
+ env blkIDs
+ liftIO $ writeIORef (heapsterEnvPermEnvRef henv) env'
+
+-----
+
+data Some (f:: k -> Type) = forall x . Some (f x)
+
+class FunctorF m where
+ fmapF :: (forall x . f x -> g x) -> m f -> m g
+
+mapSome :: (forall tp . f tp -> g tp) -> Some f -> Some g
+mapSome f (Some x) = Some $! f x
+
+instance FunctorF Some where fmapF = mapSome
+
+type SingleCtx x = EmptyCtx ::> x
+
+data Ctx k
+ = EmptyCtx
+ | Ctx k ::> k
+
+type family (<+>) (x :: Ctx k) (y :: Ctx k) :: Ctx k where
+ x <+> EmptyCtx = x
+ x <+> (y ::> e) = (x <+> y) ::> e
+
+data Height = Zero | Succ Height
+
+data BalancedTree h (f :: k -> Type) (p :: Ctx k) where
+ BalLeaf :: !(f x) -> BalancedTree 'Zero f (SingleCtx x)
+ BalPair :: !(BalancedTree h f x)
+ -> !(BalancedTree h f y)
+ -> BalancedTree ('Succ h) f (x <+> y)
+
+data BinomialTree (h::Height) (f :: k -> Type) :: Ctx k -> Type where
+ Empty :: BinomialTree h f EmptyCtx
+
+ PlusOne :: !Int
+ -> !(BinomialTree ('Succ h) f x)
+ -> !(BalancedTree h f y)
+ -> BinomialTree h f (x <+> y)
+
+ PlusZero :: !Int
+ -> !(BinomialTree ('Succ h) f x)
+ -> BinomialTree h f x
+
+tsize :: BinomialTree h f a -> Int
+tsize Empty = 0
+tsize (PlusOne s _ _) = 2*s+1
+tsize (PlusZero s _) = 2*s
+
+fmap_bin :: (forall tp . f tp -> g tp)
+ -> BinomialTree h f c
+ -> BinomialTree h g c
+fmap_bin _ Empty = Empty
+fmap_bin f (PlusOne s t x) = PlusOne s (fmap_bin f t) (fmap_bal f x)
+fmap_bin f (PlusZero s t) = PlusZero s (fmap_bin f t)
+{-# INLINABLE fmap_bin #-}
+
+fmap_bal :: (forall tp . f tp -> g tp)
+ -> BalancedTree h f c
+ -> BalancedTree h g c
+fmap_bal = go
+ where go :: (forall tp . f tp -> g tp)
+ -> BalancedTree h f c
+ -> BalancedTree h g c
+ go f (BalLeaf x) = BalLeaf (f x)
+ go f (BalPair x y) = BalPair (go f x) (go f y)
+{-# INLINABLE fmap_bal #-}
+
+traverse_bin :: Applicative m
+ => (forall tp . f tp -> m (g tp))
+ -> BinomialTree h f c
+ -> m (BinomialTree h g c)
+traverse_bin _ Empty = pure Empty
+traverse_bin f (PlusOne s t x) = PlusOne s <$> traverse_bin f t <*> traverse_bal f x
+traverse_bin f (PlusZero s t) = PlusZero s <$> traverse_bin f t
+{-# INLINABLE traverse_bin #-}
+
+traverse_bal :: Applicative m
+ => (forall tp . f tp -> m (g tp))
+ -> BalancedTree h f c
+ -> m (BalancedTree h g c)
+traverse_bal = go
+ where go :: Applicative m
+ => (forall tp . f tp -> m (g tp))
+ -> BalancedTree h f c
+ -> m (BalancedTree h g c)
+ go f (BalLeaf x) = BalLeaf <$> f x
+ go f (BalPair x y) = BalPair <$> go f x <*> go f y
+{-# INLINABLE traverse_bal #-}
+
+data Assignment (f :: k -> Type) (ctx :: Ctx k)
+ = Assignment (BinomialTree 'Zero f ctx)
+
+newtype Index (ctx :: Ctx k) (tp :: k) = Index { indexVal :: Int }
+
+newtype Size (ctx :: Ctx k) = Size Int
+
+intIndex :: Int -> Size ctx -> Maybe (Some (Index ctx))
+intIndex i n | 0 <= i && i < sizeInt n = Just (Some (Index i))
+ | otherwise = Nothing
+
+size :: Assignment f ctx -> Size ctx
+size (Assignment t) = Size (tsize t)
+
+sizeInt :: Size ctx -> Int
+sizeInt (Size n) = n
+
+class FunctorFC (t :: (k -> Type) -> l -> Type) where
+ fmapFC :: forall f g. (forall x. f x -> g x) ->
+ (forall x. t f x -> t g x)
+
+(#.) :: Coercible b c => (b -> c) -> (a -> b) -> (a -> c)
+(#.) _f = coerce
+
+class FoldableFC (t :: (k -> Type) -> l -> Type) where
+ foldMapFC :: forall f m. Monoid m => (forall x. f x -> m) -> (forall x. t f x -> m)
+ foldMapFC f = foldrFC (mappend . f) mempty
+
+ foldrFC :: forall f b. (forall x. f x -> b -> b) -> (forall x. b -> t f x -> b)
+ foldrFC f z t = appEndo (foldMapFC (Endo #. f) t) z
+
+ toListFC :: forall f a. (forall x. f x -> a) -> (forall x. t f x -> [a])
+ toListFC f t = build (\c n -> foldrFC (\e v -> c (f e) v) n t)
+
+foldMapFCDefault :: (TraversableFC t, Monoid m) => (forall x. f x -> m) -> (forall x. t f x -> m)
+foldMapFCDefault = \f -> getConst . traverseFC (Const . f)
+{-# INLINE foldMapFCDefault #-}
+
+class (FunctorFC t, FoldableFC t) => TraversableFC (t :: (k -> Type) -> l -> Type) where
+ traverseFC :: forall f g m. Applicative m
+ => (forall x. f x -> m (g x))
+ -> (forall x. t f x -> m (t g x))
+
+instance FunctorFC Assignment where
+ fmapFC = \f (Assignment x) -> Assignment (fmap_bin f x)
+ {-# INLINE fmapFC #-}
+
+instance FoldableFC Assignment where
+ foldMapFC = foldMapFCDefault
+ {-# INLINE foldMapFC #-}
+
+instance TraversableFC Assignment where
+ traverseFC = \f (Assignment x) -> Assignment <$> traverse_bin f x
+ {-# INLINE traverseFC #-}
+
+data CrucibleType
+
+data TypeRepr (tp::CrucibleType) where
+
+type CtxRepr = Assignment TypeRepr
+
+data CFG (ext :: Type)
+ (blocks :: Ctx (Ctx CrucibleType))
+ (ret :: CrucibleType)
+ = CFG { cfgBlockMap :: !(BlockMap ext blocks ret)
+ }
+
+type BlockMap ext blocks ret = Assignment (Block ext blocks ret) blocks
+
+data Block ext (blocks :: Ctx (Ctx CrucibleType)) (ret :: CrucibleType) ctx
+ = Block { blockID :: !(BlockID blocks ctx)
+ , blockInputs :: !(CtxRepr ctx)
+ }
+
+newtype BlockID (blocks :: Ctx (Ctx CrucibleType)) (tp :: Ctx CrucibleType)
+ = BlockID { blockIDIndex :: Index blocks tp }
+
+data LLVM
+
+data AnyCFG ext where
+ AnyCFG :: CFG ext blocks ret
+ -> AnyCFG ext
+
+newtype StateContT s r m a
+ = StateContT { runStateContT :: (a -> s -> m r)
+ -> s
+ -> m r
+ }
+
+fmapStateContT :: (a -> b) -> StateContT s r m a -> StateContT s r m b
+fmapStateContT = \f m -> StateContT $ \c -> runStateContT m (\v s -> (c $! f v) s)
+{-# INLINE fmapStateContT #-}
+
+applyStateContT :: StateContT s r m (a -> b) -> StateContT s r m a -> StateContT s r m b
+applyStateContT = \mf mv ->
+ StateContT $ \c ->
+ runStateContT mf (\f -> runStateContT mv (\v s -> (c $! f v) s))
+{-# INLINE applyStateContT #-}
+
+returnStateContT :: a -> StateContT s r m a
+returnStateContT = \v -> seq v $ StateContT $ \c -> c v
+{-# INLINE returnStateContT #-}
+
+bindStateContT :: StateContT s r m a -> (a -> StateContT s r m b) -> StateContT s r m b
+bindStateContT = \m n -> StateContT $ \c -> runStateContT m (\a -> runStateContT (n a) c)
+{-# INLINE bindStateContT #-}
+
+instance Functor (StateContT s r m) where
+ fmap = fmapStateContT
+
+instance Applicative (StateContT s r m) where
+ pure = returnStateContT
+ (<*>) = applyStateContT
+
+instance Monad (StateContT s r m) where
+ (>>=) = bindStateContT
+
+instance MonadFail m => MonadFail (StateContT s r m) where
+ fail = \msg -> StateContT $ \_ _ -> fail msg
+
+instance MonadTrans (StateContT s r) where
+ lift = \m -> StateContT $ \c s -> m >>= \v -> seq v (c v s)
+
+instance MonadIO m => MonadIO (StateContT s r m) where
+ liftIO = lift . liftIO
+
+instance MonadThrow m => MonadThrow (StateContT s r m) where
+ throwM e = StateContT (\_k _s -> throwM e)
+
+instance MonadCatch m => MonadCatch (StateContT s r m) where
+ catch m hdl =
+ StateContT $ \k s ->
+ catch
+ (runStateContT m k s)
+ (\e -> runStateContT (hdl e) k s)
+
+data TopLevelRO
+data TopLevelRW
+data Value
+
+newtype TopLevel a =
+ TopLevel_ (ReaderT TopLevelRO (StateContT TopLevelRW (Value, TopLevelRW) IO) a)
+ deriving (Applicative, Functor, Monad, MonadFail, MonadThrow, MonadCatch)
+
+instance MonadIO TopLevel where
+ liftIO = io
+
+io :: IO a -> TopLevel a
+io f = TopLevel_ (liftIO f) `catches` [Handler handleIO]
+ where
+ rethrow :: X.Exception ex => ex -> TopLevel a
+ rethrow ex = throwM (X.SomeException ex)
+
+ handleIO :: X.IOException -> TopLevel a
+ handleIO = rethrow
+
+data HeapsterEnv = HeapsterEnv {
+ heapsterEnvPermEnvRef :: IORef PermEnv
+ }
+
+data Hint where
+
+data PermEnv = PermEnv {
+ permEnvHints :: [Hint]
+ }
+
+permEnvAddHint :: PermEnv -> Hint -> PermEnv
+permEnvAddHint env hint = env { permEnvHints = hint : permEnvHints env }
+
+type family CtxToRList (ctx :: Ctx k) :: RList k where
+ CtxToRList EmptyCtx = RNil
+ CtxToRList (ctx' ::> x) = CtxToRList ctx' :> x
+
+data RList a
+ = RNil
+ | (RList a) :> a
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index 321cb046b9..08ff9329ec 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -455,3 +455,4 @@ test('T21851_2', [grep_errmsg(r'wwombat') ], multimod_compile, ['T21851_2', '-O
test('T22317', [grep_errmsg(r'ANSWER = YES') ], compile, ['-O -dinline-check m -ddebug-output'])
test('T22494', [grep_errmsg(r'case') ], compile, ['-O -ddump-simpl -dsuppress-uniques'])
+test('T22491', normal, compile, ['-O2'])