summaryrefslogtreecommitdiff
path: root/testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs
diff options
context:
space:
mode:
Diffstat (limited to 'testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs')
-rw-r--r--testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs112
1 files changed, 112 insertions, 0 deletions
diff --git a/testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs b/testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs
new file mode 100644
index 0000000000..62071b9cfb
--- /dev/null
+++ b/testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs
@@ -0,0 +1,112 @@
+{-# LANGUAGE MultiParamTypeClasses, KindSignatures, FlexibleInstances, DataKinds, PatternSynonyms, StandaloneDeriving, GeneralizedNewtypeDeriving, PolyKinds #-}
+{-# OPTIONS -Wno-orphans #-}
+module DefaultLifted(DefaultType,plugin) where
+import GHC.Plugins
+import GHC.Tc.Types.Constraint
+import GHC.Tc.Plugin
+import GHC.Core.InstEnv
+import GHC.Tc.Solver (approximateWC)
+import GHC.Unit.Finder (findPluginModule)
+import GHC.Driver.Config.Finder (initFinderOpts)
+import Data.List
+import GHC.Tc.Types
+import qualified Data.Map as M
+import Control.Monad (liftM2)
+import GHC.Tc.Utils.TcType
+
+class DefaultType x (y :: x)
+
+instance Eq Type where
+ (==) = eqType
+instance Ord Type where
+ compare = nonDetCmpType
+instance Semigroup (TcPluginM [a]) where
+ (<>) = liftM2 (++)
+instance Monoid (TcPluginM [a]) where
+ mempty = pure mempty
+
+plugin :: Plugin
+plugin = defaultPlugin {
+ defaultingPlugin = install,
+ pluginRecompile = purePlugin
+ }
+
+install :: p -> Maybe GHC.Tc.Types.DefaultingPlugin
+install _ = Just $ DefaultingPlugin { dePluginInit = initialize
+ , dePluginRun = run
+ , dePluginStop = stop
+ }
+
+pattern FoundModule :: Module -> FindResult
+pattern FoundModule a <- Found _ a
+fr_mod :: a -> a
+fr_mod = id
+
+lookupModule :: ModuleName -- ^ Name of the module
+ -> TcPluginM Module
+lookupModule mod_nm = do
+ hsc_env <- getTopEnv
+ let dflags = hsc_dflags hsc_env
+ let fopts = initFinderOpts dflags
+ let fc = hsc_FC hsc_env
+ let units = hsc_units hsc_env
+ let home_unit = hsc_home_unit hsc_env
+ -- found_module <- findPluginModule fc fopts units home_unit mod_name
+ found_module <- tcPluginIO $ findPluginModule fc fopts units home_unit mod_nm
+ case found_module of
+ FoundModule h -> return (fr_mod h)
+ _ -> do
+ found_module' <- findImportedModule mod_nm $ Just $ fsLit "this"
+ case found_module' of
+ FoundModule h -> return (fr_mod h)
+ _ -> panicDoc "Unable to resolve module looked up by plugin: "
+ (ppr mod_nm)
+
+data PluginState = PluginState { defaultClassName :: Name }
+
+-- | Find a 'Name' in a 'Module' given an 'OccName'
+lookupName :: Module -> OccName -> TcPluginM Name
+lookupName md occ = lookupOrig md occ
+
+solveDefaultType :: PluginState -> [Ct] -> TcPluginM DefaultingPluginResult
+solveDefaultType _ [] = return []
+solveDefaultType state wanteds = do
+ envs <- getInstEnvs
+ insts <- classInstances envs <$> tcLookupClass (defaultClassName state)
+ let defaults =
+ foldl' (\m inst ->
+ case is_tys inst of
+ [matchty, replacety] -> M.insertWith (++) matchty [replacety] m
+ _ -> error "Unsupported defaulting type")
+ M.empty insts
+ let groups =
+ foldl' (\m wanted ->
+ foldl' (\m' var -> M.insertWith (++) var [wanted] m')
+ m
+ (filter (isVariableDefaultable defaults) $ tyCoVarsOfCtList wanted))
+ M.empty wanteds
+ M.foldMapWithKey (\var cts ->
+ case M.lookup (tyVarKind var) defaults of
+ Nothing -> error "Bug, we already checked that this variable has a default"
+ Just deftys -> do
+ pure [DefaultingProposal var deftys cts])
+ groups
+ where isVariableDefaultable defaults v = isAmbiguousTyVar v && M.member (tyVarKind v) defaults
+
+lookupDefaultTypes :: TcPluginM PluginState
+lookupDefaultTypes = do
+ md <- lookupModule (mkModuleName "DefaultLifted")
+ name <- lookupName md (mkTcOcc "DefaultType")
+ pure $ PluginState { defaultClassName = name }
+
+initialize :: TcPluginM PluginState
+initialize = do
+ lookupDefaultTypes
+
+run :: PluginState -> WantedConstraints -> TcPluginM DefaultingPluginResult
+run s ws = do
+ solveDefaultType s (ctsElts $ approximateWC False ws)
+
+stop :: Monad m => p -> m ()
+stop _ = do
+ return ()