summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornineonine <mail4chemik@gmail.com>2021-11-10 00:52:06 -0800
committerMarge Bot <ben+marge-bot@smart-cactus.org>2022-03-20 21:16:06 -0400
commitc842611fc72d987519cd9fab1c351135ae93665e (patch)
treeaa4365b4050a0733887a4d4f0291e6e9d52cc801
parentd45bb70178e044bc8b6e8215da7bc8ed0c95f2cb (diff)
downloadhaskell-c842611fc72d987519cd9fab1c351135ae93665e.tar.gz
Revamp derived Eq instance code generation (#17240)
This patch improves code generation for derived Eq instances. The idea is to use 'dataToTag' to evaluate both arguments. This allows to 'short-circuit' when tags do not match. Unfortunately, inner evals are still present when we branch on tags. This is due to the way 'dataToTag#' primop evaluates its argument in the code generator. #21207 was created to explore further optimizations. Metric Decrease: LargeRecord
-rw-r--r--compiler/GHC/Tc/Deriv/Generate.hs141
-rw-r--r--testsuite/tests/deriving/should_compile/T17240.hs5
-rw-r--r--testsuite/tests/deriving/should_compile/T17240.stderr42
-rw-r--r--testsuite/tests/deriving/should_compile/all.T1
-rw-r--r--testsuite/tests/deriving/should_fail/drvfail011.stderr4
5 files changed, 132 insertions, 61 deletions
diff --git a/compiler/GHC/Tc/Deriv/Generate.hs b/compiler/GHC/Tc/Deriv/Generate.hs
index 35375bc5a5..3f8893460b 100644
--- a/compiler/GHC/Tc/Deriv/Generate.hs
+++ b/compiler/GHC/Tc/Deriv/Generate.hs
@@ -152,6 +152,12 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and
data Foo ... = N1 | N2 ... | Nn | O1 a b | O2 Int | O3 Double b b | ...
+* We first attempt to compare the constructor tags. If tags don't
+ match - we immediately bail out. Otherwise, we then generate one
+ branch per constructor comparing only the fields as we already
+ know that the tags match. Note that it only makes sense to check
+ the tag if there is more than one data constructor.
+
* For the ordinary constructors (if any), we emit clauses to do The
Usual Thing, e.g.,:
@@ -164,23 +170,29 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and
case (a1 `eqFloat#` a2) of r -> r
for that particular test.
-* For nullary constructors, we emit a
- catch-all clause of the form:
+* For nullary constructors, we emit a catch-all clause that always
+ returns True since we already know that the tags match.
+
+* So, given this data type:
+
+ data T = A | B Int | C Char
- (==) a b = case (dataToTag# a) of { a# ->
- case (dataToTag# b) of { b# ->
- case (a# ==# b#) of {
- r -> r }}}
+ We roughly get:
+
+ (==) a b =
+ case dataToTag# a /= dataToTag# b of
+ True -> False
+ False -> case a of -- Here we already know that tags match
+ B a1 -> case b of
+ B b1 -> a1 == b1 -- Only one branch
+ C a1 -> case b of
+ C b1 -> a1 == b1 -- Only one branch
+ _ -> True -- catch-all to match all nullary ctors
An older approach preferred regular pattern matches in some cases
but with dataToTag# forcing it's argument, and work on improving
join points, this seems no longer necessary.
-* If there aren't any nullary constructors, we emit a simpler
- catch-all:
-
- (==) a b = False
-
* For the @(/=)@ method, we normally just use the default method.
If the type is an enumeration type, we could/may/should? generate
special code that calls @dataToTag#@, much like for @(==)@ shown
@@ -202,58 +214,68 @@ gen_Eq_binds loc dit@(DerivInstTys{ dit_rep_tc = tycon
return (method_binds, emptyBag)
where
all_cons = getPossibleDataCons tycon tycon_args
- (nullary_cons, non_nullary_cons) = partition isNullarySrcDataCon all_cons
-
- -- For nullary constructors, use the getTag stuff.
- (tag_match_cons, pat_match_cons) = (nullary_cons, non_nullary_cons)
- no_tag_match_cons = null tag_match_cons
-
- -- (LHS patterns, result)
- fall_through_eqn :: [([LPat (GhcPass 'Parsed)] , LHsExpr GhcPs)]
- fall_through_eqn
- | no_tag_match_cons -- All constructors have arguments
- = case pat_match_cons of
- [] -> [] -- No constructors; no fall-though case
- [_] -> [] -- One constructor; no fall-though case
- _ -> -- Two or more constructors; add fall-through of
- -- (==) _ _ = False
- [([nlWildPat, nlWildPat], false_Expr)]
-
- | otherwise -- One or more tag_match cons; add fall-through of
- -- extract tags compare for equality,
- -- The case `(C1 x) == (C1 y)` can no longer happen
- -- at this point as it's matched earlier.
- = [([a_Pat, b_Pat],
- untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
- (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))]
+ non_nullary_cons = filter (not . isNullarySrcDataCon) all_cons
+
+ -- Generate tag check. See #17240
+ eq_expr_with_tag_check = nlHsCase
+ (nlHsPar (untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
+ (nlHsOpApp (nlHsVar ah_RDR) neInt_RDR (nlHsVar bh_RDR))))
+ [ mkHsCaseAlt (nlLitPat (HsIntPrim NoSourceText 1)) false_Expr
+ , mkHsCaseAlt nlWildPat (
+ nlHsCase
+ (nlHsVar a_RDR)
+ -- Only one branch to match all nullary constructors
+ -- as we already know the tags match but do not emit
+ -- the branch if there are no nullary constructors
+ (let non_nullary_pats = map pats_etc non_nullary_cons
+ in if null non_nullary_cons
+ then non_nullary_pats
+ else non_nullary_pats ++ [mkHsCaseAlt nlWildPat true_Expr]))
+ ]
method_binds = unitBag eq_bind
- eq_bind
- = mkFunBindEC 2 loc eq_RDR (const true_Expr)
- (map pats_etc pat_match_cons
- ++ fall_through_eqn)
+ eq_bind = mkFunBindEC 2 loc eq_RDR (const true_Expr) binds
+ where
+ binds
+ | null all_cons = []
+ -- Tag checking is redundant when there is only one data constructor
+ | [data_con] <- all_cons
+ , (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con
+ , data_con_RDR <- getRdrName data_con
+ , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed
+ , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed
+ , eq_expr <- nested_eq_expr tys_needed as_needed bs_needed
+ = [([con1_pat, con2_pat], eq_expr)]
+ -- This is an enum (all constructors are nullary) - just do a simple tag check
+ | all isNullarySrcDataCon all_cons
+ = [([a_Pat, b_Pat], untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
+ (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))]
+ | otherwise
+ = [([a_Pat, b_Pat], eq_expr_with_tag_check)]
------------------------------------------------------------------
- pats_etc data_con
- = let
- con1_pat = nlParPat $ nlConVarPat data_con_RDR as_needed
- con2_pat = nlParPat $ nlConVarPat data_con_RDR bs_needed
-
- data_con_RDR = getRdrName data_con
- con_arity = length tys_needed
- as_needed = take con_arity as_RDRs
- bs_needed = take con_arity bs_RDRs
- tys_needed = derivDataConInstArgTys data_con dit
- in
- ([con1_pat, con2_pat], nested_eq_expr tys_needed as_needed bs_needed)
+ nested_eq_expr [] [] [] = true_Expr
+ nested_eq_expr tys as bs
+ = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs)
+ -- Using 'foldr1' here ensures that the derived code is correctly
+ -- associated. See #10859.
where
- nested_eq_expr [] [] [] = true_Expr
- nested_eq_expr tys as bs
- = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs)
- -- Using 'foldr1' here ensures that the derived code is correctly
- -- associated. See #10859.
- where
- nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b))
+ nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b))
+
+ gen_con_fields_and_tys data_con
+ | tys_needed <- derivDataConInstArgTys data_con dit
+ , con_arity <- length tys_needed
+ , as_needed <- take con_arity as_RDRs
+ , bs_needed <- take con_arity bs_RDRs
+ = (as_needed, bs_needed, tys_needed)
+
+ pats_etc data_con
+ | (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con
+ , data_con_RDR <- getRdrName data_con
+ , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed
+ , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed
+ , fields_eq_expr <- nested_eq_expr tys_needed as_needed bs_needed
+ = mkHsCaseAlt con1_pat (nlHsCase (nlHsVar b_RDR) [mkHsCaseAlt con2_pat fields_eq_expr])
{-
************************************************************************
@@ -1473,7 +1495,7 @@ gfoldl_RDR, gunfold_RDR, toConstr_RDR, dataTypeOf_RDR, mkConstrTag_RDR,
dataCast1_RDR, dataCast2_RDR, gcast1_RDR, gcast2_RDR,
constr_RDR, dataType_RDR,
eqChar_RDR , ltChar_RDR , geChar_RDR , gtChar_RDR , leChar_RDR ,
- eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR ,
+ eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , neInt_RDR ,
eqInt8_RDR , ltInt8_RDR , geInt8_RDR , gtInt8_RDR , leInt8_RDR ,
eqInt16_RDR , ltInt16_RDR , geInt16_RDR , gtInt16_RDR , leInt16_RDR ,
eqInt32_RDR , ltInt32_RDR , geInt32_RDR , gtInt32_RDR , leInt32_RDR ,
@@ -1513,6 +1535,7 @@ gtChar_RDR = varQual_RDR gHC_PRIM (fsLit "gtChar#")
geChar_RDR = varQual_RDR gHC_PRIM (fsLit "geChar#")
eqInt_RDR = varQual_RDR gHC_PRIM (fsLit "==#")
+neInt_RDR = varQual_RDR gHC_PRIM (fsLit "/=#")
ltInt_RDR = varQual_RDR gHC_PRIM (fsLit "<#" )
leInt_RDR = varQual_RDR gHC_PRIM (fsLit "<=#")
gtInt_RDR = varQual_RDR gHC_PRIM (fsLit ">#" )
diff --git a/testsuite/tests/deriving/should_compile/T17240.hs b/testsuite/tests/deriving/should_compile/T17240.hs
new file mode 100644
index 0000000000..6b847c578a
--- /dev/null
+++ b/testsuite/tests/deriving/should_compile/T17240.hs
@@ -0,0 +1,5 @@
+module T17240 where
+
+data T = A | B Int | C Char | D Int deriving Eq
+
+data Nullary = X | Y | Z deriving Eq
diff --git a/testsuite/tests/deriving/should_compile/T17240.stderr b/testsuite/tests/deriving/should_compile/T17240.stderr
new file mode 100644
index 0000000000..cce538b59d
--- /dev/null
+++ b/testsuite/tests/deriving/should_compile/T17240.stderr
@@ -0,0 +1,42 @@
+
+==================== Derived instances ====================
+Derived class instances:
+ instance GHC.Classes.Eq T17240.Nullary where
+ (GHC.Classes.==) a b
+ = case (GHC.Prim.dataToTag# a) of
+ a#
+ -> case (GHC.Prim.dataToTag# b) of
+ b# -> (GHC.Prim.tagToEnum# (a# GHC.Prim.==# b#))
+
+ instance GHC.Classes.Eq T17240.T where
+ (GHC.Classes.==) a b
+ = case
+ (case (GHC.Prim.dataToTag# a) of
+ a# -> case (GHC.Prim.dataToTag# b) of b# -> a# GHC.Prim./=# b#)
+ of
+ 1# -> GHC.Types.False
+ _ -> case a of
+ (T17240.B a1)
+ -> case b of (T17240.B b1) -> ((a1 GHC.Classes.== b1))
+ (T17240.C a1)
+ -> case b of (T17240.C b1) -> ((a1 GHC.Classes.== b1))
+ (T17240.D a1)
+ -> case b of (T17240.D b1) -> ((a1 GHC.Classes.== b1))
+ _ -> GHC.Types.True
+
+
+Derived type family instances:
+
+
+
+==================== Filling in method body ====================
+GHC.Classes.Eq [T17240.Nullary]
+ GHC.Classes./= = GHC.Classes.$dm/= @(T17240.Nullary)
+
+
+
+==================== Filling in method body ====================
+GHC.Classes.Eq [T17240.T]
+ GHC.Classes./= = GHC.Classes.$dm/= @(T17240.T)
+
+
diff --git a/testsuite/tests/deriving/should_compile/all.T b/testsuite/tests/deriving/should_compile/all.T
index a33cb364c3..05f9e87dcb 100644
--- a/testsuite/tests/deriving/should_compile/all.T
+++ b/testsuite/tests/deriving/should_compile/all.T
@@ -122,6 +122,7 @@ test('T15831', normal, compile, [''])
test('T16179', normal, compile, [''])
test('T16341', normal, compile, [''])
test('T16518', normal, compile, [''])
+test('T17240', normal, compile, ['-ddump-deriv -dsuppress-uniques'])
test('T17324', normal, compile, [''])
test('T17339', normal, compile,
['-ddump-simpl -dsuppress-idinfo -dno-typeable-binds'])
diff --git a/testsuite/tests/deriving/should_fail/drvfail011.stderr b/testsuite/tests/deriving/should_fail/drvfail011.stderr
index d439bd03eb..5b26f5b575 100644
--- a/testsuite/tests/deriving/should_fail/drvfail011.stderr
+++ b/testsuite/tests/deriving/should_fail/drvfail011.stderr
@@ -3,8 +3,8 @@ drvfail011.hs:8:1: error:
• No instance for (Eq a) arising from a use of ‘==’
Possible fix: add (Eq a) to the context of the instance declaration
• In the expression: a1 == b1
- In an equation for ‘==’: (==) (T1 a1) (T1 b1) = ((a1 == b1))
+ In a case alternative: (T1 b1) -> ((a1 == b1))
+ In the expression: case b of (T1 b1) -> ((a1 == b1))
When typechecking the code for ‘==’
in a derived instance for ‘Eq (T a)’:
To see the code I am typechecking, use -ddump-deriv
- In the instance declaration for ‘Eq (T a)’