diff options
author | Pavol Vargovcik <pavol.vargovcik@gmail.com> | 2022-05-16 10:04:28 +0200 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2022-05-16 15:34:04 -0400 |
commit | 8dfea0789957278b99bf302dfb24078fff84b6d2 (patch) | |
tree | 1cc568b4d0d8e0c8d00466af7337ac462d59a7f4 | |
parent | 43c018aaaf15ccce215958b7e09b1e29ee7b6d40 (diff) | |
download | haskell-8dfea0789957278b99bf302dfb24078fff84b6d2.tar.gz |
TcPlugin: access to irreducible givens + fix passed ev_binds_var
-rw-r--r-- | compiler/GHC/Tc/Module.hs | 7 | ||||
-rw-r--r-- | compiler/GHC/Tc/Solver/Interact.hs | 9 | ||||
-rw-r--r-- | compiler/GHC/Tc/Solver/Monad.hs | 10 | ||||
-rw-r--r-- | compiler/GHC/Tc/Types.hs | 5 | ||||
-rw-r--r-- | docs/users_guide/extending_ghc.rst | 13 | ||||
-rw-r--r-- | testsuite/tests/tcplugins/Common.hs | 2 | ||||
-rw-r--r-- | testsuite/tests/tcplugins/CtIdPlugin.hs | 80 | ||||
-rw-r--r-- | testsuite/tests/tcplugins/Definitions.hs | 3 | ||||
-rw-r--r-- | testsuite/tests/tcplugins/TcPlugin_CtId.hs | 13 | ||||
-rw-r--r-- | testsuite/tests/tcplugins/TcPlugin_CtId.stderr | 4 | ||||
-rw-r--r-- | testsuite/tests/tcplugins/all.T | 14 |
11 files changed, 144 insertions, 16 deletions
diff --git a/compiler/GHC/Tc/Module.hs b/compiler/GHC/Tc/Module.hs index c11639725e..a6e9891b14 100644 --- a/compiler/GHC/Tc/Module.hs +++ b/compiler/GHC/Tc/Module.hs @@ -3115,9 +3115,8 @@ withTcPlugins hsc_env m = case catMaybes $ mapPlugins (hsc_plugins hsc_env) tcPlugin of [] -> m -- Common fast case plugins -> do - ev_binds_var <- newTcEvBinds (solvers, rewriters, stops) <- - unzip3 `fmap` mapM (start_plugin ev_binds_var) plugins + unzip3 `fmap` mapM start_plugin plugins let rewritersUniqFM :: UniqFM TyCon [TcPluginRewriter] !rewritersUniqFM = sequenceUFMList rewriters @@ -3131,9 +3130,9 @@ withTcPlugins hsc_env m = Left _ -> failM Right res -> return res where - start_plugin ev_binds_var (TcPlugin start solve rewrite stop) = + start_plugin (TcPlugin start solve rewrite stop) = do s <- runTcPluginM start - return (solve s ev_binds_var, rewrite s, stop s) + return (solve s, rewrite s, stop s) withDefaultingPlugins :: HscEnv -> TcM a -> TcM a withDefaultingPlugins hsc_env m = diff --git a/compiler/GHC/Tc/Solver/Interact.hs b/compiler/GHC/Tc/Solver/Interact.hs index 5adccd835c..bac38d8f0a 100644 --- a/compiler/GHC/Tc/Solver/Interact.hs +++ b/compiler/GHC/Tc/Solver/Interact.hs @@ -268,11 +268,12 @@ getTcPluginSolvers -- the plugin itself should perform this check if necessary. runTcPluginSolvers :: [TcPluginSolver] -> SplitCts -> TcS TcPluginProgress runTcPluginSolvers solvers all_cts - = foldM do_plugin initialProgress solvers + = do { ev_binds_var <- getTcEvBindsVar + ; foldM (do_plugin ev_binds_var) initialProgress solvers } where - do_plugin :: TcPluginProgress -> TcPluginSolver -> TcS TcPluginProgress - do_plugin p solver = do - result <- runTcPluginTcS (uncurry solver (pluginInputCts p)) + do_plugin :: EvBindsVar -> TcPluginProgress -> TcPluginSolver -> TcS TcPluginProgress + do_plugin ev_binds_var p solver = do + result <- runTcPluginTcS (uncurry (solver ev_binds_var) (pluginInputCts p)) return $ progress p result progress :: TcPluginProgress -> TcPluginSolveResult -> TcPluginProgress diff --git a/compiler/GHC/Tc/Solver/Monad.hs b/compiler/GHC/Tc/Solver/Monad.hs index 764f1eb454..26af2ff689 100644 --- a/compiler/GHC/Tc/Solver/Monad.hs +++ b/compiler/GHC/Tc/Solver/Monad.hs @@ -507,7 +507,8 @@ getInertGivens :: TcS [Ct] -- Returns the Given constraints in the inert set getInertGivens = do { inerts <- getInertCans - ; let all_cts = foldDicts (:) (inert_dicts inerts) + ; let all_cts = foldIrreds (:) (inert_irreds inerts) + $ foldDicts (:) (inert_dicts inerts) $ foldFunEqs (++) (inert_funeqs inerts) $ foldDVarEnv (++) [] (inert_eqs inerts) ; return (filter isGivenCt all_cts) } @@ -645,10 +646,15 @@ removeInertCt is ct = CEqCan { cc_lhs = lhs, cc_rhs = rhs } -> delEq is lhs rhs + CIrredCan {} -> is { inert_irreds = filterBag (not . eqCt ct) $ inert_irreds is } + CQuantCan {} -> panic "removeInertCt: CQuantCan" - CIrredCan {} -> panic "removeInertCt: CIrredEvCan" CNonCanonical {} -> panic "removeInertCt: CNonCanonical" +eqCt :: Ct -> Ct -> Bool +-- Equality via ctEvId +eqCt c c' = ctEvId c == ctEvId c' + -- | Looks up a family application in the inerts. lookupFamAppInert :: (CtFlavourRole -> Bool) -- can it rewrite the target? -> TyCon -> [Type] -> TcS (Maybe (Reduction, CtFlavourRole)) diff --git a/compiler/GHC/Tc/Types.hs b/compiler/GHC/Tc/Types.hs index 31e5f8ceed..c56cbc1322 100644 --- a/compiler/GHC/Tc/Types.hs +++ b/compiler/GHC/Tc/Types.hs @@ -1633,7 +1633,8 @@ Constraint Solver Plugins -- and Wanted constraints, and should return a 'TcPluginSolveResult' -- indicating which Wanted constraints it could solve, or whether any are -- insoluble. -type TcPluginSolver = [Ct] -- ^ Givens +type TcPluginSolver = EvBindsVar + -> [Ct] -- ^ Givens -> [Ct] -- ^ Wanteds -> TcPluginM TcPluginSolveResult @@ -1663,7 +1664,7 @@ data TcPlugin = forall s. TcPlugin { tcPluginInit :: TcPluginM s -- ^ Initialize plugin, when entering type-checker. - , tcPluginSolve :: s -> EvBindsVar -> TcPluginSolver + , tcPluginSolve :: s -> TcPluginSolver -- ^ Solve some constraints. -- -- This function will be invoked at two points in the constraint solving diff --git a/docs/users_guide/extending_ghc.rst b/docs/users_guide/extending_ghc.rst index 1cec248364..998ddeeec0 100644 --- a/docs/users_guide/extending_ghc.rst +++ b/docs/users_guide/extending_ghc.rst @@ -577,12 +577,12 @@ is defined thus: data TcPlugin = forall s . TcPlugin { tcPluginInit :: TcPluginM s - , tcPluginSolve :: s -> EvBindsVar -> TcPluginSolver + , tcPluginSolve :: s -> TcPluginSolver , tcPluginRewrite :: s -> UniqFM TyCon TcPluginRewriter , tcPluginStop :: s -> TcPluginM () } - type TcPluginSolver = [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult + type TcPluginSolver = EvBindsVar -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult type TcPluginRewriter = RewriteEnv -> [Ct] -> [Type] -> TcPluginM TcPluginRewriteResult @@ -652,8 +652,8 @@ The key component of a typechecker plugin is a function of type :: - solve :: [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult - solve givens wanteds = ... + solve :: EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult + solve binds givens wanteds = ... This function will be invoked in two different ways: @@ -701,6 +701,11 @@ the plugin to create equality axioms for use in evidence terms, but GHC does not check their consistency, and inconsistent axiom sets may lead to segfaults or other runtime misbehaviour. +Evidence is required also when creating new Given constraints, which are +usually implied by old ones. It is not uncommon that the evidence of a new +Given constraint contains a removed constraint: the new one has replaced the +removed one. + .. _type-family-rewriting-with-plugins: Type family rewriting with plugins diff --git a/testsuite/tests/tcplugins/Common.hs b/testsuite/tests/tcplugins/Common.hs index f2f425381d..d5eb1767d3 100644 --- a/testsuite/tests/tcplugins/Common.hs +++ b/testsuite/tests/tcplugins/Common.hs @@ -67,6 +67,7 @@ data PluginDefs = , zero :: !TyCon , succ :: !TyCon , add :: !TyCon + , ctIdFam :: !TyCon } definitionsModule :: TcPluginM Module @@ -87,6 +88,7 @@ lookupDefs = do ( promoteDataCon -> zero ) <- tcLookupDataCon =<< lookupOrig defs ( mkDataOcc "Zero" ) ( promoteDataCon -> succ ) <- tcLookupDataCon =<< lookupOrig defs ( mkDataOcc "Succ" ) add <- tcLookupTyCon =<< lookupOrig defs ( mkTcOcc "Add" ) + ctIdFam <- tcLookupTyCon =<< lookupOrig defs ( mkTcOcc "CtId" ) pure ( PluginDefs { .. } ) mkPlugin :: ( [String] -> PluginDefs -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult ) diff --git a/testsuite/tests/tcplugins/CtIdPlugin.hs b/testsuite/tests/tcplugins/CtIdPlugin.hs new file mode 100644 index 0000000000..2511c902d3 --- /dev/null +++ b/testsuite/tests/tcplugins/CtIdPlugin.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE ViewPatterns #-} + +module CtIdPlugin where + +-- base +import Data.Maybe +import Data.Traversable + +-- ghc +import GHC.Core.Class +import GHC.Core.Coercion +import GHC.Core.DataCon +import GHC.Core.Make +import GHC.Core.Predicate +import GHC.Core.TyCo.Rep +import GHC.Plugins +import GHC.Tc.Plugin +import GHC.Tc.Types +import GHC.Tc.Types.Constraint +import GHC.Tc.Types.Evidence + +-- common +import Common + +-------------------------------------------------------------------------------- + +-- This plugin simplifies Given and Wanted 'CtId' constraints. +-- To do this, we just look through the Givens and Wanteds, +-- find any irreducible constraint whose TyCon matches that of 'CtId', +-- in which case we substitute it for its argument: +-- We create a new Given or Wanted and remove the old one using a cast. + +plugin :: Plugin +plugin = mkPlugin solver don'tRewrite + +-- Solve "CtId". +solver :: [String] + -> PluginDefs -> EvBindsVar -> [Ct] -> [Ct] + -> TcPluginM TcPluginSolveResult +solver _args defs ev givens wanteds = do + let pluginCo = mkUnivCo (PluginProv "CtIdPlugin") Representational + let substEvidence ct ct' = + evCast (ctEvExpr $ ctEvidence ct') $ pluginCo (ctPred ct') (ctPred ct) + + if null wanteds + then do + newGivenPredTypes <- traverse (solveCt defs) givens + newGivens <- for (zip newGivenPredTypes givens) \case + (Nothing, _) -> return Nothing + (Just pred, ct) -> + let EvExpr expr = + evCast (ctEvExpr $ ctEvidence ct) $ pluginCo (ctPred ct) pred + in Just . mkNonCanonical <$> newGiven ev (ctLoc ct) pred expr + let removedGivens = + [ (substEvidence ct ct', ct) + | (Just ct', ct) <- zip newGivens givens + ] + pure $ TcPluginOk removedGivens (catMaybes newGivens) + else do + newWantedPredTypes <- traverse (solveCt defs) wanteds + newWanteds <- for (zip newWantedPredTypes wanteds) \case + (Nothing, _) -> return Nothing + (Just pred, ct) -> do + evidence <- newWanted (ctLoc ct) pred + return $ Just (mkNonCanonical evidence) + let removedWanteds = + [ (substEvidence ct ct', ct) + | (Just ct', ct) <- zip newWanteds wanteds + ] + pure $ TcPluginOk removedWanteds (catMaybes newWanteds) + +solveCt :: PluginDefs -> Ct -> TcPluginM (Maybe PredType) +solveCt (PluginDefs {..}) ct@(classifyPredType . ctPred -> IrredPred pred) + | Just (tyCon, [arg]) <- splitTyConApp_maybe pred + , tyCon == ctIdFam + = pure $ Just arg +solveCt _ ct = pure Nothing diff --git a/testsuite/tests/tcplugins/Definitions.hs b/testsuite/tests/tcplugins/Definitions.hs index 5a84967c07..70d04b0296 100644 --- a/testsuite/tests/tcplugins/Definitions.hs +++ b/testsuite/tests/tcplugins/Definitions.hs @@ -28,6 +28,9 @@ class MyClass a where type MyTyFam :: Type -> Type -> Type type family MyTyFam a b where +type CtId :: Constraint -> Constraint +type family CtId a where + data Nat = Zero | Succ Nat type Add :: Nat -> Nat -> Nat diff --git a/testsuite/tests/tcplugins/TcPlugin_CtId.hs b/testsuite/tests/tcplugins/TcPlugin_CtId.hs new file mode 100644 index 0000000000..14698a9f3f --- /dev/null +++ b/testsuite/tests/tcplugins/TcPlugin_CtId.hs @@ -0,0 +1,13 @@ +{-# OPTIONS_GHC -dcore-lint #-} +{-# OPTIONS_GHC -fplugin CtIdPlugin #-} + +module TcPlugin_CtId where + +import Definitions + ( CtId ) + +foo :: CtId (Num a) => a +foo = 5 + +bar :: Int +bar = foo diff --git a/testsuite/tests/tcplugins/TcPlugin_CtId.stderr b/testsuite/tests/tcplugins/TcPlugin_CtId.stderr new file mode 100644 index 0000000000..f7e7913f9f --- /dev/null +++ b/testsuite/tests/tcplugins/TcPlugin_CtId.stderr @@ -0,0 +1,4 @@ +[1 of 4] Compiling Common ( Common.hs, Common.o ) +[2 of 4] Compiling CtIdPlugin ( CtIdPlugin.hs, CtIdPlugin.o ) +[3 of 4] Compiling Definitions ( Definitions.hs, Definitions.o ) +[4 of 4] Compiling TcPlugin_CtId ( TcPlugin_CtId.hs, TcPlugin_CtId.o ) diff --git a/testsuite/tests/tcplugins/all.T b/testsuite/tests/tcplugins/all.T index 52264e83db..c371deaaa8 100644 --- a/testsuite/tests/tcplugins/all.T +++ b/testsuite/tests/tcplugins/all.T @@ -84,3 +84,17 @@ test('TcPlugin_EmitWanted' , [ 'TcPlugin_EmitWanted.hs' , '-dynamic -package ghc' if have_dynamic() else '-package ghc ' ] ) + +# See TcPlugin_CtId.hs for a description of this plugin. +test('TcPlugin_CtId' + , [ extra_files( + [ 'Definitions.hs' + , 'Common.hs' + , 'CtIdPlugin.hs' + , 'TcPlugin_CtId.hs' + ]) + ] + , multimod_compile + , [ 'TcPlugin_CtId.hs' + , '-dynamic -package ghc' if have_dynamic() else '-package ghc' ] + ) |