summaryrefslogtreecommitdiff
path: root/testsuite/tests/tcplugins/Common.hs
blob: d5eb1767d3ea8ae4752250dfd4f6045a868d4cfb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}

module Common
  ( PluginDefs(..)
  , mkPlugin
  , don'tSolve, don'tRewrite
  )
  where

-- ghc
import GHC.Core.Class
  ( Class )
import GHC.Core.DataCon
  ( promoteDataCon )
import GHC.Core.TyCon
  ( TyCon )
import GHC.Core.Type
  ( Type
  , mkTyConApp
  )
import GHC.Plugins
  ( Plugin(..)
  , defaultPlugin, purePlugin
  )
import GHC.Tc.Plugin
  ( TcPluginM
  , findImportedModule, lookupOrig
  , tcLookupClass, tcLookupDataCon, tcLookupTyCon
  )
import GHC.Tc.Types
  ( TcPlugin(..), TcPluginSolveResult(..), TcPluginRewriteResult(..)
  , TcPluginRewriter
  )
import GHC.Tc.Types.Constraint
  ( Ct )
import GHC.Tc.Types.Evidence
  ( EvBindsVar )
import GHC.Types.Name.Occurrence
  ( mkClsOcc, mkDataOcc, mkTcOcc )
import GHC.Types.Unique.FM
  ( UniqFM, emptyUFM )
import GHC.Types.PkgQual
import GHC.Unit.Finder
  ( FindResult(..) )
import GHC.Unit.Module
  ( Module
  , mkModuleName
  )

--------------------------------------------------------------------------------

-- This module defines some common operations so that each individual plugin
-- doesn't have to do the same work over again:
--
--   - lookup the names of things the plugins will use
--     (the definitions are shared between most type-checking plugin tests)
--   - create a type-checking plugin from a solver, taking care of passing
--     the relevant data to the solver stage.

data PluginDefs =
  PluginDefs
    { nullary  :: !Class
    , myClass  :: !Class
    , myTyFam  :: !TyCon
    , nat      :: !Type
    , zero     :: !TyCon
    , succ     :: !TyCon
    , add      :: !TyCon
    , ctIdFam  :: !TyCon
    }

definitionsModule :: TcPluginM Module
definitionsModule = do
  findResult <- findImportedModule ( mkModuleName "Definitions" ) NoPkgQual
  case findResult of
    Found _ res     -> pure res
    FoundMultiple _ -> error $ "TcPlugin test: found multiple modules named 'Definitions'."
    _               -> error $ "TcPlugin test: could not find any module named 'Defintiions'."

lookupDefs :: TcPluginM PluginDefs
lookupDefs = do
  defs <- definitionsModule
  nullary                       <- tcLookupClass   =<< lookupOrig defs ( mkClsOcc  "Nullary"  )
  myClass                       <- tcLookupClass   =<< lookupOrig defs ( mkClsOcc  "MyClass"  )
  myTyFam                       <- tcLookupTyCon   =<< lookupOrig defs ( mkTcOcc   "MyTyFam"  )
  ( (`mkTyConApp` []) -> nat  ) <- tcLookupTyCon   =<< lookupOrig defs ( mkTcOcc   "Nat"      )
  ( promoteDataCon    -> zero ) <- tcLookupDataCon =<< lookupOrig defs ( mkDataOcc "Zero"     )
  ( promoteDataCon    -> succ ) <- tcLookupDataCon =<< lookupOrig defs ( mkDataOcc "Succ"     )
  add                           <- tcLookupTyCon   =<< lookupOrig defs ( mkTcOcc   "Add"      )
  ctIdFam                       <- tcLookupTyCon   =<< lookupOrig defs ( mkTcOcc   "CtId"     )
  pure ( PluginDefs { .. } )

mkPlugin :: ( [String] -> PluginDefs -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult )
         -> ( [String] -> PluginDefs -> UniqFM TyCon TcPluginRewriter )
         -> Plugin
mkPlugin solve rewrite =
  defaultPlugin
    { tcPlugin        = \ args -> Just $ mkTcPlugin ( solve args ) ( rewrite args )
    , pluginRecompile = purePlugin
    }

mkTcPlugin :: ( PluginDefs -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult )
           -> ( PluginDefs -> UniqFM TyCon TcPluginRewriter )
           -> TcPlugin
mkTcPlugin solve rewrite =
  TcPlugin
    { tcPluginInit    = lookupDefs
    , tcPluginSolve   = solve
    , tcPluginRewrite = rewrite
    , tcPluginStop    = \ _ -> pure ()
    }

don'tSolve :: [String] -> s -> EvBindsVar -> [Ct] -> [Ct] -> TcPluginM TcPluginSolveResult
don'tSolve _ _ _ _ _ = pure $ TcPluginOk [] []

don'tRewrite :: [String] -> s -> UniqFM TyCon TcPluginRewriter
don'tRewrite _ _ = emptyUFM