diff options
author | nineonine <mail4chemik@gmail.com> | 2021-11-10 00:52:06 -0800 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2022-03-20 21:16:06 -0400 |
commit | c842611fc72d987519cd9fab1c351135ae93665e (patch) | |
tree | aa4365b4050a0733887a4d4f0291e6e9d52cc801 | |
parent | d45bb70178e044bc8b6e8215da7bc8ed0c95f2cb (diff) | |
download | haskell-c842611fc72d987519cd9fab1c351135ae93665e.tar.gz |
Revamp derived Eq instance code generation (#17240)
This patch improves code generation for derived Eq instances.
The idea is to use 'dataToTag' to evaluate both arguments.
This allows to 'short-circuit' when tags do not match.
Unfortunately, inner evals are still present when we branch
on tags. This is due to the way 'dataToTag#' primop
evaluates its argument in the code generator. #21207 was
created to explore further optimizations.
Metric Decrease:
LargeRecord
-rw-r--r-- | compiler/GHC/Tc/Deriv/Generate.hs | 141 | ||||
-rw-r--r-- | testsuite/tests/deriving/should_compile/T17240.hs | 5 | ||||
-rw-r--r-- | testsuite/tests/deriving/should_compile/T17240.stderr | 42 | ||||
-rw-r--r-- | testsuite/tests/deriving/should_compile/all.T | 1 | ||||
-rw-r--r-- | testsuite/tests/deriving/should_fail/drvfail011.stderr | 4 |
5 files changed, 132 insertions, 61 deletions
diff --git a/compiler/GHC/Tc/Deriv/Generate.hs b/compiler/GHC/Tc/Deriv/Generate.hs index 35375bc5a5..3f8893460b 100644 --- a/compiler/GHC/Tc/Deriv/Generate.hs +++ b/compiler/GHC/Tc/Deriv/Generate.hs @@ -152,6 +152,12 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and data Foo ... = N1 | N2 ... | Nn | O1 a b | O2 Int | O3 Double b b | ... +* We first attempt to compare the constructor tags. If tags don't + match - we immediately bail out. Otherwise, we then generate one + branch per constructor comparing only the fields as we already + know that the tags match. Note that it only makes sense to check + the tag if there is more than one data constructor. + * For the ordinary constructors (if any), we emit clauses to do The Usual Thing, e.g.,: @@ -164,23 +170,29 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and case (a1 `eqFloat#` a2) of r -> r for that particular test. -* For nullary constructors, we emit a - catch-all clause of the form: +* For nullary constructors, we emit a catch-all clause that always + returns True since we already know that the tags match. + +* So, given this data type: + + data T = A | B Int | C Char - (==) a b = case (dataToTag# a) of { a# -> - case (dataToTag# b) of { b# -> - case (a# ==# b#) of { - r -> r }}} + We roughly get: + + (==) a b = + case dataToTag# a /= dataToTag# b of + True -> False + False -> case a of -- Here we already know that tags match + B a1 -> case b of + B b1 -> a1 == b1 -- Only one branch + C a1 -> case b of + C b1 -> a1 == b1 -- Only one branch + _ -> True -- catch-all to match all nullary ctors An older approach preferred regular pattern matches in some cases but with dataToTag# forcing it's argument, and work on improving join points, this seems no longer necessary. -* If there aren't any nullary constructors, we emit a simpler - catch-all: - - (==) a b = False - * For the @(/=)@ method, we normally just use the default method. If the type is an enumeration type, we could/may/should? generate special code that calls @dataToTag#@, much like for @(==)@ shown @@ -202,58 +214,68 @@ gen_Eq_binds loc dit@(DerivInstTys{ dit_rep_tc = tycon return (method_binds, emptyBag) where all_cons = getPossibleDataCons tycon tycon_args - (nullary_cons, non_nullary_cons) = partition isNullarySrcDataCon all_cons - - -- For nullary constructors, use the getTag stuff. - (tag_match_cons, pat_match_cons) = (nullary_cons, non_nullary_cons) - no_tag_match_cons = null tag_match_cons - - -- (LHS patterns, result) - fall_through_eqn :: [([LPat (GhcPass 'Parsed)] , LHsExpr GhcPs)] - fall_through_eqn - | no_tag_match_cons -- All constructors have arguments - = case pat_match_cons of - [] -> [] -- No constructors; no fall-though case - [_] -> [] -- One constructor; no fall-though case - _ -> -- Two or more constructors; add fall-through of - -- (==) _ _ = False - [([nlWildPat, nlWildPat], false_Expr)] - - | otherwise -- One or more tag_match cons; add fall-through of - -- extract tags compare for equality, - -- The case `(C1 x) == (C1 y)` can no longer happen - -- at this point as it's matched earlier. - = [([a_Pat, b_Pat], - untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)] - (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))] + non_nullary_cons = filter (not . isNullarySrcDataCon) all_cons + + -- Generate tag check. See #17240 + eq_expr_with_tag_check = nlHsCase + (nlHsPar (untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)] + (nlHsOpApp (nlHsVar ah_RDR) neInt_RDR (nlHsVar bh_RDR)))) + [ mkHsCaseAlt (nlLitPat (HsIntPrim NoSourceText 1)) false_Expr + , mkHsCaseAlt nlWildPat ( + nlHsCase + (nlHsVar a_RDR) + -- Only one branch to match all nullary constructors + -- as we already know the tags match but do not emit + -- the branch if there are no nullary constructors + (let non_nullary_pats = map pats_etc non_nullary_cons + in if null non_nullary_cons + then non_nullary_pats + else non_nullary_pats ++ [mkHsCaseAlt nlWildPat true_Expr])) + ] method_binds = unitBag eq_bind - eq_bind - = mkFunBindEC 2 loc eq_RDR (const true_Expr) - (map pats_etc pat_match_cons - ++ fall_through_eqn) + eq_bind = mkFunBindEC 2 loc eq_RDR (const true_Expr) binds + where + binds + | null all_cons = [] + -- Tag checking is redundant when there is only one data constructor + | [data_con] <- all_cons + , (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con + , data_con_RDR <- getRdrName data_con + , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed + , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed + , eq_expr <- nested_eq_expr tys_needed as_needed bs_needed + = [([con1_pat, con2_pat], eq_expr)] + -- This is an enum (all constructors are nullary) - just do a simple tag check + | all isNullarySrcDataCon all_cons + = [([a_Pat, b_Pat], untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)] + (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))] + | otherwise + = [([a_Pat, b_Pat], eq_expr_with_tag_check)] ------------------------------------------------------------------ - pats_etc data_con - = let - con1_pat = nlParPat $ nlConVarPat data_con_RDR as_needed - con2_pat = nlParPat $ nlConVarPat data_con_RDR bs_needed - - data_con_RDR = getRdrName data_con - con_arity = length tys_needed - as_needed = take con_arity as_RDRs - bs_needed = take con_arity bs_RDRs - tys_needed = derivDataConInstArgTys data_con dit - in - ([con1_pat, con2_pat], nested_eq_expr tys_needed as_needed bs_needed) + nested_eq_expr [] [] [] = true_Expr + nested_eq_expr tys as bs + = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs) + -- Using 'foldr1' here ensures that the derived code is correctly + -- associated. See #10859. where - nested_eq_expr [] [] [] = true_Expr - nested_eq_expr tys as bs - = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs) - -- Using 'foldr1' here ensures that the derived code is correctly - -- associated. See #10859. - where - nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b)) + nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b)) + + gen_con_fields_and_tys data_con + | tys_needed <- derivDataConInstArgTys data_con dit + , con_arity <- length tys_needed + , as_needed <- take con_arity as_RDRs + , bs_needed <- take con_arity bs_RDRs + = (as_needed, bs_needed, tys_needed) + + pats_etc data_con + | (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con + , data_con_RDR <- getRdrName data_con + , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed + , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed + , fields_eq_expr <- nested_eq_expr tys_needed as_needed bs_needed + = mkHsCaseAlt con1_pat (nlHsCase (nlHsVar b_RDR) [mkHsCaseAlt con2_pat fields_eq_expr]) {- ************************************************************************ @@ -1473,7 +1495,7 @@ gfoldl_RDR, gunfold_RDR, toConstr_RDR, dataTypeOf_RDR, mkConstrTag_RDR, dataCast1_RDR, dataCast2_RDR, gcast1_RDR, gcast2_RDR, constr_RDR, dataType_RDR, eqChar_RDR , ltChar_RDR , geChar_RDR , gtChar_RDR , leChar_RDR , - eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , + eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , neInt_RDR , eqInt8_RDR , ltInt8_RDR , geInt8_RDR , gtInt8_RDR , leInt8_RDR , eqInt16_RDR , ltInt16_RDR , geInt16_RDR , gtInt16_RDR , leInt16_RDR , eqInt32_RDR , ltInt32_RDR , geInt32_RDR , gtInt32_RDR , leInt32_RDR , @@ -1513,6 +1535,7 @@ gtChar_RDR = varQual_RDR gHC_PRIM (fsLit "gtChar#") geChar_RDR = varQual_RDR gHC_PRIM (fsLit "geChar#") eqInt_RDR = varQual_RDR gHC_PRIM (fsLit "==#") +neInt_RDR = varQual_RDR gHC_PRIM (fsLit "/=#") ltInt_RDR = varQual_RDR gHC_PRIM (fsLit "<#" ) leInt_RDR = varQual_RDR gHC_PRIM (fsLit "<=#") gtInt_RDR = varQual_RDR gHC_PRIM (fsLit ">#" ) diff --git a/testsuite/tests/deriving/should_compile/T17240.hs b/testsuite/tests/deriving/should_compile/T17240.hs new file mode 100644 index 0000000000..6b847c578a --- /dev/null +++ b/testsuite/tests/deriving/should_compile/T17240.hs @@ -0,0 +1,5 @@ +module T17240 where + +data T = A | B Int | C Char | D Int deriving Eq + +data Nullary = X | Y | Z deriving Eq diff --git a/testsuite/tests/deriving/should_compile/T17240.stderr b/testsuite/tests/deriving/should_compile/T17240.stderr new file mode 100644 index 0000000000..cce538b59d --- /dev/null +++ b/testsuite/tests/deriving/should_compile/T17240.stderr @@ -0,0 +1,42 @@ + +==================== Derived instances ==================== +Derived class instances: + instance GHC.Classes.Eq T17240.Nullary where + (GHC.Classes.==) a b + = case (GHC.Prim.dataToTag# a) of + a# + -> case (GHC.Prim.dataToTag# b) of + b# -> (GHC.Prim.tagToEnum# (a# GHC.Prim.==# b#)) + + instance GHC.Classes.Eq T17240.T where + (GHC.Classes.==) a b + = case + (case (GHC.Prim.dataToTag# a) of + a# -> case (GHC.Prim.dataToTag# b) of b# -> a# GHC.Prim./=# b#) + of + 1# -> GHC.Types.False + _ -> case a of + (T17240.B a1) + -> case b of (T17240.B b1) -> ((a1 GHC.Classes.== b1)) + (T17240.C a1) + -> case b of (T17240.C b1) -> ((a1 GHC.Classes.== b1)) + (T17240.D a1) + -> case b of (T17240.D b1) -> ((a1 GHC.Classes.== b1)) + _ -> GHC.Types.True + + +Derived type family instances: + + + +==================== Filling in method body ==================== +GHC.Classes.Eq [T17240.Nullary] + GHC.Classes./= = GHC.Classes.$dm/= @(T17240.Nullary) + + + +==================== Filling in method body ==================== +GHC.Classes.Eq [T17240.T] + GHC.Classes./= = GHC.Classes.$dm/= @(T17240.T) + + diff --git a/testsuite/tests/deriving/should_compile/all.T b/testsuite/tests/deriving/should_compile/all.T index a33cb364c3..05f9e87dcb 100644 --- a/testsuite/tests/deriving/should_compile/all.T +++ b/testsuite/tests/deriving/should_compile/all.T @@ -122,6 +122,7 @@ test('T15831', normal, compile, ['']) test('T16179', normal, compile, ['']) test('T16341', normal, compile, ['']) test('T16518', normal, compile, ['']) +test('T17240', normal, compile, ['-ddump-deriv -dsuppress-uniques']) test('T17324', normal, compile, ['']) test('T17339', normal, compile, ['-ddump-simpl -dsuppress-idinfo -dno-typeable-binds']) diff --git a/testsuite/tests/deriving/should_fail/drvfail011.stderr b/testsuite/tests/deriving/should_fail/drvfail011.stderr index d439bd03eb..5b26f5b575 100644 --- a/testsuite/tests/deriving/should_fail/drvfail011.stderr +++ b/testsuite/tests/deriving/should_fail/drvfail011.stderr @@ -3,8 +3,8 @@ drvfail011.hs:8:1: error: • No instance for (Eq a) arising from a use of ‘==’ Possible fix: add (Eq a) to the context of the instance declaration • In the expression: a1 == b1 - In an equation for ‘==’: (==) (T1 a1) (T1 b1) = ((a1 == b1)) + In a case alternative: (T1 b1) -> ((a1 == b1)) + In the expression: case b of (T1 b1) -> ((a1 == b1)) When typechecking the code for ‘==’ in a derived instance for ‘Eq (T a)’: To see the code I am typechecking, use -ddump-deriv - In the instance declaration for ‘Eq (T a)’ |