From b5b7d820afd8fca098bf1f4a7380d425ca6be31d Mon Sep 17 00:00:00 2001 From: Simon Peyton Jones Date: Mon, 10 Apr 2017 08:51:49 +0100 Subject: Improve demand analysis for join points I realised (Trac #13543) that we can improve demand analysis for join point quite straightforwardly. The idea is explained in Note [Demand analysis for join points] in DmdAnal --- compiler/stranal/DmdAnal.hs | 82 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 20 deletions(-) (limited to 'compiler/stranal') diff --git a/compiler/stranal/DmdAnal.hs b/compiler/stranal/DmdAnal.hs index 304a2becb3..78eefe39a1 100644 --- a/compiler/stranal/DmdAnal.hs +++ b/compiler/stranal/DmdAnal.hs @@ -64,20 +64,20 @@ dmdAnalProgram dflags fam_envs binds dmdAnalTopBind :: AnalEnv -> CoreBind -> (AnalEnv, CoreBind) -dmdAnalTopBind sigs (NonRec id rhs) - = (extendAnalEnv TopLevel sigs id2 (idStrictness id2), NonRec id2 rhs2) +dmdAnalTopBind env (NonRec id rhs) + = (extendAnalEnv TopLevel env id2 (idStrictness id2), NonRec id2 rhs2) where - ( _, _, rhs1) = dmdAnalRhsLetDown TopLevel Nothing sigs id rhs - ( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin sigs) id rhs1 + ( _, _, rhs1) = dmdAnalRhsLetDown TopLevel Nothing env cleanEvalDmd id rhs + ( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin env) cleanEvalDmd id rhs1 -- Do two passes to improve CPR information -- See Note [CPR for thunks] -- See Note [Optimistic CPR in the "virgin" case] -- See Note [Initial CPR for strict binders] -dmdAnalTopBind sigs (Rec pairs) - = (sigs', Rec pairs') +dmdAnalTopBind env (Rec pairs) + = (env', Rec pairs') where - (sigs', _, pairs') = dmdFix TopLevel sigs pairs + (env', _, pairs') = dmdFix TopLevel env cleanEvalDmd pairs -- We get two iterations automatically -- c.f. the NonRec case above @@ -308,7 +308,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body) dmdAnal' env dmd (Let (NonRec id rhs) body) = (body_ty2, Let (NonRec id2 rhs') body') where - (lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env id rhs + (lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env dmd id rhs env1 = extendAnalEnv NotTopLevel env id1 (idStrictness id1) (body_ty, body') = dmdAnal env1 dmd body (body_ty1, id2) = annotateBndr env body_ty id1 @@ -329,7 +329,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body) dmdAnal' env dmd (Let (Rec pairs) body) = let - (env', lazy_fv, pairs') = dmdFix NotTopLevel env pairs + (env', lazy_fv, pairs') = dmdFix NotTopLevel env dmd pairs (body_ty, body') = dmdAnal env' dmd body body_ty1 = deleteFVs body_ty (map fst pairs) body_ty2 = addLazyFVs body_ty1 lazy_fv -- see Note [Lazy and unleasheable free variables] @@ -509,17 +509,17 @@ dmdTransform env var dmd -- Recursive bindings dmdFix :: TopLevelFlag -> AnalEnv -- Does not include bindings for this binding + -> CleanDemand -> [(Id,CoreExpr)] -> (AnalEnv, DmdEnv, [(Id,CoreExpr)]) -- Binders annotated with stricness info -dmdFix top_lvl env orig_pairs +dmdFix top_lvl env let_dmd orig_pairs = loop 1 initial_pairs where bndrs = map fst orig_pairs -- See Note [Initialising strictness] initial_pairs | ae_virgin env = [(setIdStrictness id botSig, rhs) | (id, rhs) <- orig_pairs ] - | otherwise = orig_pairs -- If fixed-point iteration does not yield a result we use this instead @@ -562,7 +562,7 @@ dmdFix top_lvl env orig_pairs my_downRhs (env, lazy_fv) (id,rhs) = ((env', lazy_fv'), (id', rhs')) where - (lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env id rhs + (lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env let_dmd id rhs lazy_fv' = plusVarEnv_C bothDmd lazy_fv lazy_fv1 env' = extendAnalEnv top_lvl env id (idStrictness id') @@ -621,18 +621,27 @@ dmdAnalTrivialRhs env id rhs fn -- This is the LetDown rule in the paper “Higher-Order Cardinality Analysis”. dmdAnalRhsLetDown :: TopLevelFlag -> Maybe [Id] -- Just bs <=> recursive, Nothing <=> non-recursive - -> AnalEnv -> Id -> CoreExpr + -> AnalEnv -> CleanDemand + -> Id -> CoreExpr -> (DmdEnv, Id, CoreExpr) -- Process the RHS of the binding, add the strictness signature -- to the Id, and augment the environment with the signature as well. -dmdAnalRhsLetDown top_lvl rec_flag env id rhs +dmdAnalRhsLetDown top_lvl rec_flag env let_dmd id rhs | Just fn <- unpackTrivial rhs -- See Note [Demand analysis for trivial right-hand sides] = dmdAnalTrivialRhs env id rhs fn | otherwise = (lazy_fv, id', mkLams bndrs' body') where - (bndrs, body) = collectBinders rhs + (bndrs, body, body_dmd) + = case isJoinId_maybe id of + Just join_arity -- See Note [Demand analysis for join points] + | (bndrs, body) <- collectNBinders join_arity rhs + -> (bndrs, body, let_dmd) + + Nothing | (bndrs, body) <- collectBinders rhs + -> (bndrs, body, mkBodyDmd env body) + env_body = foldl extendSigsWithLam env bndrs (body_ty, body') = dmdAnal env_body body_dmd body body_ty' = removeDmdTyArgs body_ty -- zap possible deep CPR info @@ -642,10 +651,6 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs id' = set_idStrictness env id sig_ty -- See Note [NOINLINE and strictness] - -- See Note [Product demands for function body] - body_dmd = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of - Nothing -> cleanEvalDmd - Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc) -- See Note [Aggregated demand for cardinality] rhs_fv1 = case rec_flag of @@ -667,6 +672,13 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs || not (isStrictDmd (idDemandInfo id) || ae_virgin env) -- See Note [Optimistic CPR in the "virgin" case] +mkBodyDmd :: AnalEnv -> CoreExpr -> CleanDemand +-- See Note [Product demands for function body] +mkBodyDmd env body + = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of + Nothing -> cleanEvalDmd + Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc) + unpackTrivial :: CoreExpr -> Maybe Id -- Returns (Just v) if the arg is really equal to v, modulo -- casts, type applications etc @@ -691,7 +703,37 @@ useLetUp _ (Lam _ _) = False useLetUp _ _ = True -{- +{- Note [Demand analysis for join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Consider + g :: (Int,Int) -> Int + g (p,q) = p+q + + f :: T -> Int -> Int + f x p = g (join j y = (p,y) + in case x of + A -> j 3 + B -> j 4 + C -> (p,7)) + +If j was a vanilla function definition, we'd analyse its body with +evalDmd, and think that it was lazy in p. But for join points we can +do better! We know that j's body will (if called at all) be evaluated +with the demand that consumes the entire join-binding, in this case +the argument demand from g. Whizzo! g evaluates both components of +its arugment pair, so p will certainly be evaluated if j is called. + +For f to be strict in p, we need /all/ paths to evaluate p; in this +case the C branch does so too, so we are fine. So, as usual, we need +to transport demands on free variables to the call site(s). Compare +Note [Lazy and unleasheable free variables]. + +The implementation is easy. Wwhen analysing a join point, we can +analyse its body with the demand from the entire join-binding (written +let_dmd here). + +Another win for join points! Trac #13543. + Note [Demand analysis for trivial right-hand sides] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Consider -- cgit v1.2.1