summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Graf <sebastian.graf@kit.edu>2022-05-19 18:37:38 +0200
committerSebastian Graf <sebastian.graf@kit.edu>2022-06-20 09:43:29 +0200
commitb570da84b7aad5ca3f90f2d1c1a690c927e99fe9 (patch)
tree9210e2e9c3f37477db705df57dbf359da9e95baa
parent94f2e92a2510a3338c5201a4dcc69666fa9575f8 (diff)
downloadhaskell-b570da84b7aad5ca3f90f2d1c1a690c927e99fe9.tar.gz
CorePrep: Don't speculatively evaluate recursive calls (#20836)
In #20836 we have optimised a terminating program into an endless loop, because we speculated the self-recursive call of a recursive DFun. Now we track the set of enclosing recursive binders in CorePrep to prevent speculation of such self-recursive calls. See the updates to Note [Speculative evaluation] for details. Fixes #20836.
-rw-r--r--compiler/GHC/Core/Utils.hs64
-rw-r--r--compiler/GHC/CoreToStg/Prep.hs93
-rw-r--r--compiler/GHC/Data/Graph/UnVar.hs8
-rw-r--r--testsuite/tests/count-deps/CountDepsAst.stdout1
-rw-r--r--testsuite/tests/count-deps/CountDepsParser.stdout1
-rw-r--r--testsuite/tests/simplCore/should_run/T20836.hs23
-rw-r--r--testsuite/tests/simplCore/should_run/all.T1
7 files changed, 153 insertions, 38 deletions
diff --git a/compiler/GHC/Core/Utils.hs b/compiler/GHC/Core/Utils.hs
index 87dc238d62..8c727698f3 100644
--- a/compiler/GHC/Core/Utils.hs
+++ b/compiler/GHC/Core/Utils.hs
@@ -26,8 +26,8 @@ module GHC.Core.Utils (
exprIsDupable, exprIsTrivial, getIdFromTrivialExpr, exprIsDeadEnd,
getIdFromTrivialExpr_maybe,
exprIsCheap, exprIsExpandable, exprIsCheapX, CheapAppFun,
- exprIsHNF, exprOkForSpeculation, exprOkForSideEffects, exprIsWorkFree,
- exprIsConLike,
+ exprIsHNF, exprOkForSpeculation, exprOkForSideEffects, exprOkForSpecEval,
+ exprIsWorkFree, exprIsConLike,
isCheapApp, isExpandableApp, isSaturatedConApp,
exprIsTickedString, exprIsTickedString_maybe,
exprIsTopLevelBindable,
@@ -1560,46 +1560,55 @@ it's applied only to dictionaries.
-- side effects, and can't diverge or raise an exception.
exprOkForSpeculation, exprOkForSideEffects :: CoreExpr -> Bool
-exprOkForSpeculation = expr_ok primOpOkForSpeculation
-exprOkForSideEffects = expr_ok primOpOkForSideEffects
-
-expr_ok :: (PrimOp -> Bool) -> CoreExpr -> Bool
-expr_ok _ (Lit _) = True
-expr_ok _ (Type _) = True
-expr_ok _ (Coercion _) = True
-
-expr_ok primop_ok (Var v) = app_ok primop_ok v []
-expr_ok primop_ok (Cast e _) = expr_ok primop_ok e
-expr_ok primop_ok (Lam b e)
- | isTyVar b = expr_ok primop_ok e
+exprOkForSpeculation = expr_ok fun_always_ok primOpOkForSpeculation
+exprOkForSideEffects = expr_ok fun_always_ok primOpOkForSideEffects
+
+fun_always_ok :: Id -> Bool
+fun_always_ok _ = True
+
+-- | A special version of 'exprOkForSpeculation' used during
+-- Note [Speculative evaluation]. When the predicate arg `fun_ok` returns False
+-- for `b`, then `b` is never considered ok-for-spec.
+exprOkForSpecEval :: (Id -> Bool) -> CoreExpr -> Bool
+exprOkForSpecEval fun_ok = expr_ok fun_ok primOpOkForSpeculation
+
+expr_ok :: (Id -> Bool) -> (PrimOp -> Bool) -> CoreExpr -> Bool
+expr_ok _ _ (Lit _) = True
+expr_ok _ _ (Type _) = True
+expr_ok _ _ (Coercion _) = True
+
+expr_ok fun_ok primop_ok (Var v) = app_ok fun_ok primop_ok v []
+expr_ok fun_ok primop_ok (Cast e _) = expr_ok fun_ok primop_ok e
+expr_ok fun_ok primop_ok (Lam b e)
+ | isTyVar b = expr_ok fun_ok primop_ok e
| otherwise = True
-- Tick annotations that *tick* cannot be speculated, because these
-- are meant to identify whether or not (and how often) the particular
-- source expression was evaluated at runtime.
-expr_ok primop_ok (Tick tickish e)
+expr_ok fun_ok primop_ok (Tick tickish e)
| tickishCounts tickish = False
- | otherwise = expr_ok primop_ok e
+ | otherwise = expr_ok fun_ok primop_ok e
-expr_ok _ (Let {}) = False
+expr_ok _ _ (Let {}) = False
-- Lets can be stacked deeply, so just give up.
-- In any case, the argument of exprOkForSpeculation is
-- usually in a strict context, so any lets will have been
-- floated away.
-expr_ok primop_ok (Case scrut bndr _ alts)
+expr_ok fun_ok primop_ok (Case scrut bndr _ alts)
= -- See Note [exprOkForSpeculation: case expressions]
- expr_ok primop_ok scrut
+ expr_ok fun_ok primop_ok scrut
&& isUnliftedType (idType bndr)
-- OK to call isUnliftedType: binders always have a fixed RuntimeRep
- && all (\(Alt _ _ rhs) -> expr_ok primop_ok rhs) alts
+ && all (\(Alt _ _ rhs) -> expr_ok fun_ok primop_ok rhs) alts
&& altsAreExhaustive alts
-expr_ok primop_ok other_expr
+expr_ok fun_ok primop_ok other_expr
| (expr, args) <- collectArgs other_expr
= case stripTicksTopE (not . tickishCounts) expr of
Var f ->
- app_ok primop_ok f args
+ app_ok fun_ok primop_ok f args
-- 'LitRubbish' is the only literal that can occur in the head of an
-- application and will not be matched by the above case (Var /= Lit).
@@ -1613,8 +1622,11 @@ expr_ok primop_ok other_expr
_ -> False
-----------------------------
-app_ok :: (PrimOp -> Bool) -> Id -> [CoreExpr] -> Bool
-app_ok primop_ok fun args
+app_ok :: (Id -> Bool) -> (PrimOp -> Bool) -> Id -> [CoreExpr] -> Bool
+app_ok fun_ok primop_ok fun args
+ | not (fun_ok fun)
+ = False -- This code path is only taken for Note [Speculative evaluation]
+ | otherwise
= case idDetails fun of
DFunId new_type -> not new_type
-- DFuns terminate, unless the dict is implemented
@@ -1628,7 +1640,7 @@ app_ok primop_ok fun args
PrimOpId op _
| primOpIsDiv op
, [arg1, Lit lit] <- args
- -> not (isZeroLit lit) && expr_ok primop_ok arg1
+ -> not (isZeroLit lit) && expr_ok fun_ok primop_ok arg1
-- Special case for dividing operations that fail
-- In general they are NOT ok-for-speculation
-- (which primop_ok will catch), but they ARE OK
@@ -1679,7 +1691,7 @@ app_ok primop_ok fun args
| Just Lifted <- typeLevity_maybe (scaledThing ty)
= True -- See Note [Primops with lifted arguments]
| otherwise
- = expr_ok primop_ok arg
+ = expr_ok fun_ok primop_ok arg
-----------------------------
altsAreExhaustive :: [Alt b] -> Bool
diff --git a/compiler/GHC/CoreToStg/Prep.hs b/compiler/GHC/CoreToStg/Prep.hs
index f4b6f2908d..026b134f94 100644
--- a/compiler/GHC/CoreToStg/Prep.hs
+++ b/compiler/GHC/CoreToStg/Prep.hs
@@ -50,6 +50,7 @@ import GHC.Data.Maybe
import GHC.Data.OrdList
import GHC.Data.FastString
import GHC.Data.Pair
+import GHC.Data.Graph.UnVar
import GHC.Utils.Error
import GHC.Utils.Misc
@@ -603,7 +604,7 @@ cpeBind top_lvl env (NonRec bndr rhs)
| otherwise
= addFloat floats new_float
- new_float = mkFloat dmd is_unlifted bndr1 rhs1
+ new_float = mkFloat env dmd is_unlifted bndr1 rhs1
; return (env2, floats1, Nothing) }
@@ -617,24 +618,27 @@ cpeBind top_lvl env (NonRec bndr rhs)
cpeBind top_lvl env (Rec pairs)
| not (isJoinId (head bndrs))
- = do { (env', bndrs1) <- cpCloneBndrs env bndrs
+ = do { (env, bndrs1) <- cpCloneBndrs env bndrs
+ ; let env' = enterRecGroupRHSs env bndrs1
; stuff <- zipWithM (cpePair top_lvl Recursive topDmd False env')
bndrs1 rhss
; let (floats_s, rhss1) = unzip stuff
all_pairs = foldrOL add_float (bndrs1 `zip` rhss1)
(concatFloats floats_s)
-
+ -- use env below, so that we reset cpe_rec_ids
; return (extendCorePrepEnvList env (bndrs `zip` bndrs1),
unitFloat (FloatLet (Rec all_pairs)),
Nothing) }
| otherwise -- See Note [Join points and floating]
- = do { (env', bndrs1) <- cpCloneBndrs env bndrs
+ = do { (env, bndrs1) <- cpCloneBndrs env bndrs
+ ; let env' = enterRecGroupRHSs env bndrs1
; pairs1 <- zipWithM (cpeJoinPair env') bndrs1 rhss
; let bndrs2 = map fst pairs1
- ; return (extendCorePrepEnvList env' (bndrs `zip` bndrs2),
+ -- use env below, so that we reset cpe_rec_ids
+ ; return (extendCorePrepEnvList env (bndrs `zip` bndrs2),
emptyFloats,
Just (Rec pairs1)) }
where
@@ -666,7 +670,7 @@ cpePair top_lvl is_rec dmd is_unlifted env bndr rhs
else warnPprTrace True "CorePrep: silly extra arguments:" (ppr bndr) $
-- Note [Silly extra arguments]
(do { v <- newVar (idType bndr)
- ; let float = mkFloat topDmd False v rhs2
+ ; let float = mkFloat env topDmd False v rhs2
; return ( addFloat floats2 float
, cpeEtaExpand arity (Var v)) })
@@ -1510,7 +1514,7 @@ cpeArg env dmd arg
; if okCpeArg arg2
then do { v <- newVar arg_ty
; let arg3 = cpeEtaExpand (exprArity arg2) arg2
- arg_float = mkFloat dmd is_unlifted v arg3
+ arg_float = mkFloat env dmd is_unlifted v arg3
; return (addFloat floats2 arg_float, varToCoreExpr v) }
else return (floats2, arg2)
}
@@ -1656,6 +1660,66 @@ to allocate a thunk for it, whose closure must be retained as
long as the callee might evaluate it. And if it is evaluated on
most code paths anyway, we get to turn the unknown eval in the
callee into a known call at the call site.
+
+However, we must be very careful not to speculate recursive calls!
+Doing so might well change termination behavior.
+
+That comes up in practice for DFuns, which are considered ok-for-spec,
+because they always immediately return a constructor.
+Not so if you speculate the recursive call, as #20836 shows:
+
+ class Foo m => Foo m where
+ runFoo :: m a -> m a
+ newtype Trans m a = Trans { runTrans :: m a }
+ instance Monad m => Foo (Trans m) where
+ runFoo = id
+
+(NB: class Foo m => Foo m` looks weird and needs -XUndecidableSuperClasses. The
+example in #20836 is more compelling, but boils down to the same thing.)
+This program compiles to the following DFun for the `Trans` instance:
+
+ Rec {
+ $fFooTrans
+ = \ @m $dMonad -> C:Foo ($fFooTrans $dMonad) (\ @a -> id)
+ end Rec }
+
+Note that the DFun immediately terminates and produces a dictionary, just
+like DFuns ought to, but it calls itself recursively to produce the `Foo m`
+dictionary. But alas, if we treat `$fFooTrans` as always-terminating, so
+that we can speculate its calls, and hence use call-by-value, we get:
+
+ $fFooTrans
+ = \ @m $dMonad -> case ($fFooTrans $dMonad) of sc ->
+ C:Foo sc (\ @a -> id)
+
+and that's an infinite loop!
+Note that this bad-ness only happens in `$fFooTrans`'s own RHS. In the
+*body* of the letrec, it's absolutely fine to use call-by-value on
+`foo ($fFooTrans d)`.
+
+Our solution is this: we track in cpe_rec_ids the set of enclosing
+recursively-bound Ids, the RHSs of which we are currently transforming and then
+in 'exprOkForSpecEval' (a special entry point to 'exprOkForSpeculation',
+basically) we'll say that any binder in this set is not ok-for-spec.
+
+Note if we have a letrec group `Rec { f1 = rhs1; ...; fn = rhsn }`, and we
+prep up `rhs1`, we have to include not only `f1`, but all binders of the group
+`f1..fn` in this set, otherwise our fix is not robust wrt. mutual recursive
+DFuns.
+
+NB: If at some point we decide to have a termination analysis for general
+functions (#8655, !1866), we need to take similar precautions for (guarded)
+recursive functions:
+
+ repeat x = x : repeat x
+
+Same problem here: As written, repeat evaluates rapidly to WHNF. So `repeat x`
+is a cheap call that we are willing to speculate, but *not* in repeat's RHS.
+Fortunately, pce_rec_ids already has all the information we need in that case.
+
+The problem is very similar to Note [Eta reduction in recursive RHSs].
+Here as well as there it is *unsound* to change the termination properties
+of the very function whose termination properties we are exploiting.
-}
data FloatingBind
@@ -1702,8 +1766,8 @@ data OkToSpec
-- ok-to-speculate unlifted bindings
| NotOkToSpec -- Some not-ok-to-speculate unlifted bindings
-mkFloat :: Demand -> Bool -> Id -> CpeRhs -> FloatingBind
-mkFloat dmd is_unlifted bndr rhs
+mkFloat :: CorePrepEnv -> Demand -> Bool -> Id -> CpeRhs -> FloatingBind
+mkFloat env dmd is_unlifted bndr rhs
| is_strict || ok_for_spec -- See Note [Speculative evaluation]
, not is_hnf = FloatCase rhs bndr DEFAULT [] ok_for_spec
-- Don't make a case for a HNF binding, even if it's strict
@@ -1730,7 +1794,8 @@ mkFloat dmd is_unlifted bndr rhs
where
is_hnf = exprIsHNF rhs
is_strict = isStrUsedDmd dmd
- ok_for_spec = exprOkForSpeculation rhs
+ ok_for_spec = exprOkForSpecEval (not . is_rec_call) rhs
+ is_rec_call = (`elemUnVarSet` cpe_rec_ids env)
emptyFloats :: Floats
emptyFloats = Floats OkToSpec nilOL
@@ -1941,6 +2006,8 @@ data CorePrepEnv
-- and Note [CorePrep inlines trivial CoreExpr not Id] (#12076)
, cpe_tyco_env :: Maybe CpeTyCoEnv -- See Note [CpeTyCoEnv]
+
+ , cpe_rec_ids :: UnVarSet -- Faster OutIdSet; See Note [Speculative evaluation]
}
mkInitialCorePrepEnv :: CorePrepConfig -> CorePrepEnv
@@ -1948,6 +2015,7 @@ mkInitialCorePrepEnv cfg = CPE
{ cpe_config = cfg
, cpe_env = emptyVarEnv
, cpe_tyco_env = Nothing
+ , cpe_rec_ids = emptyUnVarSet
}
extendCorePrepEnv :: CorePrepEnv -> Id -> Id -> CorePrepEnv
@@ -1969,6 +2037,10 @@ lookupCorePrepEnv cpe id
Nothing -> Var id
Just exp -> exp
+enterRecGroupRHSs :: CorePrepEnv -> [OutId] -> CorePrepEnv
+enterRecGroupRHSs env grp
+ = env { cpe_rec_ids = extendUnVarSetList grp (cpe_rec_ids env) }
+
------------------------------------------------------------------------------
-- CpeTyCoEnv
-- ---------------------------------------------------------------------------
@@ -2270,4 +2342,3 @@ mkConvertNumLiteral platform home_unit lookup_global = do
return convertNumLit
-
diff --git a/compiler/GHC/Data/Graph/UnVar.hs b/compiler/GHC/Data/Graph/UnVar.hs
index 5bfc23eef6..f5a9e1e54a 100644
--- a/compiler/GHC/Data/Graph/UnVar.hs
+++ b/compiler/GHC/Data/Graph/UnVar.hs
@@ -17,7 +17,7 @@ equal to g, but twice as expensive and large.
module GHC.Data.Graph.UnVar
( UnVarSet
, emptyUnVarSet, mkUnVarSet, varEnvDom, unionUnVarSet, unionUnVarSets
- , extendUnVarSet, delUnVarSet
+ , extendUnVarSet, extendUnVarSetList, delUnVarSet, delUnVarSetList
, elemUnVarSet, isEmptyUnVarSet
, UnVarGraph
, emptyUnVarGraph
@@ -63,6 +63,9 @@ isEmptyUnVarSet (UnVarSet s) = S.null s
delUnVarSet :: UnVarSet -> Var -> UnVarSet
delUnVarSet (UnVarSet s) v = UnVarSet $ k v `S.delete` s
+delUnVarSetList :: UnVarSet -> [Var] -> UnVarSet
+delUnVarSetList s vs = s `minusUnVarSet` mkUnVarSet vs
+
minusUnVarSet :: UnVarSet -> UnVarSet -> UnVarSet
minusUnVarSet (UnVarSet s) (UnVarSet s') = UnVarSet $ s `S.difference` s'
@@ -78,6 +81,9 @@ varEnvDom ae = UnVarSet $ ufmToSet_Directly ae
extendUnVarSet :: Var -> UnVarSet -> UnVarSet
extendUnVarSet v (UnVarSet s) = UnVarSet $ S.insert (k v) s
+extendUnVarSetList :: [Var] -> UnVarSet -> UnVarSet
+extendUnVarSetList vs s = s `unionUnVarSet` mkUnVarSet vs
+
unionUnVarSet :: UnVarSet -> UnVarSet -> UnVarSet
unionUnVarSet (UnVarSet set1) (UnVarSet set2) = UnVarSet (set1 `S.union` set2)
diff --git a/testsuite/tests/count-deps/CountDepsAst.stdout b/testsuite/tests/count-deps/CountDepsAst.stdout
index 0c7d753be6..75ef4e13de 100644
--- a/testsuite/tests/count-deps/CountDepsAst.stdout
+++ b/testsuite/tests/count-deps/CountDepsAst.stdout
@@ -76,6 +76,7 @@ GHC.Data.FastString
GHC.Data.FastString.Env
GHC.Data.FiniteMap
GHC.Data.Graph.Directed
+GHC.Data.Graph.UnVar
GHC.Data.IOEnv
GHC.Data.List.SetOps
GHC.Data.Maybe
diff --git a/testsuite/tests/count-deps/CountDepsParser.stdout b/testsuite/tests/count-deps/CountDepsParser.stdout
index 30267860d8..a4a51fbf9a 100644
--- a/testsuite/tests/count-deps/CountDepsParser.stdout
+++ b/testsuite/tests/count-deps/CountDepsParser.stdout
@@ -76,6 +76,7 @@ GHC.Data.FastString
GHC.Data.FastString.Env
GHC.Data.FiniteMap
GHC.Data.Graph.Directed
+GHC.Data.Graph.UnVar
GHC.Data.IOEnv
GHC.Data.List.SetOps
GHC.Data.Maybe
diff --git a/testsuite/tests/simplCore/should_run/T20836.hs b/testsuite/tests/simplCore/should_run/T20836.hs
new file mode 100644
index 0000000000..462fdb3ac7
--- /dev/null
+++ b/testsuite/tests/simplCore/should_run/T20836.hs
@@ -0,0 +1,23 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableSuperClasses #-}
+
+import Data.Kind (Type)
+
+class (Monad m, MonadFoo (FooM m)) => MonadFoo m where
+ type FooM m :: Type -> Type
+ runFoo :: FooM m a -> m a
+
+newtype MyMonad m a = MyMonad { runMyMonad :: m a }
+ deriving (Functor, Applicative, Monad)
+
+instance Monad m => MonadFoo (MyMonad m) where
+ type FooM (MyMonad m) = MyMonad m
+ runFoo = id
+
+main :: IO ()
+main = runMyMonad foo
+
+foo :: MonadFoo m => m ()
+foo = runFoo $ return ()
diff --git a/testsuite/tests/simplCore/should_run/all.T b/testsuite/tests/simplCore/should_run/all.T
index 509ae1ff57..bebd839724 100644
--- a/testsuite/tests/simplCore/should_run/all.T
+++ b/testsuite/tests/simplCore/should_run/all.T
@@ -105,3 +105,4 @@ test('T19313', normal, compile_and_run, [''])
test('UnliftedArgRule', normal, compile_and_run, [''])
test('T21229', normal, compile_and_run, ['-O'])
test('T21575', normal, compile_and_run, ['-O'])
+test('T20836', normal, compile_and_run, ['-O0']) # Should not time out; See #20836