From 7e77aca97127a65b574417a0fb25d18e8ddc4b1e Mon Sep 17 00:00:00 2001 From: Simon Peyton Jones Date: Fri, 1 May 2020 14:48:43 +0100 Subject: Fix specialisation for DFuns When specialising a DFun we must take care to saturate the unfolding. See Note [Specialising DFuns] in Specialise. Fixes #18120 --- compiler/GHC/Core/Opt/Specialise.hs | 70 +++++++++++++++------- compiler/GHC/Core/Unfold.hs | 48 +++++++-------- compiler/GHC/HsToCore/Binds.hs | 7 +-- compiler/GHC/Tc/Gen/Sig.hs | 7 +-- testsuite/tests/simplCore/should_compile/T18120.hs | 34 +++++++++++ testsuite/tests/simplCore/should_compile/all.T | 1 + 6 files changed, 111 insertions(+), 56 deletions(-) create mode 100644 testsuite/tests/simplCore/should_compile/T18120.hs diff --git a/compiler/GHC/Core/Opt/Specialise.hs b/compiler/GHC/Core/Opt/Specialise.hs index 18173e1644..09af3d9d2d 100644 --- a/compiler/GHC/Core/Opt/Specialise.hs +++ b/compiler/GHC/Core/Opt/Specialise.hs @@ -1362,6 +1362,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs inl_prag = idInlinePragma fn inl_act = inlinePragmaActivation inl_prag is_local = isLocalId fn + is_dfun = isDFunId fn -- Figure out whether the function has an INLINE pragma -- See Note [Inline specialisations] @@ -1384,22 +1385,34 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs spec_call :: SpecInfo -- Accumulating parameter -> CallInfo -- Call instance -> SpecM SpecInfo - spec_call spec_acc@(rules_acc, pairs_acc, uds_acc) (CI { ci_key = call_args }) + spec_call spec_acc@(rules_acc, pairs_acc, uds_acc) _ci@(CI { ci_key = call_args }) = -- See Note [Specialising Calls] - do { ( useful, rhs_env2, leftover_bndrs + do { let all_call_args | is_dfun = call_args ++ repeat UnspecArg + | otherwise = call_args + -- See Note [Specialising DFuns] + ; ( useful, rhs_env2, leftover_bndrs , rule_bndrs, rule_lhs_args - , spec_bndrs, dx_binds, spec_args) <- specHeader env rhs_bndrs call_args + , spec_bndrs1, dx_binds, spec_args) <- specHeader env rhs_bndrs all_call_args + +-- ; pprTrace "spec_call" (vcat [ text "call info: " <+> ppr _ci +-- , text "useful: " <+> ppr useful +-- , text "rule_bndrs:" <+> ppr rule_bndrs +-- , text "lhs_args: " <+> ppr rule_lhs_args +-- , text "spec_bndrs:" <+> ppr spec_bndrs1 +-- , text "spec_args: " <+> ppr spec_args +-- , text "dx_binds: " <+> ppr dx_binds +-- , text "rhs_env2: " <+> ppr (se_subst rhs_env2) +-- , ppr dx_binds ]) $ +-- return () ; dflags <- getDynFlags ; if not useful -- No useful specialisation || already_covered dflags rules_acc rule_lhs_args then return spec_acc - else -- pprTrace "spec_call" (vcat [ ppr _call_info, ppr fn, ppr rhs_dict_ids - -- , text "rhs_env2" <+> ppr (se_subst rhs_env2) - -- , ppr dx_binds ]) $ + else do { -- Run the specialiser on the specialised RHS -- The "1" suffix is before we maybe add the void arg - ; (spec_rhs1, rhs_uds) <- specLam rhs_env2 (spec_bndrs ++ leftover_bndrs) rhs_body + ; (spec_rhs1, rhs_uds) <- specLam rhs_env2 (spec_bndrs1 ++ leftover_bndrs) rhs_body ; let spec_fn_ty1 = exprType spec_rhs1 -- Maybe add a void arg to the specialised function, @@ -1407,14 +1420,13 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs -- See Note [Specialisations Must Be Lifted] -- C.f. GHC.Core.Opt.WorkWrap.Utils.mkWorkerArgs add_void_arg = isUnliftedType spec_fn_ty1 && not (isJoinId fn) - (spec_rhs, spec_fn_ty, rule_rhs_args) - | add_void_arg = ( Lam voidArgId spec_rhs1 - , mkVisFunTy voidPrimTy spec_fn_ty1 - , voidPrimId : spec_bndrs) - | otherwise = (spec_rhs1, spec_fn_ty1, spec_bndrs) - - arity_decr = count isValArg rule_lhs_args - count isId rule_rhs_args - join_arity_decr = length rule_lhs_args - length rule_rhs_args + (spec_bndrs, spec_rhs, spec_fn_ty) + | add_void_arg = ( voidPrimId : spec_bndrs1 + , Lam voidArgId spec_rhs1 + , mkVisFunTy voidPrimTy spec_fn_ty1) + | otherwise = (spec_bndrs1, spec_rhs1, spec_fn_ty1) + + join_arity_decr = length rule_lhs_args - length spec_bndrs spec_join_arity | Just orig_join_arity <- isJoinId_maybe fn = Just (orig_join_arity - join_arity_decr) | otherwise @@ -1449,7 +1461,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs (idName fn) rule_bndrs rule_lhs_args - (mkVarApps (Var spec_fn) rule_rhs_args) + (mkVarApps (Var spec_fn) spec_bndrs) spec_rule = case isJoinId_maybe fn of @@ -1472,15 +1484,15 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs = (inl_prag { inl_inline = NoUserInline }, noUnfolding) | otherwise - = (inl_prag, specUnfolding dflags fn spec_bndrs spec_app arity_decr fn_unf) - - spec_app e = e `mkApps` spec_args + = (inl_prag, specUnfolding dflags spec_bndrs (`mkApps` spec_args) + rule_lhs_args fn_unf) -------------------------------------- -- Adding arity information just propagates it a bit faster -- See Note [Arity decrease] in GHC.Core.Opt.Simplify -- Copy InlinePragma information from the parent Id. -- So if f has INLINE[1] so does spec_fn + arity_decr = count isValArg rule_lhs_args - count isId spec_bndrs spec_f_w_arity = spec_fn `setIdArity` max 0 (fn_arity - arity_decr) `setInlinePragma` spec_inl_prag `setIdUnfolding` spec_unf @@ -1498,8 +1510,19 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs , spec_uds `plusUDs` uds_acc ) } } -{- Note [Specialisation Must Preserve Sharing] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +{- Note [Specialising DFuns] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +DFuns have a special sort of unfolding (DFunUnfolding), and these are +hard to specialise a DFunUnfolding to give another DFunUnfolding +unless the DFun is fully applied (#18120). So, in the case of DFunIds +we simply extend the CallKey with trailing UnspecArgs, so we'll +generate a rule that completely saturates the DFun. + +There is an ASSERT that checks this, in the DFunUnfolding case of +GHC.Core.Unfold.specUnfolding. + +Note [Specialisation Must Preserve Sharing] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Consider a function: f :: forall a. Eq a => a -> blah @@ -2089,7 +2112,7 @@ isSpecDict _ = False -- -- Specialised function helpers -- , [c, i, x] -- , [dShow1 = $dfShow dShowT2] --- , [T1, T2, dEqT1, dShow1] +-- , [T1, T2, c, i, dEqT1, dShow1] -- ) specHeader :: SpecEnv @@ -2106,12 +2129,13 @@ specHeader -- RULE helpers , [OutBndr] -- Binders for the RULE - , [CoreArg] -- Args for the LHS of the rule + , [OutExpr] -- Args for the LHS of the rule -- Specialised function helpers , [OutBndr] -- Binders for $sf , [DictBind] -- Auxiliary dictionary bindings , [OutExpr] -- Specialised arguments for unfolding + -- Same length as "args for LHS of rule" ) -- We want to specialise on type 'T1', and so we must construct a substitution diff --git a/compiler/GHC/Core/Unfold.hs b/compiler/GHC/Core/Unfold.hs index f619e36f8a..42a8974b54 100644 --- a/compiler/GHC/Core/Unfold.hs +++ b/compiler/GHC/Core/Unfold.hs @@ -173,47 +173,47 @@ mkInlinableUnfolding dflags expr where expr' = simpleOptExpr dflags expr -specUnfolding :: DynFlags -> Id -> [Var] -> (CoreExpr -> CoreExpr) -> Arity +specUnfolding :: DynFlags + -> [Var] -> (CoreExpr -> CoreExpr) + -> [CoreArg] -- LHS arguments in the RULE -> Unfolding -> Unfolding -- See Note [Specialising unfoldings] --- specUnfolding spec_bndrs spec_app arity_decrease unf --- = \spec_bndrs. spec_app( unf ) +-- specUnfolding spec_bndrs spec_args unf +-- = \spec_bndrs. unf spec_args -- -specUnfolding dflags fn spec_bndrs spec_app arity_decrease +specUnfolding dflags spec_bndrs spec_app rule_lhs_args df@(DFunUnfolding { df_bndrs = old_bndrs, df_con = con, df_args = args }) - = ASSERT2( arity_decrease == count isId old_bndrs - count isId spec_bndrs - , ppr df $$ ppr spec_bndrs $$ ppr (spec_app (Var fn)) $$ ppr arity_decrease ) + = ASSERT2( rule_lhs_args `equalLength` old_bndrs + , ppr df $$ ppr rule_lhs_args ) + -- For this ASSERT see Note [DFunUnfoldings] in GHC.Core.Opt.Specialise mkDFunUnfolding spec_bndrs con (map spec_arg args) - -- There is a hard-to-check assumption here that the spec_app has - -- enough applications to exactly saturate the old_bndrs -- For DFunUnfoldings we transform - -- \old_bndrs. MkD ... + -- \obs. MkD ... -- to - -- \new_bndrs. MkD (spec_app(\old_bndrs. )) ... ditto - -- The ASSERT checks the value part of that + -- \sbs. MkD ((\obs. ) spec_args) ... ditto where - spec_arg arg = simpleOptExpr dflags (spec_app (mkLams old_bndrs arg)) + spec_arg arg = simpleOptExpr dflags $ + spec_app (mkLams old_bndrs arg) -- The beta-redexes created by spec_app will be -- simplified away by simplOptExpr -specUnfolding dflags _ spec_bndrs spec_app arity_decrease +specUnfolding dflags spec_bndrs spec_app rule_lhs_args (CoreUnfolding { uf_src = src, uf_tmpl = tmpl , uf_is_top = top_lvl , uf_guidance = old_guidance }) | isStableSource src -- See Note [Specialising unfoldings] - , UnfWhen { ug_arity = old_arity - , ug_unsat_ok = unsat_ok - , ug_boring_ok = boring_ok } <- old_guidance - = let guidance = UnfWhen { ug_arity = old_arity - arity_decrease - , ug_unsat_ok = unsat_ok - , ug_boring_ok = boring_ok } - new_tmpl = simpleOptExpr dflags (mkLams spec_bndrs (spec_app tmpl)) - -- The beta-redexes created by spec_app will be - -- simplified away by simplOptExpr + , UnfWhen { ug_arity = old_arity } <- old_guidance + = mkCoreUnfolding src top_lvl new_tmpl + (old_guidance { ug_arity = old_arity - arity_decrease }) + where + new_tmpl = simpleOptExpr dflags $ + mkLams spec_bndrs $ + spec_app tmpl -- The beta-redexes created by spec_app + -- will besimplified away by simplOptExpr + arity_decrease = count isValArg rule_lhs_args - count isId spec_bndrs - in mkCoreUnfolding src top_lvl new_tmpl guidance -specUnfolding _ _ _ _ _ _ = noUnfolding +specUnfolding _ _ _ _ _ = noUnfolding {- Note [Specialising unfoldings] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/compiler/GHC/HsToCore/Binds.hs b/compiler/GHC/HsToCore/Binds.hs index 49a8d78215..5a9b16067d 100644 --- a/compiler/GHC/HsToCore/Binds.hs +++ b/compiler/GHC/HsToCore/Binds.hs @@ -694,20 +694,19 @@ dsSpec mb_poly_rhs (L loc (SpecPrag poly_id spec_co spec_inl)) dflags <- getDynFlags ; case decomposeRuleLhs dflags spec_bndrs ds_lhs of { Left msg -> do { warnDs NoReason msg; return Nothing } ; - Right (rule_bndrs, _fn, args) -> do + Right (rule_bndrs, _fn, rule_lhs_args) -> do { this_mod <- getModule ; let fn_unf = realIdUnfolding poly_id - spec_unf = specUnfolding dflags poly_id spec_bndrs core_app arity_decrease fn_unf + spec_unf = specUnfolding dflags spec_bndrs core_app rule_lhs_args fn_unf spec_id = mkLocalId spec_name spec_ty `setInlinePragma` inl_prag `setIdUnfolding` spec_unf - arity_decrease = count isValArg args - count isId spec_bndrs ; rule <- dsMkUserRule this_mod is_local_id (mkFastString ("SPEC " ++ showPpr dflags poly_name)) rule_act poly_name - rule_bndrs args + rule_bndrs rule_lhs_args (mkVarApps (Var spec_id) spec_bndrs) ; let spec_rhs = mkLams spec_bndrs (core_app poly_rhs) diff --git a/compiler/GHC/Tc/Gen/Sig.hs b/compiler/GHC/Tc/Gen/Sig.hs index 18582c40ed..2c716f1826 100644 --- a/compiler/GHC/Tc/Gen/Sig.hs +++ b/compiler/GHC/Tc/Gen/Sig.hs @@ -634,7 +634,6 @@ to connect the two, something like This wrapper is put in the TcSpecPrag, in the ABExport record of the AbsBinds. - f :: (Eq a, Ix b) => a -> b -> Bool {-# SPECIALISE f :: (Ix p, Ix q) => Int -> (p,q) -> Bool #-} f = @@ -662,8 +661,6 @@ Note that * The RHS of f_spec, has a *copy* of 'binds', so that it can fully specialise it. - - From the TcSpecPrag, in GHC.HsToCore.Binds we generate a binding for f_spec and a RULE: f_spec :: Int -> b -> Int @@ -702,14 +699,14 @@ Some wrinkles So we simply do this: - Generate a constraint to check that the specialised type (after - skolemiseation) is equal to the instantiated function type. + skolemisation) is equal to the instantiated function type. - But *discard* the evidence (coercion) for that constraint, so that we ultimately generate the simpler code f_spec :: Int -> F Int f_spec = Int dNumInt RULE: forall d. f Int d = f_spec - You can see this discarding happening in + You can see this discarding happening in tcSpecPrag 3. Note that the HsWrapper can transform *any* function with the right type prefix diff --git a/testsuite/tests/simplCore/should_compile/T18120.hs b/testsuite/tests/simplCore/should_compile/T18120.hs new file mode 100644 index 0000000000..0a2ea98638 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T18120.hs @@ -0,0 +1,34 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} +module Bug where + +import Data.Kind + +type family + AllF (c :: k -> Constraint) (xs :: [k]) :: Constraint where + AllF _c '[] = () + AllF c (x ': xs) = (c x, All c xs) + +class (AllF c xs, SListI xs) => All (c :: k -> Constraint) (xs :: [k]) where +instance All c '[] where +instance (c x, All c xs) => All c (x ': xs) where + +class Top x +instance Top x + +type SListI = All Top + +class All SListI (Code a) => Generic (a :: Type) where + type Code a :: [[Type]] + +data T = MkT Int +instance Generic T where + type Code T = '[ '[Int] ] diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index 71bd450040..c3db4f1b6b 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -318,3 +318,4 @@ test('T17966', test('T17810', normal, multimod_compile, ['T17810', '-fspecialise-aggressively -dcore-lint -O -v0']) test('T18013', normal, multimod_compile, ['T18013', '-v0 -O']) test('T18098', normal, compile, ['-dcore-lint -O2']) +test('T18120', normal, compile, ['-dcore-lint -O']) -- cgit v1.2.1