diff options
author | Andreas Klebinger <klebinger.andreas@gmx.at> | 2022-08-02 17:57:55 +0200 |
---|---|---|
committer | Andreas Klebinger <klebinger.andreas@gmx.at> | 2022-09-15 10:12:41 +0200 |
commit | d6ea8356b721ea4c3b871a796c1b2b13f94fd471 (patch) | |
tree | a72cb96f4a7a0d8992eb996ba39c66e916caa255 | |
parent | df04d6ec6a543d8bf1b953cf27c26e63ec6aab25 (diff) | |
download | haskell-d6ea8356b721ea4c3b871a796c1b2b13f94fd471.tar.gz |
Tag inference: Fix #21954 by retaining tagsigs of vars in function position.
For an expression like:
case x of y
Con z -> z
If we also retain the tag sig for z we can generate code to immediately return
it rather than calling out to stg_ap_0_fast.
-rw-r--r-- | compiler/GHC/Stg/InferTags/Rewrite.hs | 55 | ||||
-rw-r--r-- | testsuite/tests/simplStg/should_compile/all.T | 1 | ||||
-rw-r--r-- | testsuite/tests/simplStg/should_compile/inferTags002.hs | 7 | ||||
-rw-r--r-- | testsuite/tests/simplStg/should_compile/inferTags002.stderr | 171 |
4 files changed, 201 insertions, 33 deletions
diff --git a/compiler/GHC/Stg/InferTags/Rewrite.hs b/compiler/GHC/Stg/InferTags/Rewrite.hs index ba2bbf2449..253763cc5b 100644 --- a/compiler/GHC/Stg/InferTags/Rewrite.hs +++ b/compiler/GHC/Stg/InferTags/Rewrite.hs @@ -336,7 +336,7 @@ rewriteRhs (_id, _tagSig) (StgRhsCon ccs con cn ticks args) = {-# SCC rewriteRhs rewriteRhs _binding (StgRhsClosure fvs ccs flag args body) = do withBinders NotTopLevel args $ withClosureLcls fvs $ - StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr False body + StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr body -- return (closure) fvArgs :: [StgArg] -> RM DVarSet @@ -345,40 +345,36 @@ fvArgs args = do -- pprTraceM "fvArgs" (text "args:" <> ppr args $$ text "lcls:" <> pprVarSet (fv_lcls) (braces . fsep . map ppr) ) return $ mkDVarSet [ v | StgVarArg v <- args, elemVarSet v fv_lcls] -type IsScrut = Bool - rewriteArgs :: [StgArg] -> RM [StgArg] rewriteArgs = mapM rewriteArg rewriteArg :: StgArg -> RM StgArg rewriteArg (StgVarArg v) = StgVarArg <$!> rewriteId v rewriteArg (lit@StgLitArg{}) = return lit --- Attach a tagSig if it's tagged rewriteId :: Id -> RM Id rewriteId v = do is_tagged <- isTagged v if is_tagged then return $! setIdTagSig v (TagSig TagProper) else return v -rewriteExpr :: IsScrut -> InferStgExpr -> RM TgStgExpr -rewriteExpr _ (e@StgCase {}) = rewriteCase e -rewriteExpr _ (e@StgLet {}) = rewriteLet e -rewriteExpr _ (e@StgLetNoEscape {}) = rewriteLetNoEscape e -rewriteExpr isScrut (StgTick t e) = StgTick t <$!> rewriteExpr isScrut e -rewriteExpr _ e@(StgConApp {}) = rewriteConApp e - -rewriteExpr isScrut e@(StgApp {}) = rewriteApp isScrut e -rewriteExpr _ (StgLit lit) = return $! (StgLit lit) -rewriteExpr _ (StgOpApp op@(StgPrimOp DataToTagOp) args res_ty) = do +rewriteExpr :: InferStgExpr -> RM TgStgExpr +rewriteExpr (e@StgCase {}) = rewriteCase e +rewriteExpr (e@StgLet {}) = rewriteLet e +rewriteExpr (e@StgLetNoEscape {}) = rewriteLetNoEscape e +rewriteExpr (StgTick t e) = StgTick t <$!> rewriteExpr e +rewriteExpr e@(StgConApp {}) = rewriteConApp e +rewriteExpr e@(StgApp {}) = rewriteApp e +rewriteExpr (StgLit lit) = return $! (StgLit lit) +rewriteExpr (StgOpApp op@(StgPrimOp DataToTagOp) args res_ty) = do (StgOpApp op) <$!> rewriteArgs args <*> pure res_ty -rewriteExpr _ (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty) +rewriteExpr (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty) rewriteCase :: InferStgExpr -> RM TgStgExpr rewriteCase (StgCase scrut bndr alt_type alts) = withBinder NotTopLevel bndr $ pure StgCase <*> - rewriteExpr True scrut <*> + rewriteExpr scrut <*> pure (fst bndr) <*> pure alt_type <*> mapM rewriteAlt alts @@ -388,7 +384,7 @@ rewriteCase _ = panic "Impossible: nodeCase" rewriteAlt :: InferStgAlt -> RM TgStgAlt rewriteAlt alt@GenStgAlt{alt_con=_, alt_bndrs=bndrs, alt_rhs=rhs} = withBinders NotTopLevel bndrs $ do - !rhs' <- rewriteExpr False rhs + !rhs' <- rewriteExpr rhs return $! alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'} rewriteLet :: InferStgExpr -> RM TgStgExpr @@ -396,7 +392,7 @@ rewriteLet (StgLet xt bind expr) = do (!bind') <- rewriteBinds NotTopLevel bind withBind NotTopLevel bind $ do -- pprTraceM "withBindLet" (ppr $ bindersOfX bind) - !expr' <- rewriteExpr False expr + !expr' <- rewriteExpr expr return $! (StgLet xt bind' expr') rewriteLet _ = panic "Impossible" @@ -404,7 +400,7 @@ rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr rewriteLetNoEscape (StgLetNoEscape xt bind expr) = do (!bind') <- rewriteBinds NotTopLevel bind withBind NotTopLevel bind $ do - !expr' <- rewriteExpr False expr + !expr' <- rewriteExpr expr return $! (StgLetNoEscape xt bind' expr') rewriteLetNoEscape _ = panic "Impossible" @@ -424,19 +420,12 @@ rewriteConApp (StgConApp con cn args tys) = do rewriteConApp _ = panic "Impossible" --- Special case: Expressions like `case x of { ... }` -rewriteApp :: IsScrut -> InferStgExpr -> RM TgStgExpr -rewriteApp True (StgApp f []) = do - -- pprTraceM "rewriteAppScrut" (ppr f) - f_tagged <- isTagged f - -- isTagged looks at more than the result of our analysis. - -- So always update here if useful. - let f' = if f_tagged - -- TODO: We might consisder using a subst env instead of setting the sig only for select places. - then setIdTagSig f (TagSig TagProper) - else f +-- Special case: Atomic binders, usually in a case context like `case f of ...`. +rewriteApp :: InferStgExpr -> RM TgStgExpr +rewriteApp (StgApp f []) = do + f' <- rewriteId f return $! StgApp f' [] -rewriteApp _ (StgApp f args) +rewriteApp (StgApp f args) -- pprTrace "rewriteAppOther" (ppr f <+> ppr args) False -- = undefined | Just marks <- idCbvMarks_maybe f @@ -457,8 +446,8 @@ rewriteApp _ (StgApp f args) cbvArgIds = [x | StgVarArg x <- map fstOf3 cbvArgInfo] :: [Id] mkSeqs args cbvArgIds (\cbv_args -> StgApp f cbv_args) -rewriteApp _ (StgApp f args) = return $ StgApp f args -rewriteApp _ _ = panic "Impossible" +rewriteApp (StgApp f args) = return $ StgApp f args +rewriteApp _ = panic "Impossible" -- `mkSeq` x x' e generates `case x of x' -> e` -- We could also substitute x' for x in e but that's so rarely beneficial diff --git a/testsuite/tests/simplStg/should_compile/all.T b/testsuite/tests/simplStg/should_compile/all.T index 8cc4c49922..c5f9162579 100644 --- a/testsuite/tests/simplStg/should_compile/all.T +++ b/testsuite/tests/simplStg/should_compile/all.T @@ -11,3 +11,4 @@ setTestOpts(f) test('T13588', [ grep_errmsg('case') ] , compile, ['-dverbose-stg2stg -fno-worker-wrapper']) test('T19717', normal, compile, ['-ddump-stg-final -dsuppress-uniques -dno-typeable-binds']) +test('inferTags002', [ only_ways(['optasm']), grep_errmsg('(call stg\_ap\_0)', [1])], compile, ['-ddump-cmm -dsuppress-uniques -dno-typeable-binds -O']) diff --git a/testsuite/tests/simplStg/should_compile/inferTags002.hs b/testsuite/tests/simplStg/should_compile/inferTags002.hs new file mode 100644 index 0000000000..69145acb7a --- /dev/null +++ b/testsuite/tests/simplStg/should_compile/inferTags002.hs @@ -0,0 +1,7 @@ +module M where + +data T a = MkT !Bool !a + +-- The rhs of the case alternative should not result in a call std_ap_0_fast. +f x = case x of + MkT y z -> z diff --git a/testsuite/tests/simplStg/should_compile/inferTags002.stderr b/testsuite/tests/simplStg/should_compile/inferTags002.stderr new file mode 100644 index 0000000000..ef6979932b --- /dev/null +++ b/testsuite/tests/simplStg/should_compile/inferTags002.stderr @@ -0,0 +1,171 @@ + +==================== Output Cmm ==================== +[M.$WMkT_entry() { // [R3, R2] + { info_tbls: [(cym, + label: block_cym_info + rep: StackRep [False] + srt: Nothing), + (cyp, + label: M.$WMkT_info + rep: HeapRep static { Fun {arity: 2 fun_type: ArgSpec 15} } + srt: Nothing), + (cys, + label: block_cys_info + rep: StackRep [False] + srt: Nothing)] + stack_info: arg_space: 8 + } + {offset + cyp: // global + if ((Sp + -16) < SpLim) (likely: False) goto cyv; else goto cyw; + cyv: // global + R1 = M.$WMkT_closure; + call (stg_gc_fun)(R3, R2, R1) args: 8, res: 0, upd: 8; + cyw: // global + I64[Sp - 16] = cym; + R1 = R2; + P64[Sp - 8] = R3; + Sp = Sp - 16; + if (R1 & 7 != 0) goto cym; else goto cyn; + cyn: // global + call (I64[R1])(R1) returns to cym, args: 8, res: 8, upd: 8; + cym: // global + I64[Sp] = cys; + _sy8::P64 = R1; + R1 = P64[Sp + 8]; + P64[Sp + 8] = _sy8::P64; + call stg_ap_0_fast(R1) returns to cys, args: 8, res: 8, upd: 8; + cys: // global + Hp = Hp + 24; + if (Hp > HpLim) (likely: False) goto cyA; else goto cyz; + cyA: // global + HpAlloc = 24; + call stg_gc_unpt_r1(R1) returns to cys, args: 8, res: 8, upd: 8; + cyz: // global + I64[Hp - 16] = M.MkT_con_info; + P64[Hp - 8] = P64[Sp + 8]; + P64[Hp] = R1; + R1 = Hp - 15; + Sp = Sp + 16; + call (P64[Sp])(R1) args: 8, res: 0, upd: 8; + } + }, + section ""data" . M.$WMkT_closure" { + M.$WMkT_closure: + const M.$WMkT_info; + }] + + + +==================== Output Cmm ==================== +[M.f_entry() { // [R2] + { info_tbls: [(cyK, + label: block_cyK_info + rep: StackRep [] + srt: Nothing), + (cyN, + label: M.f_info + rep: HeapRep static { Fun {arity: 1 fun_type: ArgSpec 5} } + srt: Nothing)] + stack_info: arg_space: 8 + } + {offset + cyN: // global + if ((Sp + -8) < SpLim) (likely: False) goto cyO; else goto cyP; + cyO: // global + R1 = M.f_closure; + call (stg_gc_fun)(R2, R1) args: 8, res: 0, upd: 8; + cyP: // global + I64[Sp - 8] = cyK; + R1 = R2; + Sp = Sp - 8; + if (R1 & 7 != 0) goto cyK; else goto cyL; + cyL: // global + call (I64[R1])(R1) returns to cyK, args: 8, res: 8, upd: 8; + cyK: // global + R1 = P64[R1 + 15]; + Sp = Sp + 8; + call (P64[Sp])(R1) args: 8, res: 0, upd: 8; + } + }, + section ""data" . M.f_closure" { + M.f_closure: + const M.f_info; + }] + + + +==================== Output Cmm ==================== +[M.MkT_entry() { // [R3, R2] + { info_tbls: [(cz1, + label: block_cz1_info + rep: StackRep [False] + srt: Nothing), + (cz4, + label: M.MkT_info + rep: HeapRep static { Fun {arity: 2 fun_type: ArgSpec 15} } + srt: Nothing), + (cz7, + label: block_cz7_info + rep: StackRep [False] + srt: Nothing)] + stack_info: arg_space: 8 + } + {offset + cz4: // global + if ((Sp + -16) < SpLim) (likely: False) goto cza; else goto czb; + cza: // global + R1 = M.MkT_closure; + call (stg_gc_fun)(R3, R2, R1) args: 8, res: 0, upd: 8; + czb: // global + I64[Sp - 16] = cz1; + R1 = R2; + P64[Sp - 8] = R3; + Sp = Sp - 16; + if (R1 & 7 != 0) goto cz1; else goto cz2; + cz2: // global + call (I64[R1])(R1) returns to cz1, args: 8, res: 8, upd: 8; + cz1: // global + I64[Sp] = cz7; + _tyf::P64 = R1; + R1 = P64[Sp + 8]; + P64[Sp + 8] = _tyf::P64; + call stg_ap_0_fast(R1) returns to cz7, args: 8, res: 8, upd: 8; + cz7: // global + Hp = Hp + 24; + if (Hp > HpLim) (likely: False) goto czf; else goto cze; + czf: // global + HpAlloc = 24; + call stg_gc_unpt_r1(R1) returns to cz7, args: 8, res: 8, upd: 8; + cze: // global + I64[Hp - 16] = M.MkT_con_info; + P64[Hp - 8] = P64[Sp + 8]; + P64[Hp] = R1; + R1 = Hp - 15; + Sp = Sp + 16; + call (P64[Sp])(R1) args: 8, res: 0, upd: 8; + } + }, + section ""data" . M.MkT_closure" { + M.MkT_closure: + const M.MkT_info; + }] + + + +==================== Output Cmm ==================== +[M.MkT_con_entry() { // [] + { info_tbls: [(czl, + label: M.MkT_con_info + rep: HeapRep 2 ptrs { Con {tag: 0 descr:"main:M.MkT"} } + srt: Nothing)] + stack_info: arg_space: 8 + } + {offset + czl: // global + R1 = R1 + 1; + call (P64[Sp])(R1) args: 8, res: 0, upd: 8; + } + }] + + |