summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Klebinger <klebinger.andreas@gmx.at>2022-08-02 17:57:55 +0200
committerAndreas Klebinger <klebinger.andreas@gmx.at>2022-09-15 10:12:41 +0200
commitd6ea8356b721ea4c3b871a796c1b2b13f94fd471 (patch)
treea72cb96f4a7a0d8992eb996ba39c66e916caa255
parentdf04d6ec6a543d8bf1b953cf27c26e63ec6aab25 (diff)
downloadhaskell-d6ea8356b721ea4c3b871a796c1b2b13f94fd471.tar.gz
Tag inference: Fix #21954 by retaining tagsigs of vars in function position.
For an expression like: case x of y Con z -> z If we also retain the tag sig for z we can generate code to immediately return it rather than calling out to stg_ap_0_fast.
-rw-r--r--compiler/GHC/Stg/InferTags/Rewrite.hs55
-rw-r--r--testsuite/tests/simplStg/should_compile/all.T1
-rw-r--r--testsuite/tests/simplStg/should_compile/inferTags002.hs7
-rw-r--r--testsuite/tests/simplStg/should_compile/inferTags002.stderr171
4 files changed, 201 insertions, 33 deletions
diff --git a/compiler/GHC/Stg/InferTags/Rewrite.hs b/compiler/GHC/Stg/InferTags/Rewrite.hs
index ba2bbf2449..253763cc5b 100644
--- a/compiler/GHC/Stg/InferTags/Rewrite.hs
+++ b/compiler/GHC/Stg/InferTags/Rewrite.hs
@@ -336,7 +336,7 @@ rewriteRhs (_id, _tagSig) (StgRhsCon ccs con cn ticks args) = {-# SCC rewriteRhs
rewriteRhs _binding (StgRhsClosure fvs ccs flag args body) = do
withBinders NotTopLevel args $
withClosureLcls fvs $
- StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr False body
+ StgRhsClosure fvs ccs flag (map fst args) <$> rewriteExpr body
-- return (closure)
fvArgs :: [StgArg] -> RM DVarSet
@@ -345,40 +345,36 @@ fvArgs args = do
-- pprTraceM "fvArgs" (text "args:" <> ppr args $$ text "lcls:" <> pprVarSet (fv_lcls) (braces . fsep . map ppr) )
return $ mkDVarSet [ v | StgVarArg v <- args, elemVarSet v fv_lcls]
-type IsScrut = Bool
-
rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs = mapM rewriteArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg (StgVarArg v) = StgVarArg <$!> rewriteId v
rewriteArg (lit@StgLitArg{}) = return lit
--- Attach a tagSig if it's tagged
rewriteId :: Id -> RM Id
rewriteId v = do
is_tagged <- isTagged v
if is_tagged then return $! setIdTagSig v (TagSig TagProper)
else return v
-rewriteExpr :: IsScrut -> InferStgExpr -> RM TgStgExpr
-rewriteExpr _ (e@StgCase {}) = rewriteCase e
-rewriteExpr _ (e@StgLet {}) = rewriteLet e
-rewriteExpr _ (e@StgLetNoEscape {}) = rewriteLetNoEscape e
-rewriteExpr isScrut (StgTick t e) = StgTick t <$!> rewriteExpr isScrut e
-rewriteExpr _ e@(StgConApp {}) = rewriteConApp e
-
-rewriteExpr isScrut e@(StgApp {}) = rewriteApp isScrut e
-rewriteExpr _ (StgLit lit) = return $! (StgLit lit)
-rewriteExpr _ (StgOpApp op@(StgPrimOp DataToTagOp) args res_ty) = do
+rewriteExpr :: InferStgExpr -> RM TgStgExpr
+rewriteExpr (e@StgCase {}) = rewriteCase e
+rewriteExpr (e@StgLet {}) = rewriteLet e
+rewriteExpr (e@StgLetNoEscape {}) = rewriteLetNoEscape e
+rewriteExpr (StgTick t e) = StgTick t <$!> rewriteExpr e
+rewriteExpr e@(StgConApp {}) = rewriteConApp e
+rewriteExpr e@(StgApp {}) = rewriteApp e
+rewriteExpr (StgLit lit) = return $! (StgLit lit)
+rewriteExpr (StgOpApp op@(StgPrimOp DataToTagOp) args res_ty) = do
(StgOpApp op) <$!> rewriteArgs args <*> pure res_ty
-rewriteExpr _ (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty)
+rewriteExpr (StgOpApp op args res_ty) = return $! (StgOpApp op args res_ty)
rewriteCase :: InferStgExpr -> RM TgStgExpr
rewriteCase (StgCase scrut bndr alt_type alts) =
withBinder NotTopLevel bndr $
pure StgCase <*>
- rewriteExpr True scrut <*>
+ rewriteExpr scrut <*>
pure (fst bndr) <*>
pure alt_type <*>
mapM rewriteAlt alts
@@ -388,7 +384,7 @@ rewriteCase _ = panic "Impossible: nodeCase"
rewriteAlt :: InferStgAlt -> RM TgStgAlt
rewriteAlt alt@GenStgAlt{alt_con=_, alt_bndrs=bndrs, alt_rhs=rhs} =
withBinders NotTopLevel bndrs $ do
- !rhs' <- rewriteExpr False rhs
+ !rhs' <- rewriteExpr rhs
return $! alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'}
rewriteLet :: InferStgExpr -> RM TgStgExpr
@@ -396,7 +392,7 @@ rewriteLet (StgLet xt bind expr) = do
(!bind') <- rewriteBinds NotTopLevel bind
withBind NotTopLevel bind $ do
-- pprTraceM "withBindLet" (ppr $ bindersOfX bind)
- !expr' <- rewriteExpr False expr
+ !expr' <- rewriteExpr expr
return $! (StgLet xt bind' expr')
rewriteLet _ = panic "Impossible"
@@ -404,7 +400,7 @@ rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
rewriteLetNoEscape (StgLetNoEscape xt bind expr) = do
(!bind') <- rewriteBinds NotTopLevel bind
withBind NotTopLevel bind $ do
- !expr' <- rewriteExpr False expr
+ !expr' <- rewriteExpr expr
return $! (StgLetNoEscape xt bind' expr')
rewriteLetNoEscape _ = panic "Impossible"
@@ -424,19 +420,12 @@ rewriteConApp (StgConApp con cn args tys) = do
rewriteConApp _ = panic "Impossible"
--- Special case: Expressions like `case x of { ... }`
-rewriteApp :: IsScrut -> InferStgExpr -> RM TgStgExpr
-rewriteApp True (StgApp f []) = do
- -- pprTraceM "rewriteAppScrut" (ppr f)
- f_tagged <- isTagged f
- -- isTagged looks at more than the result of our analysis.
- -- So always update here if useful.
- let f' = if f_tagged
- -- TODO: We might consisder using a subst env instead of setting the sig only for select places.
- then setIdTagSig f (TagSig TagProper)
- else f
+-- Special case: Atomic binders, usually in a case context like `case f of ...`.
+rewriteApp :: InferStgExpr -> RM TgStgExpr
+rewriteApp (StgApp f []) = do
+ f' <- rewriteId f
return $! StgApp f' []
-rewriteApp _ (StgApp f args)
+rewriteApp (StgApp f args)
-- pprTrace "rewriteAppOther" (ppr f <+> ppr args) False
-- = undefined
| Just marks <- idCbvMarks_maybe f
@@ -457,8 +446,8 @@ rewriteApp _ (StgApp f args)
cbvArgIds = [x | StgVarArg x <- map fstOf3 cbvArgInfo] :: [Id]
mkSeqs args cbvArgIds (\cbv_args -> StgApp f cbv_args)
-rewriteApp _ (StgApp f args) = return $ StgApp f args
-rewriteApp _ _ = panic "Impossible"
+rewriteApp (StgApp f args) = return $ StgApp f args
+rewriteApp _ = panic "Impossible"
-- `mkSeq` x x' e generates `case x of x' -> e`
-- We could also substitute x' for x in e but that's so rarely beneficial
diff --git a/testsuite/tests/simplStg/should_compile/all.T b/testsuite/tests/simplStg/should_compile/all.T
index 8cc4c49922..c5f9162579 100644
--- a/testsuite/tests/simplStg/should_compile/all.T
+++ b/testsuite/tests/simplStg/should_compile/all.T
@@ -11,3 +11,4 @@ setTestOpts(f)
test('T13588', [ grep_errmsg('case') ] , compile, ['-dverbose-stg2stg -fno-worker-wrapper'])
test('T19717', normal, compile, ['-ddump-stg-final -dsuppress-uniques -dno-typeable-binds'])
+test('inferTags002', [ only_ways(['optasm']), grep_errmsg('(call stg\_ap\_0)', [1])], compile, ['-ddump-cmm -dsuppress-uniques -dno-typeable-binds -O'])
diff --git a/testsuite/tests/simplStg/should_compile/inferTags002.hs b/testsuite/tests/simplStg/should_compile/inferTags002.hs
new file mode 100644
index 0000000000..69145acb7a
--- /dev/null
+++ b/testsuite/tests/simplStg/should_compile/inferTags002.hs
@@ -0,0 +1,7 @@
+module M where
+
+data T a = MkT !Bool !a
+
+-- The rhs of the case alternative should not result in a call std_ap_0_fast.
+f x = case x of
+ MkT y z -> z
diff --git a/testsuite/tests/simplStg/should_compile/inferTags002.stderr b/testsuite/tests/simplStg/should_compile/inferTags002.stderr
new file mode 100644
index 0000000000..ef6979932b
--- /dev/null
+++ b/testsuite/tests/simplStg/should_compile/inferTags002.stderr
@@ -0,0 +1,171 @@
+
+==================== Output Cmm ====================
+[M.$WMkT_entry() { // [R3, R2]
+ { info_tbls: [(cym,
+ label: block_cym_info
+ rep: StackRep [False]
+ srt: Nothing),
+ (cyp,
+ label: M.$WMkT_info
+ rep: HeapRep static { Fun {arity: 2 fun_type: ArgSpec 15} }
+ srt: Nothing),
+ (cys,
+ label: block_cys_info
+ rep: StackRep [False]
+ srt: Nothing)]
+ stack_info: arg_space: 8
+ }
+ {offset
+ cyp: // global
+ if ((Sp + -16) < SpLim) (likely: False) goto cyv; else goto cyw;
+ cyv: // global
+ R1 = M.$WMkT_closure;
+ call (stg_gc_fun)(R3, R2, R1) args: 8, res: 0, upd: 8;
+ cyw: // global
+ I64[Sp - 16] = cym;
+ R1 = R2;
+ P64[Sp - 8] = R3;
+ Sp = Sp - 16;
+ if (R1 & 7 != 0) goto cym; else goto cyn;
+ cyn: // global
+ call (I64[R1])(R1) returns to cym, args: 8, res: 8, upd: 8;
+ cym: // global
+ I64[Sp] = cys;
+ _sy8::P64 = R1;
+ R1 = P64[Sp + 8];
+ P64[Sp + 8] = _sy8::P64;
+ call stg_ap_0_fast(R1) returns to cys, args: 8, res: 8, upd: 8;
+ cys: // global
+ Hp = Hp + 24;
+ if (Hp > HpLim) (likely: False) goto cyA; else goto cyz;
+ cyA: // global
+ HpAlloc = 24;
+ call stg_gc_unpt_r1(R1) returns to cys, args: 8, res: 8, upd: 8;
+ cyz: // global
+ I64[Hp - 16] = M.MkT_con_info;
+ P64[Hp - 8] = P64[Sp + 8];
+ P64[Hp] = R1;
+ R1 = Hp - 15;
+ Sp = Sp + 16;
+ call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+ }
+ },
+ section ""data" . M.$WMkT_closure" {
+ M.$WMkT_closure:
+ const M.$WMkT_info;
+ }]
+
+
+
+==================== Output Cmm ====================
+[M.f_entry() { // [R2]
+ { info_tbls: [(cyK,
+ label: block_cyK_info
+ rep: StackRep []
+ srt: Nothing),
+ (cyN,
+ label: M.f_info
+ rep: HeapRep static { Fun {arity: 1 fun_type: ArgSpec 5} }
+ srt: Nothing)]
+ stack_info: arg_space: 8
+ }
+ {offset
+ cyN: // global
+ if ((Sp + -8) < SpLim) (likely: False) goto cyO; else goto cyP;
+ cyO: // global
+ R1 = M.f_closure;
+ call (stg_gc_fun)(R2, R1) args: 8, res: 0, upd: 8;
+ cyP: // global
+ I64[Sp - 8] = cyK;
+ R1 = R2;
+ Sp = Sp - 8;
+ if (R1 & 7 != 0) goto cyK; else goto cyL;
+ cyL: // global
+ call (I64[R1])(R1) returns to cyK, args: 8, res: 8, upd: 8;
+ cyK: // global
+ R1 = P64[R1 + 15];
+ Sp = Sp + 8;
+ call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+ }
+ },
+ section ""data" . M.f_closure" {
+ M.f_closure:
+ const M.f_info;
+ }]
+
+
+
+==================== Output Cmm ====================
+[M.MkT_entry() { // [R3, R2]
+ { info_tbls: [(cz1,
+ label: block_cz1_info
+ rep: StackRep [False]
+ srt: Nothing),
+ (cz4,
+ label: M.MkT_info
+ rep: HeapRep static { Fun {arity: 2 fun_type: ArgSpec 15} }
+ srt: Nothing),
+ (cz7,
+ label: block_cz7_info
+ rep: StackRep [False]
+ srt: Nothing)]
+ stack_info: arg_space: 8
+ }
+ {offset
+ cz4: // global
+ if ((Sp + -16) < SpLim) (likely: False) goto cza; else goto czb;
+ cza: // global
+ R1 = M.MkT_closure;
+ call (stg_gc_fun)(R3, R2, R1) args: 8, res: 0, upd: 8;
+ czb: // global
+ I64[Sp - 16] = cz1;
+ R1 = R2;
+ P64[Sp - 8] = R3;
+ Sp = Sp - 16;
+ if (R1 & 7 != 0) goto cz1; else goto cz2;
+ cz2: // global
+ call (I64[R1])(R1) returns to cz1, args: 8, res: 8, upd: 8;
+ cz1: // global
+ I64[Sp] = cz7;
+ _tyf::P64 = R1;
+ R1 = P64[Sp + 8];
+ P64[Sp + 8] = _tyf::P64;
+ call stg_ap_0_fast(R1) returns to cz7, args: 8, res: 8, upd: 8;
+ cz7: // global
+ Hp = Hp + 24;
+ if (Hp > HpLim) (likely: False) goto czf; else goto cze;
+ czf: // global
+ HpAlloc = 24;
+ call stg_gc_unpt_r1(R1) returns to cz7, args: 8, res: 8, upd: 8;
+ cze: // global
+ I64[Hp - 16] = M.MkT_con_info;
+ P64[Hp - 8] = P64[Sp + 8];
+ P64[Hp] = R1;
+ R1 = Hp - 15;
+ Sp = Sp + 16;
+ call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+ }
+ },
+ section ""data" . M.MkT_closure" {
+ M.MkT_closure:
+ const M.MkT_info;
+ }]
+
+
+
+==================== Output Cmm ====================
+[M.MkT_con_entry() { // []
+ { info_tbls: [(czl,
+ label: M.MkT_con_info
+ rep: HeapRep 2 ptrs { Con {tag: 0 descr:"main:M.MkT"} }
+ srt: Nothing)]
+ stack_info: arg_space: 8
+ }
+ {offset
+ czl: // global
+ R1 = R1 + 1;
+ call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
+ }
+ }]
+
+