summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/GHC/Tc/Solver/Monad.hs24
1 files changed, 16 insertions, 8 deletions
diff --git a/compiler/GHC/Tc/Solver/Monad.hs b/compiler/GHC/Tc/Solver/Monad.hs
index bc9680c233..063f409041 100644
--- a/compiler/GHC/Tc/Solver/Monad.hs
+++ b/compiler/GHC/Tc/Solver/Monad.hs
@@ -186,6 +186,7 @@ import GHC.Data.TrieMap
import Control.Monad
import GHC.Utils.Monad
import Data.IORef
+import GHC.Exts (oneShot)
import Data.List ( partition, mapAccumL )
import Data.List.NonEmpty ( NonEmpty(..), cons, toList, nonEmpty )
import qualified Data.List.NonEmpty as NE
@@ -2897,15 +2898,21 @@ data TcSEnv
---------------
newtype TcS a = TcS { unTcS :: TcSEnv -> TcM a } deriving (Functor)
+-- | Smart constructor for 'TcS', as describe in Note [The one-shot state
+-- monad trick] in "GHC.Utils.Monad".
+mkTcS :: (TcSEnv -> TcM a) -> TcS a
+mkTcS f = TcS (oneShot f)
+
instance Applicative TcS where
- pure x = TcS (\_ -> return x)
+ pure x = mkTcS $ \_ -> return x
(<*>) = ap
instance Monad TcS where
- m >>= k = TcS (\ebs -> unTcS m ebs >>= \r -> unTcS (k r) ebs)
+ m >>= k = mkTcS $ \ebs -> do
+ unTcS m ebs >>= (\r -> unTcS (k r) ebs)
instance MonadFail TcS where
- fail err = TcS (\_ -> fail err)
+ fail err = mkTcS $ \_ -> fail err
instance MonadUnique TcS where
getUniqueSupplyM = wrapTcS getUniqueSupplyM
@@ -2921,7 +2928,7 @@ instance MonadThings TcS where
wrapTcS :: TcM a -> TcS a
-- Do not export wrapTcS, because it promotes an arbitrary TcM to TcS,
-- and TcS is supposed to have limited functionality
-wrapTcS = TcS . const -- a TcM action will not use the TcEvBinds
+wrapTcS action = mkTcS $ \_env -> action -- a TcM action will not use the TcEvBinds
wrapErrTcS :: TcM a -> TcS a
-- The thing wrapped should just fail
@@ -2956,9 +2963,10 @@ getGlobalRdrEnvTcS :: TcS GlobalRdrEnv
getGlobalRdrEnvTcS = wrapTcS TcM.getGlobalRdrEnv
bumpStepCountTcS :: TcS ()
-bumpStepCountTcS = TcS $ \env -> do { let ref = tcs_count env
- ; n <- TcM.readTcRef ref
- ; TcM.writeTcRef ref (n+1) }
+bumpStepCountTcS = mkTcS $ \env ->
+ do { let ref = tcs_count env
+ ; n <- TcM.readTcRef ref
+ ; TcM.writeTcRef ref (n+1) }
csTraceTcS :: SDoc -> TcS ()
csTraceTcS doc
@@ -2968,7 +2976,7 @@ csTraceTcS doc
traceFireTcS :: CtEvidence -> SDoc -> TcS ()
-- Dump a rule-firing trace
traceFireTcS ev doc
- = TcS $ \env -> csTraceTcM $
+ = mkTcS $ \env -> csTraceTcM $
do { n <- TcM.readTcRef (tcs_count env)
; tclvl <- TcM.getTcLevel
; return (hang (text "Step" <+> int n