diff options
-rw-r--r-- | compiler/GHC/Core/Opt/Arity.hs | 6 | ||||
-rw-r--r-- | compiler/GHC/Core/Opt/Simplify/Iteration.hs | 87 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T22491.hs | 319 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/all.T | 1 |
4 files changed, 379 insertions, 34 deletions
diff --git a/compiler/GHC/Core/Opt/Arity.hs b/compiler/GHC/Core/Opt/Arity.hs index fbbcf1c2ad..832fba354c 100644 --- a/compiler/GHC/Core/Opt/Arity.hs +++ b/compiler/GHC/Core/Opt/Arity.hs @@ -3103,9 +3103,13 @@ etaBodyForJoinPoint need_args body | Just (tv, res_ty) <- splitForAllTyCoVar_maybe ty , let (subst', tv') = substVarBndr subst tv = go (n-1) res_ty subst' (tv' : rev_bs) (e `App` varToCoreExpr tv') + -- The varToCoreExpr is important: `tv` might be a coercion variable + | Just (_, mult, arg_ty, res_ty) <- splitFunTy_maybe ty , let (subst', b) = freshEtaId n subst (Scaled mult arg_ty) - = go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b) + = go (n-1) res_ty subst' (b : rev_bs) (e `App` varToCoreExpr b) + -- The varToCoreExpr is important: `b` might be a coercion variable + | otherwise = pprPanic "etaBodyForJoinPoint" $ int need_args $$ ppr body $$ ppr (exprType body) diff --git a/compiler/GHC/Core/Opt/Simplify/Iteration.hs b/compiler/GHC/Core/Opt/Simplify/Iteration.hs index f8ed00c119..36c969224c 100644 --- a/compiler/GHC/Core/Opt/Simplify/Iteration.hs +++ b/compiler/GHC/Core/Opt/Simplify/Iteration.hs @@ -1227,11 +1227,23 @@ simplExprF1 env (Let (NonRec bndr rhs) body) cont do { ty' <- simplType env ty ; simplExprF (extendTvSubst env bndr ty') body cont } + | Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env + -- Because of the let-can-float invariant, it's ok to + -- inline freely, or to drop the binding if it is dead. + = do { tick (PreInlineUnconditionally bndr) + ; simplExprF env' body cont } + + -- Now check for a join point. It's better to do the preInlineUnconditionally + -- test first, because joinPointBinding_maybe has to eta-expand, so a trivial + -- binding like { j = j2 |> co } would first be eta-expanded and then inlined + -- Better to test preInlineUnconditionally first. | Just (bndr', rhs') <- joinPointBinding_maybe bndr rhs - = {-#SCC "simplNonRecJoinPoint" #-} simplNonRecJoinPoint env bndr' rhs' body cont + = {-#SCC "simplNonRecJoinPoint" #-} + simplNonRecJoinPoint env bndr' rhs' body cont | otherwise - = {-#SCC "simplNonRecE" #-} simplNonRecE env bndr (rhs, env) body cont + = {-#SCC "simplNonRecE" #-} + simplNonRecE env False bndr (rhs, env) body cont {- Note [Avoiding space leaks in OutType] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1680,12 +1692,12 @@ simpl_lam env bndr body (ApplyToVal { sc_arg = arg, sc_env = arg_se , sc_cont = cont, sc_dup = dup }) | isSimplified dup -- Don't re-simplify if we've simplified it once -- See Note [Avoiding exponential behaviour] - = do { tick (BetaReduction bndr) - ; completeBindX env bndr arg body cont } + = do { tick (BetaReduction bndr) + ; completeBindX env bndr arg body cont } | otherwise -- See Note [Avoiding exponential behaviour] - = do { tick (BetaReduction bndr) - ; simplNonRecE env bndr (arg, arg_se) body cont } + = do { tick (BetaReduction bndr) + ; simplNonRecE env True bndr (arg, arg_se) body cont } -- Discard a non-counting tick on a lambda. This may change the -- cost attribution slightly (moving the allocation of the @@ -1717,6 +1729,8 @@ simplLamBndrs env bndrs = mapAccumLM simplLamBndr env bndrs ------------------ simplNonRecE :: SimplEnv + -> Bool -- True <=> from a lambda + -- False <=> from a let -> InId -- The binder, always an Id -- Never a join point -> (InExpr, SimplEnv) -- Rhs of binding (or arg of lambda) @@ -1735,34 +1749,46 @@ simplNonRecE :: SimplEnv -- It deals with strict bindings, via the StrictBind continuation, -- which may abort the whole process. -- --- The RHS may not satisfy the let-can-float invariant yet +-- from_lam=False => the RHS satisfies the let-can-float invariant +-- Otherwise it may or may not satisfy it. -simplNonRecE env bndr (rhs, rhs_se) body cont +simplNonRecE env from_lam bndr (rhs, rhs_se) body cont = assert (isId bndr && not (isJoinId bndr) ) $ do { (env1, bndr1) <- simplNonRecBndr env bndr ; let needs_case_binding = needsCaseBinding (idType bndr1) rhs -- See Note [Dark corner with representation polymorphism] - ; if | not needs_case_binding - , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs rhs_se -> - do { tick (PreInlineUnconditionally bndr) - ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $ - simplLam env' body cont } - + -- If from_lam=False then needs_case_binding is False, + -- because the binding started as a let, which must + -- satisfy let-can-float + + ; if | from_lam && not needs_case_binding + -- If not from_lam we are coming from a (NonRec bndr rhs) binding + -- and preInlineUnconditionally has been done already; + -- no need to repeat it. But for lambdas we must be careful about + -- preInlineUndonditionally: consider (\(x:Int#). 3) (error "urk") + -- We must not drop the (error "urk"). + , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs rhs_se + -> do { tick (PreInlineUnconditionally bndr) + ; -- pprTrace "preInlineUncond" (ppr bndr <+> ppr rhs) $ + simplLam env' body cont } -- Deal with strict bindings - -- See Note [Dark corner with representation polymorphism] - | isStrictId bndr1 && seCaseCase env - || needs_case_binding -> - simplExprF (rhs_se `setInScopeFromE` env) rhs - (StrictBind { sc_bndr = bndr, sc_body = body - , sc_env = env, sc_cont = cont, sc_dup = NoDup }) + | isStrictId bndr1 && seCaseCase env + || from_lam && needs_case_binding + -- The important bit here is needs_case_binds; but no need to + -- test it if from_lam is False because then needs_case_binding is False too + -- NB: either way, the RHS may or may not satisfy let-can-float + -- but that's ok for StrictBind. + -> simplExprF (rhs_se `setInScopeFromE` env) rhs + (StrictBind { sc_bndr = bndr, sc_body = body + , sc_env = env, sc_cont = cont, sc_dup = NoDup }) -- Deal with lazy bindings - | otherwise -> - do { (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Let NotTopLevel NonRecursive) - ; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se - ; (floats2, expr') <- simplLam env3 body cont - ; return (floats1 `addFloats` floats2, expr') } } + | otherwise + -> do { (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Let NotTopLevel NonRecursive) + ; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se + ; (floats2, expr') <- simplLam env3 body cont + ; return (floats1 `addFloats` floats2, expr') } } ------------------ simplRecE :: SimplEnv @@ -1806,7 +1832,7 @@ care here. Note [Avoiding exponential behaviour] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ One way in which we can get exponential behaviour is if we simplify a -big expression, and the re-simplify it -- and then this happens in a +big expression, and then re-simplify it -- and then this happens in a deeply-nested way. So we must be jolly careful about re-simplifying an expression. That is why simplNonRecX does not try preInlineUnconditionally (unlike simplNonRecE). @@ -1864,13 +1890,8 @@ simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) simplNonRecJoinPoint env bndr rhs body cont - | assert (isJoinId bndr ) True - , Just env' <- preInlineUnconditionally env NotTopLevel bndr rhs env - = do { tick (PreInlineUnconditionally bndr) - ; simplExprF env' body cont } - - | otherwise - = wrapJoinCont env cont $ \ env cont -> + = assert (isJoinId bndr ) $ + wrapJoinCont env cont $ \ env cont -> do { -- We push join_cont into the join RHS and the body; -- and wrap wrap_cont around the whole thing ; let mult = contHoleScaling cont diff --git a/testsuite/tests/simplCore/should_compile/T22491.hs b/testsuite/tests/simplCore/should_compile/T22491.hs new file mode 100644 index 0000000000..ed27654979 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T22491.hs @@ -0,0 +1,319 @@ +{-# LANGUAGE Haskell2010 #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module T22491 (heapster_add_block_hints) where + +import qualified Control.Exception as X +import Control.Applicative +import Control.Monad +import Control.Monad.Catch (MonadThrow(..), MonadCatch(..), catches, Handler(..)) +import Control.Monad.IO.Class +import qualified Control.Monad.Fail as Fail +import Control.Monad.Trans.Class (MonadTrans(..)) +import Control.Monad.Trans.Reader (ReaderT) +import Data.Coerce (Coercible, coerce) +import Data.IORef +import Data.Kind (Type) +import Data.Monoid +import GHC.Exts (build) + +failOnNothing :: Fail.MonadFail m => String -> Maybe a -> m a +failOnNothing err_str Nothing = Fail.fail err_str +failOnNothing _ (Just a) = return a + +lookupLLVMSymbolModAndCFG :: HeapsterEnv -> String -> IO (Maybe (AnyCFG LLVM)) +lookupLLVMSymbolModAndCFG _ _ = pure Nothing + +heapster_add_block_hints :: HeapsterEnv -> String -> [Int] -> + (forall ext blocks ret. + CFG ext blocks ret -> + TopLevel Hint) -> + TopLevel () +heapster_add_block_hints henv nm blks hintF = + do env <- liftIO $ readIORef $ heapsterEnvPermEnvRef henv + AnyCFG cfg <- + failOnNothing ("Could not find symbol definition: " ++ nm) =<< + io (lookupLLVMSymbolModAndCFG henv nm) + let blocks = fmapFC blockInputs $ cfgBlockMap cfg + block_idxs = fmapFC (blockIDIndex . blockID) $ cfgBlockMap cfg + blkIDs <- case blks of + [] -> pure $ toListFC (Some . BlockID) block_idxs + _ -> forM blks $ \blk -> + failOnNothing ("Block ID " ++ show blk ++ + " not found in function " ++ nm) + (fmapF BlockID <$> intIndex blk (size blocks)) + env' <- foldM (\env' _ -> + permEnvAddHint env' <$> + hintF cfg) + env blkIDs + liftIO $ writeIORef (heapsterEnvPermEnvRef henv) env' + +----- + +data Some (f:: k -> Type) = forall x . Some (f x) + +class FunctorF m where + fmapF :: (forall x . f x -> g x) -> m f -> m g + +mapSome :: (forall tp . f tp -> g tp) -> Some f -> Some g +mapSome f (Some x) = Some $! f x + +instance FunctorF Some where fmapF = mapSome + +type SingleCtx x = EmptyCtx ::> x + +data Ctx k + = EmptyCtx + | Ctx k ::> k + +type family (<+>) (x :: Ctx k) (y :: Ctx k) :: Ctx k where + x <+> EmptyCtx = x + x <+> (y ::> e) = (x <+> y) ::> e + +data Height = Zero | Succ Height + +data BalancedTree h (f :: k -> Type) (p :: Ctx k) where + BalLeaf :: !(f x) -> BalancedTree 'Zero f (SingleCtx x) + BalPair :: !(BalancedTree h f x) + -> !(BalancedTree h f y) + -> BalancedTree ('Succ h) f (x <+> y) + +data BinomialTree (h::Height) (f :: k -> Type) :: Ctx k -> Type where + Empty :: BinomialTree h f EmptyCtx + + PlusOne :: !Int + -> !(BinomialTree ('Succ h) f x) + -> !(BalancedTree h f y) + -> BinomialTree h f (x <+> y) + + PlusZero :: !Int + -> !(BinomialTree ('Succ h) f x) + -> BinomialTree h f x + +tsize :: BinomialTree h f a -> Int +tsize Empty = 0 +tsize (PlusOne s _ _) = 2*s+1 +tsize (PlusZero s _) = 2*s + +fmap_bin :: (forall tp . f tp -> g tp) + -> BinomialTree h f c + -> BinomialTree h g c +fmap_bin _ Empty = Empty +fmap_bin f (PlusOne s t x) = PlusOne s (fmap_bin f t) (fmap_bal f x) +fmap_bin f (PlusZero s t) = PlusZero s (fmap_bin f t) +{-# INLINABLE fmap_bin #-} + +fmap_bal :: (forall tp . f tp -> g tp) + -> BalancedTree h f c + -> BalancedTree h g c +fmap_bal = go + where go :: (forall tp . f tp -> g tp) + -> BalancedTree h f c + -> BalancedTree h g c + go f (BalLeaf x) = BalLeaf (f x) + go f (BalPair x y) = BalPair (go f x) (go f y) +{-# INLINABLE fmap_bal #-} + +traverse_bin :: Applicative m + => (forall tp . f tp -> m (g tp)) + -> BinomialTree h f c + -> m (BinomialTree h g c) +traverse_bin _ Empty = pure Empty +traverse_bin f (PlusOne s t x) = PlusOne s <$> traverse_bin f t <*> traverse_bal f x +traverse_bin f (PlusZero s t) = PlusZero s <$> traverse_bin f t +{-# INLINABLE traverse_bin #-} + +traverse_bal :: Applicative m + => (forall tp . f tp -> m (g tp)) + -> BalancedTree h f c + -> m (BalancedTree h g c) +traverse_bal = go + where go :: Applicative m + => (forall tp . f tp -> m (g tp)) + -> BalancedTree h f c + -> m (BalancedTree h g c) + go f (BalLeaf x) = BalLeaf <$> f x + go f (BalPair x y) = BalPair <$> go f x <*> go f y +{-# INLINABLE traverse_bal #-} + +data Assignment (f :: k -> Type) (ctx :: Ctx k) + = Assignment (BinomialTree 'Zero f ctx) + +newtype Index (ctx :: Ctx k) (tp :: k) = Index { indexVal :: Int } + +newtype Size (ctx :: Ctx k) = Size Int + +intIndex :: Int -> Size ctx -> Maybe (Some (Index ctx)) +intIndex i n | 0 <= i && i < sizeInt n = Just (Some (Index i)) + | otherwise = Nothing + +size :: Assignment f ctx -> Size ctx +size (Assignment t) = Size (tsize t) + +sizeInt :: Size ctx -> Int +sizeInt (Size n) = n + +class FunctorFC (t :: (k -> Type) -> l -> Type) where + fmapFC :: forall f g. (forall x. f x -> g x) -> + (forall x. t f x -> t g x) + +(#.) :: Coercible b c => (b -> c) -> (a -> b) -> (a -> c) +(#.) _f = coerce + +class FoldableFC (t :: (k -> Type) -> l -> Type) where + foldMapFC :: forall f m. Monoid m => (forall x. f x -> m) -> (forall x. t f x -> m) + foldMapFC f = foldrFC (mappend . f) mempty + + foldrFC :: forall f b. (forall x. f x -> b -> b) -> (forall x. b -> t f x -> b) + foldrFC f z t = appEndo (foldMapFC (Endo #. f) t) z + + toListFC :: forall f a. (forall x. f x -> a) -> (forall x. t f x -> [a]) + toListFC f t = build (\c n -> foldrFC (\e v -> c (f e) v) n t) + +foldMapFCDefault :: (TraversableFC t, Monoid m) => (forall x. f x -> m) -> (forall x. t f x -> m) +foldMapFCDefault = \f -> getConst . traverseFC (Const . f) +{-# INLINE foldMapFCDefault #-} + +class (FunctorFC t, FoldableFC t) => TraversableFC (t :: (k -> Type) -> l -> Type) where + traverseFC :: forall f g m. Applicative m + => (forall x. f x -> m (g x)) + -> (forall x. t f x -> m (t g x)) + +instance FunctorFC Assignment where + fmapFC = \f (Assignment x) -> Assignment (fmap_bin f x) + {-# INLINE fmapFC #-} + +instance FoldableFC Assignment where + foldMapFC = foldMapFCDefault + {-# INLINE foldMapFC #-} + +instance TraversableFC Assignment where + traverseFC = \f (Assignment x) -> Assignment <$> traverse_bin f x + {-# INLINE traverseFC #-} + +data CrucibleType + +data TypeRepr (tp::CrucibleType) where + +type CtxRepr = Assignment TypeRepr + +data CFG (ext :: Type) + (blocks :: Ctx (Ctx CrucibleType)) + (ret :: CrucibleType) + = CFG { cfgBlockMap :: !(BlockMap ext blocks ret) + } + +type BlockMap ext blocks ret = Assignment (Block ext blocks ret) blocks + +data Block ext (blocks :: Ctx (Ctx CrucibleType)) (ret :: CrucibleType) ctx + = Block { blockID :: !(BlockID blocks ctx) + , blockInputs :: !(CtxRepr ctx) + } + +newtype BlockID (blocks :: Ctx (Ctx CrucibleType)) (tp :: Ctx CrucibleType) + = BlockID { blockIDIndex :: Index blocks tp } + +data LLVM + +data AnyCFG ext where + AnyCFG :: CFG ext blocks ret + -> AnyCFG ext + +newtype StateContT s r m a + = StateContT { runStateContT :: (a -> s -> m r) + -> s + -> m r + } + +fmapStateContT :: (a -> b) -> StateContT s r m a -> StateContT s r m b +fmapStateContT = \f m -> StateContT $ \c -> runStateContT m (\v s -> (c $! f v) s) +{-# INLINE fmapStateContT #-} + +applyStateContT :: StateContT s r m (a -> b) -> StateContT s r m a -> StateContT s r m b +applyStateContT = \mf mv -> + StateContT $ \c -> + runStateContT mf (\f -> runStateContT mv (\v s -> (c $! f v) s)) +{-# INLINE applyStateContT #-} + +returnStateContT :: a -> StateContT s r m a +returnStateContT = \v -> seq v $ StateContT $ \c -> c v +{-# INLINE returnStateContT #-} + +bindStateContT :: StateContT s r m a -> (a -> StateContT s r m b) -> StateContT s r m b +bindStateContT = \m n -> StateContT $ \c -> runStateContT m (\a -> runStateContT (n a) c) +{-# INLINE bindStateContT #-} + +instance Functor (StateContT s r m) where + fmap = fmapStateContT + +instance Applicative (StateContT s r m) where + pure = returnStateContT + (<*>) = applyStateContT + +instance Monad (StateContT s r m) where + (>>=) = bindStateContT + +instance MonadFail m => MonadFail (StateContT s r m) where + fail = \msg -> StateContT $ \_ _ -> fail msg + +instance MonadTrans (StateContT s r) where + lift = \m -> StateContT $ \c s -> m >>= \v -> seq v (c v s) + +instance MonadIO m => MonadIO (StateContT s r m) where + liftIO = lift . liftIO + +instance MonadThrow m => MonadThrow (StateContT s r m) where + throwM e = StateContT (\_k _s -> throwM e) + +instance MonadCatch m => MonadCatch (StateContT s r m) where + catch m hdl = + StateContT $ \k s -> + catch + (runStateContT m k s) + (\e -> runStateContT (hdl e) k s) + +data TopLevelRO +data TopLevelRW +data Value + +newtype TopLevel a = + TopLevel_ (ReaderT TopLevelRO (StateContT TopLevelRW (Value, TopLevelRW) IO) a) + deriving (Applicative, Functor, Monad, MonadFail, MonadThrow, MonadCatch) + +instance MonadIO TopLevel where + liftIO = io + +io :: IO a -> TopLevel a +io f = TopLevel_ (liftIO f) `catches` [Handler handleIO] + where + rethrow :: X.Exception ex => ex -> TopLevel a + rethrow ex = throwM (X.SomeException ex) + + handleIO :: X.IOException -> TopLevel a + handleIO = rethrow + +data HeapsterEnv = HeapsterEnv { + heapsterEnvPermEnvRef :: IORef PermEnv + } + +data Hint where + +data PermEnv = PermEnv { + permEnvHints :: [Hint] + } + +permEnvAddHint :: PermEnv -> Hint -> PermEnv +permEnvAddHint env hint = env { permEnvHints = hint : permEnvHints env } + +type family CtxToRList (ctx :: Ctx k) :: RList k where + CtxToRList EmptyCtx = RNil + CtxToRList (ctx' ::> x) = CtxToRList ctx' :> x + +data RList a + = RNil + | (RList a) :> a diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index 321cb046b9..08ff9329ec 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -455,3 +455,4 @@ test('T21851_2', [grep_errmsg(r'wwombat') ], multimod_compile, ['T21851_2', '-O test('T22317', [grep_errmsg(r'ANSWER = YES') ], compile, ['-O -dinline-check m -ddebug-output']) test('T22494', [grep_errmsg(r'case') ], compile, ['-O -ddump-simpl -dsuppress-uniques']) +test('T22491', normal, compile, ['-O2']) |