summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2021-05-04 23:41:48 +0100
committerSimon Peyton Jones <simonpj@microsoft.com>2021-05-06 22:36:09 +0100
commit9a050c7774c67cb3bfb72bbd0d609280564d14ec (patch)
treef8828adc3ef4404440378d4574bb126e46af8ffa
parent30f6923a834ccaca30c3622a0a82421fabcab119 (diff)
downloadhaskell-wip/T19780.tar.gz
Fix strictness and arity info in SpecConstrwip/T19780
In GHC.Core.Opt.SpecConstr.spec_one we were giving join-points an incorrect join-arity -- this was fallout from commit c71b220491a6ae46924cc5011b80182bcc773a58 Author: Simon Peyton Jones <simonpj@microsoft.com> Date: Thu Apr 8 23:36:24 2021 +0100 Improvements in SpecConstr * Allow under-saturated calls to specialise See Note [SpecConstr call patterns] This just allows a bit more specialisation to take place. and showed up in #19780. I refactored the code to make the new function calcSpecInfo which treats join points separately. In doing this I discovered two other small bugs: * In the Var case of argToPat we were treating UnkOcc as uninteresting, but (by omission) NoOcc as interesting. As a result we were generating SpecConstr specialisations for functions with unused arguments. But the absence anlyser does that much better; doing it here just generates more code. Easily fixed. * The lifted/unlifted test in GHC.Core.Opt.WorkWrap.Utils.mkWorkerArgs was back to front (#19794). Easily fixed. * In the same function, mkWorkerArgs, we were adding an extra argument nullary join points, which isn't necessary. I added a test for this. That in turn meant I had to remove an ASSERT in CoreToStg.mkStgRhs for nullary join points, which was always bogus but now trips; I added a comment to explain.
-rw-r--r--compiler/GHC/Core/Opt/SpecConstr.hs98
-rw-r--r--compiler/GHC/Core/Opt/WorkWrap/Utils.hs47
-rw-r--r--compiler/GHC/CoreToStg.hs6
-rw-r--r--testsuite/tests/simplCore/should_compile/T18328.stderr30
-rw-r--r--testsuite/tests/simplCore/should_compile/T19780.hs100
-rw-r--r--testsuite/tests/simplCore/should_compile/T19794.hs8
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T3
7 files changed, 215 insertions, 77 deletions
diff --git a/compiler/GHC/Core/Opt/SpecConstr.hs b/compiler/GHC/Core/Opt/SpecConstr.hs
index 89d5e9fd22..7509a4cda3 100644
--- a/compiler/GHC/Core/Opt/SpecConstr.hs
+++ b/compiler/GHC/Core/Opt/SpecConstr.hs
@@ -1719,30 +1719,27 @@ spec_one env fn arg_bndrs body (call_pat, rule_number)
-- And build the results
; let spec_body_ty = exprType spec_body
- spec_lam_args1 = qvars ++ extra_bndrs
- (spec_lam_args, spec_call_args) = mkWorkerArgs False
- spec_lam_args1 spec_body_ty
+
+ (spec_lam_args1, spec_sig, spec_arity, spec_join_arity)
+ = calcSpecInfo fn call_pat extra_bndrs
+ -- Annotate the variables with the strictness information from
+ -- the function (see Note [Strictness information in worker binders])
+
+ (spec_lam_args, spec_call_args) = mkWorkerArgs fn False
+ spec_lam_args1 spec_body_ty
-- mkWorkerArgs: usual w/w hack to avoid generating
-- a spec_rhs of unlifted type and no args
- spec_str = calcSpecStrictness fn spec_lam_args pats
- spec_lam_args_str = handOutStrictnessInformation spec_str spec_lam_args
- -- Annotate the variables with the strictness information from
- -- the function (see Note [Strictness information in worker binders])
-
- spec_join_arity | isJoinId fn = Just (length spec_lam_args)
- | otherwise = Nothing
spec_id = mkLocalId spec_name Many
(mkLamTypes spec_lam_args spec_body_ty)
-- See Note [Transfer strictness]
- `setIdDmdSig` spec_str
- `setIdCprSig` topCprSig
- `setIdArity` count isId spec_lam_args
+ `setIdDmdSig` spec_sig
+ `setIdCprSig` topCprSig
+ `setIdArity` spec_arity
`asJoinId_maybe` spec_join_arity
-
-- Conditionally use result of new worker-wrapper transform
- spec_rhs = mkLams spec_lam_args_str spec_body
+ spec_rhs = mkLams spec_lam_args spec_body
rule_rhs = mkVarApps (Var spec_id) $
dropTail (length extra_bndrs) spec_call_args
inline_act = idInlineActivation fn
@@ -1755,31 +1752,46 @@ spec_one env fn arg_bndrs body (call_pat, rule_number)
, os_rhs = spec_rhs }) }
--- See Note [Strictness information in worker binders]
-handOutStrictnessInformation :: DmdSig -> [Var] -> [Var]
-handOutStrictnessInformation str vs
- = go (fst (splitDmdSig str)) vs
- where
- go _ [] = []
- go [] vs = vs
- go (d:dmds) (v:vs) | isId v = setIdDemandInfo v d : go dmds vs
- go dmds (v:vs) = v : go dmds vs
-
-calcSpecStrictness :: Id -- The original function
- -> [Var] -> [CoreExpr] -- Call pattern
- -> DmdSig -- Strictness of specialised thing
+calcSpecInfo :: Id -- The original function
+ -> CallPat -- Call pattern
+ -> [Var] -- Extra bndrs
+ -> ( [Var] -- Demand-decorated binders
+ , DmdSig -- Strictness of specialised thing
+ , Arity, Maybe JoinArity ) -- Arities of specialised thing
+-- Calcuate bits of IdInfo for the specialised function
-- See Note [Transfer strictness]
-calcSpecStrictness fn qvars pats
- = mkClosedDmdSig spec_dmds div
+-- See Note [Strictness information in worker binders]
+calcSpecInfo fn (CP { cp_qvars = qvars, cp_args = pats }) extra_bndrs
+ | isJoinId fn -- Join points have strictness and arity for LHS only
+ = ( bndrs_w_dmds
+ , mkClosedDmdSig qvar_dmds div
+ , count isId qvars
+ , Just (length qvars) )
+ | otherwise
+ = ( bndrs_w_dmds
+ , mkClosedDmdSig (qvar_dmds ++ extra_dmds) div
+ , count isId qvars + count isId extra_bndrs
+ , Nothing )
where
- spec_dmds = [ lookupVarEnv dmd_env qv `orElse` topDmd | qv <- qvars, isId qv ]
- DmdSig (DmdType _ dmds div) = idDmdSig fn
+ DmdSig (DmdType _ fn_dmds div) = idDmdSig fn
+
+ val_pats = filterOut isTypeArg pats
+ qvar_dmds = [ lookupVarEnv dmd_env qv `orElse` topDmd | qv <- qvars, isId qv ]
+ extra_dmds = dropList val_pats fn_dmds
+
+ bndrs_w_dmds = set_dmds qvars qvar_dmds
+ ++ set_dmds extra_bndrs extra_dmds
+
+ set_dmds :: [Var] -> [Demand] -> [Var]
+ set_dmds [] _ = []
+ set_dmds vs [] = vs -- Run out of demands
+ set_dmds (v:vs) ds@(d:ds') | isTyVar v = v : set_dmds vs ds
+ | otherwise = setIdDemandInfo v d : set_dmds vs ds'
- dmd_env = go emptyVarEnv dmds pats
+ dmd_env = go emptyVarEnv fn_dmds val_pats
go :: DmdEnv -> [Demand] -> [CoreExpr] -> DmdEnv
- go env ds (Type {} : pats) = go env ds pats
- go env ds (Coercion {} : pats) = go env ds pats
+ -- We've filtered out all the type patterns already
go env (d:ds) (pat : pats) = go (go_one env d pat) ds pats
go env _ _ = env
@@ -1789,7 +1801,8 @@ calcSpecStrictness fn qvars pats
| (Var _, args) <- collectArgs e
, Just ds <- viewProd (length args) cd
= go env ds args
- go_one env _ _ = env
+ go_one env _ _ = env
+
{-
Note [spec_usg includes rhs_usg]
@@ -1847,13 +1860,13 @@ The function calcSpecStrictness performs the calculation.
Note [Strictness information in worker binders]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
After having calculated the strictness annotation for the worker (see Note
[Transfer strictness] above), we also want to have this information attached to
the worker’s arguments, for the benefit of later passes. The function
handOutStrictnessInformation decomposes the strictness annotation calculated by
calcSpecStrictness and attaches them to the variables.
+
************************************************************************
* *
\subsection{Argument analysis}
@@ -2269,15 +2282,20 @@ argToPat env in_scope val_env arg arg_occ
-- Check if the argument is a variable that
-- (a) is used in an interesting way in the function body
+ --- i.e. ScrutOcc. UnkOcc and NoOcc are not interesting
+ -- (NoOcc means we could drop the argument, but that's the
+ -- business of absence analysis, not SpecConstr.)
-- (b) we know what its value is
-- In that case it counts as "interesting"
argToPat env in_scope val_env (Var v) arg_occ
- | sc_force env || case arg_occ of { UnkOcc -> False; _other -> True }, -- (a)
- is_value, -- (b)
+ | sc_force env || case arg_occ of { ScrutOcc {} -> True
+ ; UnkOcc -> False
+ ; NoOcc -> False } -- (a)
+ , is_value -- (b)
-- Ignoring sc_keen here to avoid gratuitously incurring Note [Reboxing]
-- So sc_keen focused just on f (I# x), where we have freshly-allocated
-- box that we can eliminate in the caller
- not (ignoreType env (varType v))
+ , not (ignoreType env (varType v))
= return (True, Var v)
where
is_value
diff --git a/compiler/GHC/Core/Opt/WorkWrap/Utils.hs b/compiler/GHC/Core/Opt/WorkWrap/Utils.hs
index 4ef35e9b83..5bd7bdf263 100644
--- a/compiler/GHC/Core/Opt/WorkWrap/Utils.hs
+++ b/compiler/GHC/Core/Opt/WorkWrap/Utils.hs
@@ -189,7 +189,8 @@ mkWwBodies opts rhs_fvs fun_id demands cpr_info
; (useful2, wrap_fn_cpr, work_fn_cpr, cpr_res_ty)
<- mkWWcpr_entry opts res_ty cpr_info
- ; let (work_lam_args, work_call_args) = mkWorkerArgs (wo_fun_to_thunk opts) work_args cpr_res_ty
+ ; let (work_lam_args, work_call_args) = mkWorkerArgs fun_id (wo_fun_to_thunk opts)
+ work_args cpr_res_ty
worker_args_dmds = [idDemandInfo v | v <- work_call_args, isId v]
wrapper_body = wrap_fn_args . wrap_fn_cpr . wrap_fn_str . applyToVars work_call_args . Var
worker_body = mkLams work_lam_args. work_fn_str . work_fn_cpr . work_fn_args
@@ -302,31 +303,39 @@ add a void argument. E.g.
We use the state-token type which generates no code.
-}
-mkWorkerArgs :: Bool
+mkWorkerArgs :: Id -- The wrapper Id
+ -> Bool
-> [Var]
-> Type -- Type of body
-> ([Var], -- Lambda bound args
[Var]) -- Args at call site
-mkWorkerArgs fun_to_thunk args res_ty
- | any isId args || not needsAValueLambda
- = (args, args)
- | otherwise
+mkWorkerArgs wrap_id fun_to_thunk args res_ty
+ | not (isJoinId wrap_id) -- Join Ids never need an extra arg
+ , not (any isId args) -- No existing value lambdas
+ , needs_a_value_lambda -- and we need to add one
= (args ++ [voidArgId], args ++ [voidPrimId])
+
+ | otherwise
+ = (args, args)
where
- -- See "Making wrapper args" section above
- needsAValueLambda =
- lifted
- -- We may encounter a levity-polymorphic result, in which case we
- -- conservatively assume that we have laziness that needs preservation.
- -- See #15186.
- || not fun_to_thunk
- -- see Note [Protecting the last value argument]
+ -- If fun_to_thunk is False we always keep at least one value
+ -- argument: see Note [Protecting the last value argument]
+ -- If it is True, we only need to keep a value argument if
+ -- the result type is (or might be) unlifted, in which case
+ -- dropping the last arg would mean we wrongly used call-by-value
+ needs_a_value_lambda
+ = not fun_to_thunk
+ || might_be_unlifted
-- Might the result be lifted?
- lifted =
- case isLiftedType_maybe res_ty of
- Just lifted -> lifted
- Nothing -> True
+ -- False => definitely lifted
+ -- True => might be unlifted
+ -- We may encounter a levity-polymorphic result, in which case we
+ -- conservatively assume that we have laziness that needs
+ -- preservation. See #15186.
+ might_be_unlifted = case isLiftedType_maybe res_ty of
+ Just lifted -> not lifted
+ Nothing -> True
{-
Note [Protecting the last value argument]
@@ -344,7 +353,6 @@ so f can't be inlined *under a lambda*.
Note [Join points and beta-redexes]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
Originally, the worker would invoke the original function by calling it with
arguments, thus producing a beta-redex for the simplifier to munch away:
@@ -375,7 +383,6 @@ worry about hygiene, but luckily wy is freshly generated.)
Note [Join points returning functions]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
It is crucial that the arity of a join point depends on its *callers,* not its
own syntax. What this means is that a join point can have "extra lambdas":
diff --git a/compiler/GHC/CoreToStg.hs b/compiler/GHC/CoreToStg.hs
index 9452015ab4..af8c8ae25b 100644
--- a/compiler/GHC/CoreToStg.hs
+++ b/compiler/GHC/CoreToStg.hs
@@ -771,8 +771,10 @@ mkStgRhs bndr (PreStgRhs bndrs rhs)
-- After this point we know that `bndrs` is empty,
-- so this is not a function binding
- | isJoinId bndr -- must be a nullary join point
- = ASSERT(idJoinArity bndr == 0)
+
+ | isJoinId bndr -- Must be a nullary join point
+ = -- It might have /type/ arguments (T18328),
+ -- so its JoinArity might be >0
StgRhsClosure noExtFieldSilent
currentCCS
ReEntrant -- ignored for LNE
diff --git a/testsuite/tests/simplCore/should_compile/T18328.stderr b/testsuite/tests/simplCore/should_compile/T18328.stderr
index d32f553114..78e3430b88 100644
--- a/testsuite/tests/simplCore/should_compile/T18328.stderr
+++ b/testsuite/tests/simplCore/should_compile/T18328.stderr
@@ -1,38 +1,38 @@
==================== Tidy Core ====================
Result size of Tidy Core
- = {terms: 69, types: 61, coercions: 0, joins: 1/1}
+ = {terms: 65, types: 53, coercions: 0, joins: 1/1}
--- RHS size: {terms: 42, types: 28, coercions: 0, joins: 1/1}
+-- RHS size: {terms: 38, types: 23, coercions: 0, joins: 1/1}
T18328.$wf [InlPrag=[2]]
:: forall {a}. GHC.Prim.Int# -> [a] -> [a] -> [a]
[GblId,
Arity=3,
- Str=<SU><U><U>,
+ Str=<SL><SL><ML>,
Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
- WorkFree=True, Expandable=True, Guidance=IF_ARGS [182 0 0] 312 0}]
+ WorkFree=True, Expandable=True, Guidance=IF_ARGS [176 0 0] 306 0}]
T18328.$wf
= \ (@a) (ww :: GHC.Prim.Int#) (w :: [a]) (w1 :: [a]) ->
join {
- $wj [InlPrag=NOINLINE, Dmd=1C1(U)] :: forall {p}. (# #) -> [a]
- [LclId[JoinId(2)], Arity=1, Str=<A>, Unf=OtherCon []]
- $wj (@p) _ [Occ=Dead, OS=OneShot]
+ $wj [InlPrag=NOINLINE, Dmd=ML] :: forall {p}. [a]
+ [LclId[JoinId(1)]]
+ $wj (@p)
= case ww of {
__DEFAULT -> ++ @a w (++ @a w (++ @a w w1));
3# -> ++ @a w (++ @a w (++ @a w (++ @a w w1)))
} } in
case ww of {
__DEFAULT -> ++ @a w w1;
- 1# -> jump $wj @Integer GHC.Prim.(##);
- 2# -> jump $wj @Integer GHC.Prim.(##);
- 3# -> jump $wj @Integer GHC.Prim.(##)
+ 1# -> jump $wj @Integer;
+ 2# -> jump $wj @Integer;
+ 3# -> jump $wj @Integer
}
--- RHS size: {terms: 11, types: 10, coercions: 0, joins: 0/0}
+-- RHS size: {terms: 11, types: 9, coercions: 0, joins: 0/0}
f [InlPrag=[2]] :: forall a. Int -> [a] -> [a] -> [a]
[GblId,
Arity=3,
- Str=<S(SU)><U><U>,
+ Str=<1P(SL)><SL><ML>,
Unf=Unf{Src=InlineStable, TopLvl=True, Value=True, ConLike=True,
WorkFree=True, Expandable=True,
Guidance=ALWAYS_IF(arity=3,unsat_ok=True,boring_ok=False)
@@ -40,11 +40,11 @@ f [InlPrag=[2]] :: forall a. Int -> [a] -> [a] -> [a]
(w [Occ=Once1!] :: Int)
(w1 [Occ=Once1] :: [a])
(w2 [Occ=Once1] :: [a]) ->
- case w of { GHC.Types.I# ww1 [Occ=Once1] ->
- T18328.$wf @a ww1 w1 w2
+ case w of { GHC.Types.I# ww [Occ=Once1] ->
+ T18328.$wf @a ww w1 w2
}}]
f = \ (@a) (w :: Int) (w1 :: [a]) (w2 :: [a]) ->
- case w of { GHC.Types.I# ww1 -> T18328.$wf @a ww1 w1 w2 }
+ case w of { GHC.Types.I# ww -> T18328.$wf @a ww w1 w2 }
-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
T18328.$trModule4 :: GHC.Prim.Addr#
diff --git a/testsuite/tests/simplCore/should_compile/T19780.hs b/testsuite/tests/simplCore/should_compile/T19780.hs
new file mode 100644
index 0000000000..5acc896f60
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T19780.hs
@@ -0,0 +1,100 @@
+{-# LANGUAGE BangPatterns #-}
+module Data.ByteString.Search.DFA (strictSearcher) where
+
+import qualified Data.ByteString as S
+import Data.ByteString.Unsafe (unsafeIndex)
+
+import Control.Monad (when)
+import Data.Array.Base (unsafeRead, unsafeWrite, unsafeAt)
+import Data.Array.ST (newArray, newArray_, runSTUArray)
+import Data.Array.Unboxed (UArray)
+import Data.Bits (Bits(..))
+import Data.Word (Word8)
+
+------------------------------------------------------------------------------
+-- Searching Function --
+------------------------------------------------------------------------------
+
+strictSearcher :: Bool -> S.ByteString -> S.ByteString -> [Int]
+strictSearcher _ !pat
+ | S.null pat = enumFromTo 0 . S.length
+ | S.length pat == 1 = let !w = S.head pat in S.elemIndices w
+strictSearcher !overlap pat = search
+ where
+ !patLen = S.length pat
+ !auto = automaton pat
+ !p0 = unsafeIndex pat 0
+ !ams = if overlap then patLen else 0
+ search str = match 0 0
+ where
+ !strLen = S.length str
+ {-# INLINE strAt #-}
+ strAt :: Int -> Int
+ strAt !i = fromIntegral (unsafeIndex str i)
+ match 0 idx
+ | idx == strLen = []
+ | unsafeIndex str idx == p0 = match 1 (idx + 1)
+ | otherwise = match 0 (idx + 1)
+ match state idx
+ | idx == strLen = []
+ | otherwise =
+ let !nstate = unsafeAt auto ((state `shiftL` 8) + strAt idx)
+ !nxtIdx = idx + 1
+ in if nstate == patLen
+ then (nxtIdx - patLen) : match ams nxtIdx
+ else match nstate nxtIdx
+
+------------------------------------------------------------------------------
+-- Preprocessing --
+------------------------------------------------------------------------------
+
+{-# INLINE automaton #-}
+automaton :: S.ByteString -> UArray Int Int
+automaton !pat = runSTUArray (do
+ let !patLen = S.length pat
+ {-# INLINE patAt #-}
+ patAt !i = fromIntegral (unsafeIndex pat i)
+ !bord = kmpBorders pat
+ aut <- newArray (0, (patLen + 1)*256 - 1) 0
+ unsafeWrite aut (patAt 0) 1
+ let loop !state = do
+ let !base = state `shiftL` 8
+ inner j
+ | j < 0 = if state == patLen
+ then return aut
+ else loop (state+1)
+ | otherwise = do
+ let !i = base + patAt j
+ s <- unsafeRead aut i
+ when (s == 0) (unsafeWrite aut i (j+1))
+ inner (unsafeAt bord j)
+ if state == patLen
+ then inner (unsafeAt bord state)
+ else inner state
+ loop 1)
+
+-- kmpBorders calculates the width of the widest borders of the prefixes
+-- of the pattern which are not extensible to borders of the next
+-- longer prefix. Most entries will be 0.
+{-# INLINE kmpBorders #-}
+kmpBorders :: S.ByteString -> UArray Int Int
+kmpBorders pat = runSTUArray (do
+ let !patLen = S.length pat
+ {-# INLINE patAt #-}
+ patAt :: Int -> Word8
+ patAt i = unsafeIndex pat i
+ ar <- newArray_ (0, patLen)
+ unsafeWrite ar 0 (-1)
+ let dec w j
+ | j < 0 || w == patAt j = return $! j+1
+ | otherwise = unsafeRead ar j >>= dec w
+ bordLoop !i !j
+ | patLen < i = return ar
+ | otherwise = do
+ let !w = patAt (i-1)
+ j' <- dec w j
+ if i < patLen && patAt j' == patAt i
+ then unsafeRead ar j' >>= unsafeWrite ar i
+ else unsafeWrite ar i j'
+ bordLoop (i+1) j'
+ bordLoop 1 (-1))
diff --git a/testsuite/tests/simplCore/should_compile/T19794.hs b/testsuite/tests/simplCore/should_compile/T19794.hs
new file mode 100644
index 0000000000..c8f6897468
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T19794.hs
@@ -0,0 +1,8 @@
+{-# LANGUAGE MagicHash #-}
+{-# OPTIONS_GHC -ffun-to-thunk #-} -- This is essential for the test
+
+module Foo where
+import GHC.Exts
+
+f :: Int -> Int#
+f x = f (x+1)
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index dba67fa80b..6a857e1115 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -358,3 +358,6 @@ test('T13873', [ grep_errmsg(r'SPEC') ], compile, ['-O -ddump-rules'])
# Look for a specialisation rule for wimwam
test('T19672', normal, compile, ['-O2 -ddump-rules'])
+
+test('T19780', normal, compile, ['-O2'])
+test('T19794', normal, compile, ['-O'])