diff options
-rw-r--r-- | compiler/GHC/Tc/Solver/Monad.hs | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/compiler/GHC/Tc/Solver/Monad.hs b/compiler/GHC/Tc/Solver/Monad.hs index bc9680c233..063f409041 100644 --- a/compiler/GHC/Tc/Solver/Monad.hs +++ b/compiler/GHC/Tc/Solver/Monad.hs @@ -186,6 +186,7 @@ import GHC.Data.TrieMap import Control.Monad import GHC.Utils.Monad import Data.IORef +import GHC.Exts (oneShot) import Data.List ( partition, mapAccumL ) import Data.List.NonEmpty ( NonEmpty(..), cons, toList, nonEmpty ) import qualified Data.List.NonEmpty as NE @@ -2897,15 +2898,21 @@ data TcSEnv --------------- newtype TcS a = TcS { unTcS :: TcSEnv -> TcM a } deriving (Functor) +-- | Smart constructor for 'TcS', as describe in Note [The one-shot state +-- monad trick] in "GHC.Utils.Monad". +mkTcS :: (TcSEnv -> TcM a) -> TcS a +mkTcS f = TcS (oneShot f) + instance Applicative TcS where - pure x = TcS (\_ -> return x) + pure x = mkTcS $ \_ -> return x (<*>) = ap instance Monad TcS where - m >>= k = TcS (\ebs -> unTcS m ebs >>= \r -> unTcS (k r) ebs) + m >>= k = mkTcS $ \ebs -> do + unTcS m ebs >>= (\r -> unTcS (k r) ebs) instance MonadFail TcS where - fail err = TcS (\_ -> fail err) + fail err = mkTcS $ \_ -> fail err instance MonadUnique TcS where getUniqueSupplyM = wrapTcS getUniqueSupplyM @@ -2921,7 +2928,7 @@ instance MonadThings TcS where wrapTcS :: TcM a -> TcS a -- Do not export wrapTcS, because it promotes an arbitrary TcM to TcS, -- and TcS is supposed to have limited functionality -wrapTcS = TcS . const -- a TcM action will not use the TcEvBinds +wrapTcS action = mkTcS $ \_env -> action -- a TcM action will not use the TcEvBinds wrapErrTcS :: TcM a -> TcS a -- The thing wrapped should just fail @@ -2956,9 +2963,10 @@ getGlobalRdrEnvTcS :: TcS GlobalRdrEnv getGlobalRdrEnvTcS = wrapTcS TcM.getGlobalRdrEnv bumpStepCountTcS :: TcS () -bumpStepCountTcS = TcS $ \env -> do { let ref = tcs_count env - ; n <- TcM.readTcRef ref - ; TcM.writeTcRef ref (n+1) } +bumpStepCountTcS = mkTcS $ \env -> + do { let ref = tcs_count env + ; n <- TcM.readTcRef ref + ; TcM.writeTcRef ref (n+1) } csTraceTcS :: SDoc -> TcS () csTraceTcS doc @@ -2968,7 +2976,7 @@ csTraceTcS doc traceFireTcS :: CtEvidence -> SDoc -> TcS () -- Dump a rule-firing trace traceFireTcS ev doc - = TcS $ \env -> csTraceTcM $ + = mkTcS $ \env -> csTraceTcM $ do { n <- TcM.readTcRef (tcs_count env) ; tclvl <- TcM.getTcLevel ; return (hang (text "Step" <+> int n |