From 66ff794fedf6e81e727dc8f651e63afe6f2a874b Mon Sep 17 00:00:00 2001 From: Simon Peyton Jones Date: Tue, 9 Jan 2018 13:53:09 +0000 Subject: Fix join-point decision This patch moves the "ok_unfolding" test from CoreOpt.joinPointBinding_maybe to OccurAnal.decideJoinPointHood Previously the occurrence analyser was deciding to make something a join point, but the simplifier was reversing that decision, which made the decision about /other/ bindings invalid. Fixes Trac #14650. --- compiler/coreSyn/CoreOpt.hs | 44 +------------ compiler/simplCore/OccurAnal.hs | 68 +++++++++++++++---- testsuite/tests/simplCore/should_compile/T14650.hs | 76 ++++++++++++++++++++++ testsuite/tests/simplCore/should_compile/all.T | 1 + 4 files changed, 136 insertions(+), 53 deletions(-) create mode 100644 testsuite/tests/simplCore/should_compile/T14650.hs diff --git a/compiler/coreSyn/CoreOpt.hs b/compiler/coreSyn/CoreOpt.hs index 4240647d58..0f35e8f3ac 100644 --- a/compiler/coreSyn/CoreOpt.hs +++ b/compiler/coreSyn/CoreOpt.hs @@ -22,7 +22,7 @@ module CoreOpt ( import GhcPrelude -import CoreArity( joinRhsArity, etaExpandToJoinPoint ) +import CoreArity( etaExpandToJoinPoint ) import CoreSyn import CoreSubst @@ -646,58 +646,18 @@ joinPointBinding_maybe bndr rhs = Just (bndr, rhs) | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr) - , not (bad_unfolding join_arity (idUnfolding bndr)) , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs = Just (bndr `asJoinId` join_arity, mkLams bndrs body) | otherwise = Nothing - where - -- bad_unfolding returns True if we should /not/ convert a non-join-id - -- into a join-id, even though it is AlwaysTailCalled - -- See Note [Join points and INLINE pragmas] - bad_unfolding join_arity (CoreUnfolding { uf_src = src, uf_tmpl = rhs }) - = isStableSource src && join_arity > joinRhsArity rhs - bad_unfolding _ (DFunUnfolding {}) - = True - bad_unfolding _ _ - = False - joinPointBindings_maybe :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)] joinPointBindings_maybe bndrs = mapM (uncurry joinPointBinding_maybe) bndrs -{- Note [Join points and INLINE pragmas] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Consider - f x = let g = \x. not -- Arity 1 - {-# INLINE g #-} - in case x of - A -> g True True - B -> g True False - C -> blah2 - -Here 'g' is always tail-called applied to 2 args, but the stable -unfolding captured by the INLINE pragma has arity 1. If we try to -convert g to be a join point, its unfolding will still have arity 1 -(since it is stable, and we don't meddle with stable unfoldings), and -Lint will complain (see Note [Invariants on join points], (2a), in -CoreSyn. Trac #13413. - -Moreover, since g is going to be inlined anyway, there is no benefit -from making it a join point. - -If it is recursive, and uselessly marked INLINE, this will stop us -making it a join point, which is annoying. But occasionally -(notably in class methods; see Note [Instances and loop breakers] in -TcInstDcls) we mark recursive things as INLINE but the recursion -unravels; so ignoring INLINE pragmas on recursive things isn't good -either. - - -************************************************************************ +{- ********************************************************************* * * exprIsConApp_maybe * * diff --git a/compiler/simplCore/OccurAnal.hs b/compiler/simplCore/OccurAnal.hs index bcc84100a1..b0987d5da0 100644 --- a/compiler/simplCore/OccurAnal.hs +++ b/compiler/simplCore/OccurAnal.hs @@ -25,6 +25,7 @@ import CoreSyn import CoreFVs import CoreUtils ( exprIsTrivial, isDefaultAlt, isExpandableApp, stripTicksTopE, mkTicks ) +import CoreArity ( joinRhsArity ) import Id import IdInfo import Name( localiseName ) @@ -2664,9 +2665,8 @@ tagRecBinders lvl body_uds triples , AlwaysTailCalled arity <- tailCallInfo occ = Just arity | otherwise - = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if we're - -- making join points! - Nothing + = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if + Nothing -- we are making join points! -- 3. Compute final usage details from adjusted RHS details adj_uds = body_uds +++ combineUsageDetailsList rhs_udss' @@ -2694,10 +2694,15 @@ setBinderOcc occ_info bndr -- | Decide whether some bindings should be made into join points or not. -- Returns `False` if they can't be join points. Note that it's an --- all-or-nothing decision, as if multiple binders are given, they're assumed to --- be mutually recursive. +-- all-or-nothing decision, as if multiple binders are given, they're +-- assumed to be mutually recursive. -- --- See Note [Invariants for join points] in CoreSyn. +-- It must, however, be a final decision. If we say "True" for 'f', +-- and then subsequently decide /not/ make 'f' into a join point, then +-- the decision about another binding 'g' might be invalidated if (say) +-- 'f' tail-calls 'g'. +-- +-- See Note [Invariants on join points] in CoreSyn. decideJoinPointHood :: TopLevelFlag -> UsageDetails -> [CoreBndr] -> Bool @@ -2721,6 +2726,9 @@ decideJoinPointHood NotTopLevel usage bndrs AlwaysTailCalled arity <- tailCallInfo (lookupDetails usage bndr) , -- Invariant 1 as applied to LHSes of rules all (ok_rule arity) (idCoreRules bndr) + -- Invariant 2a: stable unfoldings + -- See Note [Join points and INLINE pragmas] + , ok_unfolding arity (realIdUnfolding bndr) -- Invariant 4: Satisfies polymorphism rule , isValidJoinPointType arity (idType bndr) = True @@ -2732,14 +2740,52 @@ decideJoinPointHood NotTopLevel usage bndrs = args `lengthIs` join_arity -- Invariant 1 as applied to LHSes of rules + -- ok_unfolding returns False if we should /not/ convert a non-join-id + -- into a join-id, even though it is AlwaysTailCalled + ok_unfolding join_arity (CoreUnfolding { uf_src = src, uf_tmpl = rhs }) + = not (isStableSource src && join_arity > joinRhsArity rhs) + ok_unfolding _ (DFunUnfolding {}) + = False + ok_unfolding _ _ + = True + willBeJoinId_maybe :: CoreBndr -> Maybe JoinArity willBeJoinId_maybe bndr - | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr) - = Just arity - | otherwise - = isJoinId_maybe bndr + = case tailCallInfo (idOccInfo bndr) of + AlwaysTailCalled arity -> Just arity + _ -> isJoinId_maybe bndr + + +{- Note [Join points and INLINE pragmas] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Consider + f x = let g = \x. not -- Arity 1 + {-# INLINE g #-} + in case x of + A -> g True True + B -> g True False + C -> blah2 + +Here 'g' is always tail-called applied to 2 args, but the stable +unfolding captured by the INLINE pragma has arity 1. If we try to +convert g to be a join point, its unfolding will still have arity 1 +(since it is stable, and we don't meddle with stable unfoldings), and +Lint will complain (see Note [Invariants on join points], (2a), in +CoreSyn. Trac #13413. + +Moreover, since g is going to be inlined anyway, there is no benefit +from making it a join point. + +If it is recursive, and uselessly marked INLINE, this will stop us +making it a join point, which is annoying. But occasionally +(notably in class methods; see Note [Instances and loop breakers] in +TcInstDcls) we mark recursive things as INLINE but the recursion +unravels; so ignoring INLINE pragmas on recursive things isn't good +either. + +See Invariant 2a of Note [Invariants on join points] in CoreSyn + -{- ************************************************************************ * * \subsection{Operations over OccInfo} diff --git a/testsuite/tests/simplCore/should_compile/T14650.hs b/testsuite/tests/simplCore/should_compile/T14650.hs new file mode 100644 index 0000000000..b9eac20021 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T14650.hs @@ -0,0 +1,76 @@ +module MergeSort ( + msortBy + ) where + +infixl 7 :% +infixr 6 :& + +data LenList a = LL {-# UNPACK #-} !Int Bool [a] + +data LenListAnd a b = {-# UNPACK #-} !(LenList a) :% b + +data Stack a + = End + | {-# UNPACK #-} !(LenList a) :& (Stack a) + +msortBy :: (a -> a -> Ordering) -> [a] -> [a] +msortBy cmp = mergeSplit End where + splitAsc n _ _ _ | n `seq` False = undefined + splitAsc n as _ [] = LL n True as :% [] + splitAsc n as a bs@(b:bs') = case cmp a b of + GT -> LL n False as :% bs + _ -> splitAsc (n + 1) as b bs' + + splitDesc n _ _ _ | n `seq` False = undefined + splitDesc n rs a [] = LL n True (a:rs) :% [] + splitDesc n rs a bs@(b:bs') = case cmp a b of + GT -> splitDesc (n + 1) (a:rs) b bs' + _ -> LL n True (a:rs) :% bs + + mergeLL (LL na fa as) (LL nb fb bs) = LL (na + nb) True $ mergeLs na as nb bs where + mergeLs nx _ ny _ | nx `seq` ny `seq` False = undefined + mergeLs 0 _ ny ys = if fb then ys else take ny ys + mergeLs _ [] ny ys = if fb then ys else take ny ys + mergeLs nx xs 0 _ = if fa then xs else take nx xs + mergeLs nx xs _ [] = if fa then xs else take nx xs + mergeLs nx xs@(x:xs') ny ys@(y:ys') = case cmp x y of + GT -> y:mergeLs nx xs (ny - 1) ys' + _ -> x:mergeLs (nx - 1) xs' ny ys + + push ssx px@(LL nx _ _) = case ssx of + End -> px :% ssx + py@(LL ny _ _) :& ssy -> case ssy of + End + | nx >= ny -> mergeLL py px :% ssy + pz@(LL nz _ _) :& ssz + | nx >= ny || nx + ny >= nz -> case nx > nz of + False -> push ssy $ mergeLL py px + _ -> case push ssz $ mergeLL pz py of + pz' :% ssz' -> push (pz' :& ssz') px + _ -> px :% ssx + + mergeAll _ px | px `seq` False = undefined + mergeAll ssx px@(LL nx _ xs) = case ssx of + End -> xs + py@(LL _ _ _) :& ssy -> case ssy of + End -> case mergeLL py px of + LL _ _ xys -> xys + pz@(LL nz _ _) :& ssz -> case nx > nz of + False -> mergeAll ssy $ mergeLL py px + _ -> case push ssz $ mergeLL pz py of + pz' :% ssz' -> mergeAll (pz' :& ssz') px + + mergeSplit ss _ | ss `seq` False = undefined + mergeSplit ss [] = case ss of + End -> [] + px :& ss' -> mergeAll ss' px + mergeSplit ss as@(a:as') = case as' of + [] -> mergeAll ss $ LL 1 True as + b:bs -> case cmp a b of + GT -> case splitDesc 2 [a] b bs of + px :% rs -> case push ss px of + px' :% ss' -> mergeSplit (px' :& ss') rs + _ -> case splitAsc 2 as b bs of + px :% rs -> case push ss px of + px' :% ss' -> mergeSplit (px' :& ss') rs + {-# INLINABLE mergeSplit #-} diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index e51e8f7db4..e681ca7363 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -289,3 +289,4 @@ test('T14152a', [extra_files(['T14152.hs']), pre_cmd('cp T14152.hs T14152a.hs'), only_ways(['optasm']), check_errmsg(r'dead code') ], compile, ['-fno-exitification -ddump-simpl']) test('T13990', normal, compile, ['-dcore-lint -O']) +test('T14650', normal, compile, ['-O2']) -- cgit v1.2.1