summaryrefslogtreecommitdiff
path: root/testsuite/tests/plugins/defaulting-plugin/DefaultLifted.hs
blob: 62071b9cfbe00b1cf7aacc9bba7ea4e495670b8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
{-# LANGUAGE MultiParamTypeClasses, KindSignatures, FlexibleInstances, DataKinds, PatternSynonyms, StandaloneDeriving, GeneralizedNewtypeDeriving, PolyKinds #-}
{-# OPTIONS -Wno-orphans #-}
module DefaultLifted(DefaultType,plugin) where
import GHC.Plugins
import GHC.Tc.Types.Constraint
import GHC.Tc.Plugin
import GHC.Core.InstEnv
import GHC.Tc.Solver (approximateWC)
import GHC.Unit.Finder (findPluginModule)
import GHC.Driver.Config.Finder (initFinderOpts)
import Data.List
import GHC.Tc.Types
import qualified Data.Map as M
import Control.Monad (liftM2)
import GHC.Tc.Utils.TcType

class DefaultType x (y :: x)

instance Eq Type where
  (==) = eqType
instance Ord Type where
  compare = nonDetCmpType
instance Semigroup (TcPluginM [a]) where
  (<>) = liftM2 (++)
instance Monoid (TcPluginM [a]) where
  mempty = pure mempty

plugin :: Plugin
plugin = defaultPlugin {
  defaultingPlugin = install,
  pluginRecompile = purePlugin
  }

install :: p -> Maybe GHC.Tc.Types.DefaultingPlugin
install _ = Just $ DefaultingPlugin { dePluginInit = initialize
                                    , dePluginRun  = run
                                    , dePluginStop = stop
                                    }

pattern FoundModule :: Module -> FindResult
pattern FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod = id

lookupModule :: ModuleName -- ^ Name of the module
             -> TcPluginM Module
lookupModule mod_nm = do
  hsc_env <- getTopEnv
  let dflags    = hsc_dflags hsc_env
  let fopts     = initFinderOpts dflags
  let fc        = hsc_FC hsc_env
  let units     = hsc_units hsc_env
  let home_unit = hsc_home_unit hsc_env
  -- found_module <- findPluginModule fc fopts units home_unit mod_name
  found_module <- tcPluginIO $ findPluginModule fc fopts units home_unit mod_nm
  case found_module of
    FoundModule h -> return (fr_mod h)
    _          -> do
      found_module' <- findImportedModule mod_nm $ Just $ fsLit "this"
      case found_module' of
        FoundModule h -> return (fr_mod h)
        _          -> panicDoc "Unable to resolve module looked up by plugin: "
                               (ppr mod_nm)

data PluginState = PluginState { defaultClassName :: Name }

-- | Find a 'Name' in a 'Module' given an 'OccName'
lookupName :: Module -> OccName -> TcPluginM Name
lookupName md occ = lookupOrig md occ

solveDefaultType :: PluginState -> [Ct] -> TcPluginM DefaultingPluginResult
solveDefaultType _     []      = return []
solveDefaultType state wanteds = do
  envs <- getInstEnvs
  insts <- classInstances envs <$> tcLookupClass (defaultClassName state)
  let defaults =
        foldl' (\m inst ->
                 case is_tys inst of
                   [matchty, replacety] -> M.insertWith (++) matchty [replacety] m
                   _ -> error "Unsupported defaulting type")
               M.empty insts
  let groups =
        foldl' (\m wanted ->
                  foldl' (\m' var -> M.insertWith (++) var [wanted] m')
                         m
                         (filter (isVariableDefaultable defaults) $ tyCoVarsOfCtList wanted))
               M.empty wanteds
  M.foldMapWithKey (\var cts ->
                    case M.lookup (tyVarKind var) defaults of
                      Nothing -> error "Bug, we already checked that this variable has a default"
                      Just deftys -> do
                        pure [DefaultingProposal var deftys cts])
    groups
  where isVariableDefaultable defaults v = isAmbiguousTyVar v && M.member (tyVarKind v) defaults

lookupDefaultTypes :: TcPluginM PluginState
lookupDefaultTypes = do
    md   <- lookupModule (mkModuleName "DefaultLifted")
    name <- lookupName md (mkTcOcc "DefaultType")
    pure $ PluginState { defaultClassName = name }

initialize :: TcPluginM PluginState
initialize = do
  lookupDefaultTypes

run :: PluginState -> WantedConstraints -> TcPluginM DefaultingPluginResult
run s ws = do
  solveDefaultType s (ctsElts $ approximateWC False ws)

stop :: Monad m => p -> m ()
stop _ = do
  return ()