summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavol Vargovcik <pavol.vargovcik@gmail.com>2022-05-16 10:04:28 +0200
committerMarge Bot <ben+marge-bot@smart-cactus.org>2022-05-16 15:34:04 -0400
commit8dfea0789957278b99bf302dfb24078fff84b6d2 (patch)
tree1cc568b4d0d8e0c8d00466af7337ac462d59a7f4
parent43c018aaaf15ccce215958b7e09b1e29ee7b6d40 (diff)
downloadhaskell-8dfea0789957278b99bf302dfb24078fff84b6d2.tar.gz
TcPlugin: access to irreducible givens + fix passed ev_binds_var
-rw-r--r--compiler/GHC/Tc/Module.hs7
-rw-r--r--compiler/GHC/Tc/Solver/Interact.hs9
-rw-r--r--compiler/GHC/Tc/Solver/Monad.hs10
-rw-r--r--compiler/GHC/Tc/Types.hs5
-rw-r--r--docs/users_guide/extending_ghc.rst13
-rw-r--r--testsuite/tests/tcplugins/Common.hs2
-rw-r--r--testsuite/tests/tcplugins/CtIdPlugin.hs80
-rw-r--r--testsuite/tests/tcplugins/Definitions.hs3
-rw-r--r--testsuite/tests/tcplugins/TcPlugin_CtId.hs13
-rw-r--r--testsuite/tests/tcplugins/TcPlugin_CtId.stderr4
-rw-r--r--testsuite/tests/tcplugins/all.T14
11 files changed, 144 insertions, 16 deletions
diff --git a/compiler/GHC/Tc/Module.hs b/compiler/GHC/Tc/Module.hs
index c11639725e..a6e9891b14 100644
--- a/compiler/GHC/Tc/Module.hs
+++ b/compiler/GHC/Tc/Module.hs
@@ -3115,9 +3115,8 @@ withTcPlugins hsc_env m =
case catMaybes $ mapPlugins (hsc_plugins hsc_env) tcPlugin of
[] -> m -- Common fast case
plugins -> do
- ev_binds_var <- newTcEvBinds
(solvers, rewriters, stops) <-
- unzip3 `fmap` mapM (start_plugin ev_binds_var) plugins
+ unzip3 `fmap` mapM start_plugin plugins
let
rewritersUniqFM :: UniqFM TyCon [TcPluginRewriter]
!rewritersUniqFM = sequenceUFMList rewriters
@@ -3131,9 +3130,9 @@ withTcPlugins hsc_env m =
Left _ -> failM
Right res -> return res
where
- start_plugin ev_binds_var (TcPlugin start solve rewrite stop) =
+ start_plugin (TcPlugin start solve rewrite stop) =
do s <- runTcPluginM start
- return (solve s ev_binds_var, rewrite s, stop s)
+ return (solve s, rewrite s, stop s)
withDefaultingPlugins :: HscEnv -> TcM a -> TcM a
withDefaultingPlugins hsc_env m =
diff --git a/compiler/GHC/Tc/Solver/Interact.hs b/compiler/GHC/Tc/Solver/Interact.hs
index 5adccd835c..bac38d8f0a 100644
--- a/compiler/GHC/Tc/Solver/Interact.hs
+++ b/compiler/GHC/Tc/Solver/Interact.hs
@@ -268,11 +268,12 @@ getTcPluginSolvers
-- the plugin itself should perform this check if necessary.
runTcPluginSolvers :: [TcPluginSolver] -> SplitCts -> TcS TcPluginProgress
runTcPluginSolvers solvers all_cts
- = foldM do_plugin initialProgress solvers
+ = do { ev_binds_var <- getTcEvBindsVar
+ ; foldM (do_plugin ev_binds_var) initialProgress solvers }
where
- do_plugin :: TcPluginProgress -> TcPluginSolver -> TcS TcPluginProgress
- do_plugin p solver = do
- result <- runTcPluginTcS (uncurry solver (pluginInputCts p))
+ do_plugin :: EvBindsVar -> TcPluginProgress -> TcPluginSolver -> TcS TcPluginProgress
+ do_plugin ev_binds_var p solver = do
+ result <- runTcPluginTcS (uncurry (solver ev_binds_var) (pluginInputCts p))
return $ progress p result
progress :: TcPluginProgress -> TcPluginSolveResult -> TcPluginProgress
diff --git a/compiler/GHC/Tc/Solver/Monad.hs b/compiler/GHC/Tc/Solver/Monad.hs
index 764f1eb454..26af2ff689 100644
--- a/compiler/GHC/Tc/Solver/Monad.hs
+++ b/compiler/GHC/Tc/Solver/Monad.hs
@@ -507,7 +507,8 @@ getInertGivens :: TcS [Ct]
-- Returns the Given constraints in the inert set
getInertGivens
= do { inerts <- getInertCans
- ; let all_cts = foldDicts (:) (inert_dicts inerts)
+ ; let all_cts = foldIrreds (:) (inert_irreds inerts)
+ $ foldDicts (:) (inert_dicts inerts)
$ foldFunEqs (++) (inert_funeqs inerts)
$ foldDVarEnv (++) [] (inert_eqs inerts)
; return (filter isGivenCt all_cts) }
@@ -645,10 +646,15 @@ removeInertCt is ct =
CEqCan { cc_lhs = lhs, cc_rhs = rhs } -> delEq is lhs rhs
+ CIrredCan {} -> is { inert_irreds = filterBag (not . eqCt ct) $ inert_irreds is }
+
CQuantCan {} -> panic "removeInertCt: CQuantCan"
- CIrredCan {} -> panic "removeInertCt: CIrredEvCan"
CNonCanonical {} -> panic "removeInertCt: CNonCanonical"
+eqCt :: Ct -> Ct -> Bool
+-- Equality via ctEvId
+eqCt c c' = ctEvId c == ctEvId c'
+
-- | Looks up a family application in the inerts.
lookupFamAppInert :: (CtFlavourRole -> Bool) -- can it rewrite the target?
-> TyCon -> [Type] -> TcS (Maybe (Reduction, CtFlavourRole))
diff --git a/compiler/GHC/Tc/Types.hs b/compiler/GHC/Tc/Types.hs
index 31e5f8ceed..c56cbc1322 100644
--- a/compiler/GHC/Tc/Types.hs
+++ b/compiler/GHC/Tc/Types.hs
@@ -1633,7 +1633,8 @@ Constraint Solver Plugins
-- and Wanted constraints, and should return a 'TcPluginSolveResult'
-- indicating which Wanted constraints it could solve, or whether any are
-- insoluble.
-type TcPluginSolver = [Ct] -- ^ Givens
+type TcPluginSolver = EvBindsVar
+ -> [Ct] -- ^ Givens
-> [Ct] -- ^ Wanteds
-> TcPluginM TcPluginSolveResult
@@ -1663,7 +1664,7 @@ data TcPlugin = forall s. TcPlugin
{ tcPluginInit :: TcPluginM s
-- ^ Initialize plugin, when entering type-checker.
- , tcPluginSolve :: s -> EvBindsVar -> TcPluginSolver
+ , tcPluginSolve :: s -> TcPluginSolver
-- ^ Solve some constraints.
--
-- This function will be invoked at two points in the constraint solving
diff --git a/docs/users_guide/extending_ghc.rst b/docs/users_guide/extending_ghc.rst
index 1cec248364..998ddeeec0 100644
--- a/docs/users_guide/extending_ghc.rst
+++ b/docs/users_guide/extending_ghc.rst
@@ -577,12 +577,12 @@ is defined thus:
data TcPlugin = forall s . TcPlugin
{ tcPluginInit :: TcPluginM s
- , tcPluginSolve :: s -> EvBindsVar -> TcPluginSolver
+ , tcPluginSolve :: s -> TcPluginSolver
, tcPluginRewrite :: s -> UniqFM TyCon TcPluginRewriter
, tcPluginStop :: s -> TcPluginM ()
}
- type TcPluginSolver = [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
+ type TcPluginSolver = EvBindsVar -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
type TcPluginRewriter = RewriteEnv -> [Ct] -> [Type] -> TcPluginM TcPluginRewriteResult
@@ -652,8 +652,8 @@ The key component of a typechecker plugin is a function of type
::
- solve :: [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
- solve givens wanteds = ...
+ solve :: EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
+ solve binds givens wanteds = ...
This function will be invoked in two different ways:
@@ -701,6 +701,11 @@ the plugin to create equality axioms for use in evidence terms, but GHC
does not check their consistency, and inconsistent axiom sets may lead
to segfaults or other runtime misbehaviour.
+Evidence is required also when creating new Given constraints, which are
+usually implied by old ones. It is not uncommon that the evidence of a new
+Given constraint contains a removed constraint: the new one has replaced the
+removed one.
+
.. _type-family-rewriting-with-plugins:
Type family rewriting with plugins
diff --git a/testsuite/tests/tcplugins/Common.hs b/testsuite/tests/tcplugins/Common.hs
index f2f425381d..d5eb1767d3 100644
--- a/testsuite/tests/tcplugins/Common.hs
+++ b/testsuite/tests/tcplugins/Common.hs
@@ -67,6 +67,7 @@ data PluginDefs =
, zero :: !TyCon
, succ :: !TyCon
, add :: !TyCon
+ , ctIdFam :: !TyCon
}
definitionsModule :: TcPluginM Module
@@ -87,6 +88,7 @@ lookupDefs = do
( promoteDataCon -> zero ) <- tcLookupDataCon =<< lookupOrig defs ( mkDataOcc "Zero" )
( promoteDataCon -> succ ) <- tcLookupDataCon =<< lookupOrig defs ( mkDataOcc "Succ" )
add <- tcLookupTyCon =<< lookupOrig defs ( mkTcOcc "Add" )
+ ctIdFam <- tcLookupTyCon =<< lookupOrig defs ( mkTcOcc "CtId" )
pure ( PluginDefs { .. } )
mkPlugin :: ( [String] -> PluginDefs -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult )
diff --git a/testsuite/tests/tcplugins/CtIdPlugin.hs b/testsuite/tests/tcplugins/CtIdPlugin.hs
new file mode 100644
index 0000000000..2511c902d3
--- /dev/null
+++ b/testsuite/tests/tcplugins/CtIdPlugin.hs
@@ -0,0 +1,80 @@
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE BlockArguments #-}
+{-# LANGUAGE ViewPatterns #-}
+
+module CtIdPlugin where
+
+-- base
+import Data.Maybe
+import Data.Traversable
+
+-- ghc
+import GHC.Core.Class
+import GHC.Core.Coercion
+import GHC.Core.DataCon
+import GHC.Core.Make
+import GHC.Core.Predicate
+import GHC.Core.TyCo.Rep
+import GHC.Plugins
+import GHC.Tc.Plugin
+import GHC.Tc.Types
+import GHC.Tc.Types.Constraint
+import GHC.Tc.Types.Evidence
+
+-- common
+import Common
+
+--------------------------------------------------------------------------------
+
+-- This plugin simplifies Given and Wanted 'CtId' constraints.
+-- To do this, we just look through the Givens and Wanteds,
+-- find any irreducible constraint whose TyCon matches that of 'CtId',
+-- in which case we substitute it for its argument:
+-- We create a new Given or Wanted and remove the old one using a cast.
+
+plugin :: Plugin
+plugin = mkPlugin solver don'tRewrite
+
+-- Solve "CtId".
+solver :: [String]
+ -> PluginDefs -> EvBindsVar -> [Ct] -> [Ct]
+ -> TcPluginM TcPluginSolveResult
+solver _args defs ev givens wanteds = do
+ let pluginCo = mkUnivCo (PluginProv "CtIdPlugin") Representational
+ let substEvidence ct ct' =
+ evCast (ctEvExpr $ ctEvidence ct') $ pluginCo (ctPred ct') (ctPred ct)
+
+ if null wanteds
+ then do
+ newGivenPredTypes <- traverse (solveCt defs) givens
+ newGivens <- for (zip newGivenPredTypes givens) \case
+ (Nothing, _) -> return Nothing
+ (Just pred, ct) ->
+ let EvExpr expr =
+ evCast (ctEvExpr $ ctEvidence ct) $ pluginCo (ctPred ct) pred
+ in Just . mkNonCanonical <$> newGiven ev (ctLoc ct) pred expr
+ let removedGivens =
+ [ (substEvidence ct ct', ct)
+ | (Just ct', ct) <- zip newGivens givens
+ ]
+ pure $ TcPluginOk removedGivens (catMaybes newGivens)
+ else do
+ newWantedPredTypes <- traverse (solveCt defs) wanteds
+ newWanteds <- for (zip newWantedPredTypes wanteds) \case
+ (Nothing, _) -> return Nothing
+ (Just pred, ct) -> do
+ evidence <- newWanted (ctLoc ct) pred
+ return $ Just (mkNonCanonical evidence)
+ let removedWanteds =
+ [ (substEvidence ct ct', ct)
+ | (Just ct', ct) <- zip newWanteds wanteds
+ ]
+ pure $ TcPluginOk removedWanteds (catMaybes newWanteds)
+
+solveCt :: PluginDefs -> Ct -> TcPluginM (Maybe PredType)
+solveCt (PluginDefs {..}) ct@(classifyPredType . ctPred -> IrredPred pred)
+ | Just (tyCon, [arg]) <- splitTyConApp_maybe pred
+ , tyCon == ctIdFam
+ = pure $ Just arg
+solveCt _ ct = pure Nothing
diff --git a/testsuite/tests/tcplugins/Definitions.hs b/testsuite/tests/tcplugins/Definitions.hs
index 5a84967c07..70d04b0296 100644
--- a/testsuite/tests/tcplugins/Definitions.hs
+++ b/testsuite/tests/tcplugins/Definitions.hs
@@ -28,6 +28,9 @@ class MyClass a where
type MyTyFam :: Type -> Type -> Type
type family MyTyFam a b where
+type CtId :: Constraint -> Constraint
+type family CtId a where
+
data Nat = Zero | Succ Nat
type Add :: Nat -> Nat -> Nat
diff --git a/testsuite/tests/tcplugins/TcPlugin_CtId.hs b/testsuite/tests/tcplugins/TcPlugin_CtId.hs
new file mode 100644
index 0000000000..14698a9f3f
--- /dev/null
+++ b/testsuite/tests/tcplugins/TcPlugin_CtId.hs
@@ -0,0 +1,13 @@
+{-# OPTIONS_GHC -dcore-lint #-}
+{-# OPTIONS_GHC -fplugin CtIdPlugin #-}
+
+module TcPlugin_CtId where
+
+import Definitions
+ ( CtId )
+
+foo :: CtId (Num a) => a
+foo = 5
+
+bar :: Int
+bar = foo
diff --git a/testsuite/tests/tcplugins/TcPlugin_CtId.stderr b/testsuite/tests/tcplugins/TcPlugin_CtId.stderr
new file mode 100644
index 0000000000..f7e7913f9f
--- /dev/null
+++ b/testsuite/tests/tcplugins/TcPlugin_CtId.stderr
@@ -0,0 +1,4 @@
+[1 of 4] Compiling Common ( Common.hs, Common.o )
+[2 of 4] Compiling CtIdPlugin ( CtIdPlugin.hs, CtIdPlugin.o )
+[3 of 4] Compiling Definitions ( Definitions.hs, Definitions.o )
+[4 of 4] Compiling TcPlugin_CtId ( TcPlugin_CtId.hs, TcPlugin_CtId.o )
diff --git a/testsuite/tests/tcplugins/all.T b/testsuite/tests/tcplugins/all.T
index 52264e83db..c371deaaa8 100644
--- a/testsuite/tests/tcplugins/all.T
+++ b/testsuite/tests/tcplugins/all.T
@@ -84,3 +84,17 @@ test('TcPlugin_EmitWanted'
, [ 'TcPlugin_EmitWanted.hs'
, '-dynamic -package ghc' if have_dynamic() else '-package ghc ' ]
)
+
+# See TcPlugin_CtId.hs for a description of this plugin.
+test('TcPlugin_CtId'
+ , [ extra_files(
+ [ 'Definitions.hs'
+ , 'Common.hs'
+ , 'CtIdPlugin.hs'
+ , 'TcPlugin_CtId.hs'
+ ])
+ ]
+ , multimod_compile
+ , [ 'TcPlugin_CtId.hs'
+ , '-dynamic -package ghc' if have_dynamic() else '-package ghc' ]
+ )