diff options
author | Zach Sullivan <zachsully@gmail.com> | 2019-02-04 15:02:08 -0800 |
---|---|---|
committer | Sebastian Graf <sebastian.graf@kit.edu> | 2019-11-04 10:49:34 +0100 |
commit | a80ff66e2dd93122c4bce1b2073104cff340e7f4 (patch) | |
tree | 3537faf95284281e1a448248a044f88582f25c09 | |
parent | 9920d39886eaf4a2381fa0086bcf438d6fecb0b7 (diff) | |
download | haskell-a80ff66e2dd93122c4bce1b2073104cff340e7f4.tar.gz |
Add arity rules for extensional function types to CoreLint. Coalesced
transformation into a single file in simplCore.
-rw-r--r-- | compiler/basicTypes/Id.hs | 6 | ||||
-rw-r--r-- | compiler/basicTypes/OccName.hs | 5 | ||||
-rw-r--r-- | compiler/coreSyn/CoreArity.hs | 5 | ||||
-rw-r--r-- | compiler/coreSyn/CoreLint.hs | 25 | ||||
-rw-r--r-- | compiler/coreSyn/MkCore.hs | 9 | ||||
-rw-r--r-- | compiler/ghc.cabal.in | 3 | ||||
-rw-r--r-- | compiler/simplCore/EtaArityWW.hs (renamed from compiler/coreSyn/CoreEta.hs) | 182 | ||||
-rw-r--r-- | compiler/simplCore/EtaWorkerWrapper.hs | 20 | ||||
-rw-r--r-- | compiler/simplCore/FloatIn.hs | 9 | ||||
-rw-r--r-- | compiler/simplCore/SimplCore.hs | 10 | ||||
-rw-r--r-- | compiler/simplCore/Simplify.hs | 2 | ||||
-rw-r--r-- | compiler/simplStg/RepType.hs | 2 | ||||
-rw-r--r-- | compiler/types/Type.hs | 23 |
13 files changed, 172 insertions, 129 deletions
diff --git a/compiler/basicTypes/Id.hs b/compiler/basicTypes/Id.hs index 8c62cc9944..a8fcf33ece 100644 --- a/compiler/basicTypes/Id.hs +++ b/compiler/basicTypes/Id.hs @@ -40,7 +40,7 @@ module Id ( mkSysLocal, mkSysLocalM, mkSysLocalOrCoVar, mkSysLocalOrCoVarM, mkUserLocal, mkUserLocalOrCoVar, mkTemplateLocals, mkTemplateLocalsNum, mkTemplateLocal, - mkWorkerId, + mkWorkerId, mkEtaWorkerId, -- ** Taking an Id apart idName, idType, idUnique, idInfo, idDetails, @@ -347,6 +347,10 @@ mkWorkerId :: Unique -> Id -> Type -> Id mkWorkerId uniq unwrkr ty = mkLocalIdOrCoVar (mkDerivedInternalName mkWorkerOcc uniq (getName unwrkr)) ty +mkEtaWorkerId :: Unique -> Id -> Type -> Id +mkEtaWorkerId uniq unwrkr ty + = mkLocalIdOrCoVar (mkDerivedInternalName mkEtaWorkerOcc uniq (getName unwrkr)) ty + -- | Create a /template local/: a family of system local 'Id's in bijection with @Int@s, typically used in unfoldings mkTemplateLocal :: Int -> Type -> Id mkTemplateLocal i ty = mkSysLocalOrCoVar (fsLit "v") (mkBuiltinUnique i) ty diff --git a/compiler/basicTypes/OccName.hs b/compiler/basicTypes/OccName.hs index bbd40f85a5..dcfac4bc08 100644 --- a/compiler/basicTypes/OccName.hs +++ b/compiler/basicTypes/OccName.hs @@ -55,7 +55,7 @@ module OccName ( -- ** Derived 'OccName's isDerivedOccName, - mkDataConWrapperOcc, mkWorkerOcc, + mkDataConWrapperOcc, mkWorkerOcc, mkEtaWorkerOcc, mkMatcherOcc, mkBuilderOcc, mkDefaultMethodOcc, isDefaultMethodOcc, isTypeableBindOcc, mkNewTyCoOcc, mkClassOpAuxOcc, @@ -606,7 +606,7 @@ isTypeableBindOcc occ = '$':'t':'r':_ -> True -- Module binding _ -> False -mkDataConWrapperOcc, mkWorkerOcc, +mkDataConWrapperOcc, mkWorkerOcc, mkEtaWorkerOcc, mkMatcherOcc, mkBuilderOcc, mkDefaultMethodOcc, mkClassDataConOcc, mkDictOcc, @@ -621,6 +621,7 @@ mkDataConWrapperOcc, mkWorkerOcc, -- These derived variables have a prefix that no Haskell value could have mkDataConWrapperOcc = mk_simple_deriv varName "$W" mkWorkerOcc = mk_simple_deriv varName "$w" +mkEtaWorkerOcc = mk_simple_deriv varName "$etaW_" mkMatcherOcc = mk_simple_deriv varName "$m" mkBuilderOcc = mk_simple_deriv varName "$b" mkDefaultMethodOcc = mk_simple_deriv varName "$dm" diff --git a/compiler/coreSyn/CoreArity.hs b/compiler/coreSyn/CoreArity.hs index c2ce0a8a20..b7f539c109 100644 --- a/compiler/coreSyn/CoreArity.hs +++ b/compiler/coreSyn/CoreArity.hs @@ -1092,6 +1092,11 @@ mkEtaWW orig_n orig_expr in_scope orig_ty -- Avoid free vars of the original expression = go (n-1) subst' res_ty (EtaVar eta_id' : eis) + | Just (arg_ty, res_ty) <- splitFunTildeTy_maybe ty + , not (isTypeLevPoly arg_ty) + , let (subst', eta_id') = freshEtaId n subst arg_ty + = go (n-1) subst' res_ty (EtaVar eta_id' : eis) + ----------- Newtypes -- Given this: -- newtype T = MkT ([T] -> Int) diff --git a/compiler/coreSyn/CoreLint.hs b/compiler/coreSyn/CoreLint.hs index a664b825b2..8f2c8cdc31 100644 --- a/compiler/coreSyn/CoreLint.hs +++ b/compiler/coreSyn/CoreLint.hs @@ -840,6 +840,7 @@ lintVarOcc var nargs checkL (idName var /= makeStaticName) $ text "Found makeStatic nested in an expression" + ; checkFunTildeTySat var ty nargs ; checkDeadIdOcc var ; checkJoinOcc var nargs @@ -866,6 +867,23 @@ lintCoreFun expr nargs lintCoreExpr expr ------------------ +checkFunTildeTySat :: Id -> Type -> Arity -> LintM () +checkFunTildeTySat id ty app_arity + | isFunTildeTy ty + = let ty_arity = length (typeArity ty) in + do { checkL (not (ty_arity > app_arity)) (err_msg True ty_arity) + ; checkL (not (ty_arity < app_arity)) (err_msg False ty_arity) + ; return () } + | otherwise + = return () + where applied_msg True = text "Under-applied extensional function:" + applied_msg False = text "Over-applied extensional function:" + + err_msg b ty_arity = applied_msg b <+> ppr id + $$ text "Expected Args:" <+> ppr ty_arity + $$ text "Actual Args:" <+> ppr app_arity + +------------------ checkDeadIdOcc :: Id -> LintM () -- Occurrences of an Id should never be dead.... -- except when we are checking a case pattern @@ -1038,15 +1056,14 @@ lintTyApp fun_ty arg_ty = failWithL (mkTyAppMsg fun_ty arg_ty) ----------------- --- Tilde Types must appear fully applied -lintValApps :: [CoreExpr] -> OutType -> [OutType] -> LintM OutType -lintValApps = undefined - lintValApp :: CoreExpr -> OutType -> OutType -> LintM OutType lintValApp arg fun_ty arg_ty | Just (arg,res) <- splitFunTy_maybe fun_ty = do { ensureEqTys arg arg_ty err1 ; return res } + | Just (arg,res) <- splitFunTildeTy_maybe fun_ty + = do { ensureEqTys arg arg_ty err1 + ; return res } | otherwise = failWithL err2 where diff --git a/compiler/coreSyn/MkCore.hs b/compiler/coreSyn/MkCore.hs index c9665ec8d7..377d878d5a 100644 --- a/compiler/coreSyn/MkCore.hs +++ b/compiler/coreSyn/MkCore.hs @@ -161,10 +161,15 @@ mkCoreAppTyped _ (fun, fun_ty) (Type ty) mkCoreAppTyped _ (fun, fun_ty) (Coercion co) = (App fun (Coercion co), funResultTy fun_ty) mkCoreAppTyped d (fun, fun_ty) arg - = ASSERT2( isFunTy fun_ty, ppr fun $$ ppr arg $$ d ) - (mkValApp fun arg arg_ty res_ty, res_ty) + | isFunTy fun_ty + = (mkValApp fun arg arg_ty res_ty, res_ty) where (arg_ty, res_ty) = splitFunTy fun_ty +mkCoreAppTyped d (fun, fun_ty) arg + = ASSERT2( isFunTildeTy fun_ty, ppr fun $$ ppr arg $$ d ) + (mkValApp fun arg arg_ty res_ty, res_ty) + where + (arg_ty, res_ty) = splitTildeFunTy fun_ty mkValApp :: CoreExpr -> CoreExpr -> Type -> Type -> CoreExpr -- Build an application (e1 e2), diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in index 635234f303..181ecd1fbc 100644 --- a/compiler/ghc.cabal.in +++ b/compiler/ghc.cabal.in @@ -309,7 +309,6 @@ Library GHC.StgToCmm.ExtCode SMRep CoreArity - CoreEta CoreFVs CoreLint CorePrep @@ -473,7 +472,7 @@ Library StgSyn StgFVs CallArity - EtaWorkerWrapper + EtaArityWW DmdAnal Exitify WorkWrap diff --git a/compiler/coreSyn/CoreEta.hs b/compiler/simplCore/EtaArityWW.hs index 5fdaf0c0d0..6ce08dcfeb 100644 --- a/compiler/coreSyn/CoreEta.hs +++ b/compiler/simplCore/EtaArityWW.hs @@ -1,23 +1,20 @@ -{- -Worker/wrapper transformation for type directed etaExpansion. - -To be done as part of tidying just before translation to STG. --} - -module CoreEta - ( arityWorkerWrapper,etaTypeArity - ) where +module EtaArityWW (etaArityWW) where import GhcPrelude +import DynFlags import BasicTypes import CoreSyn import CoreSubst import CoreArity +import CoreFVs import Id +import IdInfo import TyCoRep import UniqSupply +import VarEnv import Outputable +import MonadUtils import qualified Data.Map as F @@ -27,53 +24,74 @@ import qualified Data.Map as F Call Arity in the Types * * ************************************************************************ + +Goal: + +Expose more arity information at code generation by tracking the arity of top +level (though let-bound terms should be included too) terms in the types. -} +etaArityWW + :: DynFlags -> UniqSupply -> CoreProgram -> CoreProgram +etaArityWW dflags us binds = initUs_ us $ concatMapM arityWorkerWrapper binds + -- ^ Given a top level entity, produce the WorkerWrapper transformed -- version. This transformation may or may not produce new top level entities -- depending on its arity. arityWorkerWrapper :: CoreBind -> UniqSM [CoreBind] arityWorkerWrapper (NonRec name expr) - = arityWorkerWrapper' name expr >>= \e_ww -> - case e_ww of - Left (worker,wrapper) -> return (map (uncurry NonRec) [worker,wrapper]) - Right (n,e) -> return [NonRec n e] -arityWorkerWrapper (Rec binds) = - do { out <- mapM (uncurry arityWorkerWrapper') binds - ; let (recs,nonrecs) = collectRecs out - ; return ([Rec recs] ++ map (uncurry NonRec) nonrecs) } - where collectRecs = foldr (\x (recs,nonrecs) -> - case x of - Left (worker,wrapper) -> - (worker:recs,wrapper:nonrecs) - Right cb -> (cb:recs,nonrecs) - ) - ([],[]) + = map (uncurry NonRec) <$> arityWorkerWrapper' name expr +arityWorkerWrapper (Rec binds) + = (return . Rec) <$> concatMapM (uncurry arityWorkerWrapper') binds -- ^ Change a function binding into a call to its wrapper and the production of -- a wrapper. The worker/wrapper transformation *only* makes sense for Id's or -- binders to code. arityWorkerWrapper' - :: CoreBndr + :: Id -> CoreExpr - -> UniqSM (Either ((CoreBndr,CoreExpr),(CoreBndr,CoreExpr)) - (CoreBndr,CoreExpr)) + -> UniqSM [(Id,CoreExpr)] -- the first component are recursive binds and the second are non-recursive -- binds (the wrappers are non-recursive) -arityWorkerWrapper' name expr - = let arity = manifestArity expr in - case arity >= 1 && isId name of - True -> - let fm = calledArityMap expr - ty = exprArityType arity (idType name) expr fm in - do { uniq <- getUniqueM - ; let wname = mkWorkerId uniq name ty - ; let worker = mkArityWorker name wname expr - -- ; panic $ showSDocUnsafe (ppr (F.toList fm)) - -- ; panic $ showSDocUnsafe (ppr expr) - ; wrapper <- mkArityWrapper fm name wname expr arity - ; return (Left (worker,wrapper)) } - False -> return (Right (name,expr)) +arityWorkerWrapper' fn_id rhs + | arity >= 1 && isId fn_id + = let fm = calledArityMap rhs + work_ty = exprArityType arity (idType fn_id) rhs fm + fn_info = idInfo fn_id + fn_inl_prag = inlinePragInfo fn_info + fn_inline_spec = inl_inline fn_inl_prag + fn_act = inl_act fn_inl_prag + rule_match_info = inlinePragmaRuleMatchInfo fn_inl_prag + work_prag = InlinePragma { inl_src = SourceText "{-# INLINE" + , inl_inline = NoInline + , inl_sat = Nothing + , inl_act = ActiveAfter NoSourceText 0 + , inl_rule = FunLike } + wrap_act = case fn_act of + ActiveAfter {} -> fn_act + NeverActive -> ActiveAfter NoSourceText 0 + _ -> ActiveAfter NoSourceText 2 + wrap_prag = InlinePragma { inl_src = SourceText "{-# INLINE" + , inl_inline = NoUserInline + , inl_sat = Nothing + , inl_act = wrap_act + , inl_rule = rule_match_info } + in + do { uniq <- getUniqueM + ; let work_id = mkEtaWorkerId uniq fn_id work_ty + `setIdOccInfo` occInfo fn_info + `setIdArity` arity + `setInlinePragma` work_prag + wrap_id = fn_id `setIdOccInfo` noOccInfo + `setInlinePragma` wrap_prag + work_rhs = mkArityWorkerRhs fn_id work_id rhs + ; wrap_rhs <- mkArityWrapperRhs fm work_id rhs arity + ; return [(work_id,work_rhs),(wrap_id,wrap_rhs)] } + + | otherwise + = return [(fn_id,rhs)] + + where arity = manifestArity rhs {- | exprArityType creates the new type for an extensional function given the arity. We also need to consider higher-order functions. The type of a @@ -149,14 +167,15 @@ fooWrapper f = let f' = \x1 x2 -> f x1 x2 in fooWorker f' -The wrapper eta-expands all functions. +The wrapper eta-expands all functions so that the worker can assume that its +arguments are the most extensional functions types. -} {- -calledArityMap takes a core expression (meant to be the RHS of a top level +@calledArityMap@ takes a core expression (meant to be the RHS of a top level binding) and returns a Map of binders to an arity. This map will be used for determining how much to etaExpand the higher-order functions used in -mkArityWrapper. +@mkArityWrapper@. foo :: (Int -> Int -> Int) -> Int foo f = @@ -164,7 +183,7 @@ foo f = y = f 2 in x + y 3 -} -calledArityMap :: CoreExpr -> F.Map CoreBndr Arity +calledArityMap :: CoreExpr -> F.Map Id Arity calledArityMap e = case e of Var x -> F.singleton x 0 @@ -174,20 +193,20 @@ calledArityMap e = expr@(App _ _) -> case collectArgs expr of (Var x,args) -> - let fm = F.unionsWith retGreater (map calledArityMap args) + let fm = F.unionsWith max (map calledArityMap args) a = length args in - F.unionWith retGreater (F.singleton x a) fm - (_,args) -> F.unionsWith retGreater (map calledArityMap args) + F.unionWith max (F.singleton x a) fm + (_,args) -> F.unionsWith max (map calledArityMap args) Lam _ expr -> calledArityMap expr Let bnds expr -> - let fm = F.unionsWith retGreater (map calledArityMap (rhssOfBind bnds)) in - F.unionWith retGreater fm (calledArityMap expr) + let fm = F.unionsWith max (map calledArityMap (rhssOfBind bnds)) in + F.unionWith max fm (calledArityMap expr) Case expr _ _ alts -> - let fm = F.unionsWith retGreater (map calledArityMap (rhssOfAlts alts)) in - F.unionWith retGreater fm (calledArityMap expr) + let fm = F.unionsWith max (map calledArityMap (rhssOfAlts alts)) in + F.unionWith max fm (calledArityMap expr) Cast expr _ -> calledArityMap expr @@ -196,13 +215,9 @@ calledArityMap e = Type _ -> F.empty Coercion _ -> F.empty - where retGreater x y = - case x > y of - True -> x - False -> y {- -exprArityType creates the new type for an extensional function given the +@exprArityType@ creates the new type for an extensional function given the arity. We also need to consider higher-order functions. The type of a function argument can change based on the usage of the type in the body of the function. For example, consider the zipWith function. @@ -225,7 +240,7 @@ forall a b c. (a ~> b ~> c) ~> [a] ~> [b] ~> [c] because the function is only applied to two arguments in the body of the function. -} -exprArityType :: Arity -> Type -> CoreExpr -> F.Map CoreBndr Arity -> Type +exprArityType :: Arity -> Type -> CoreExpr -> F.Map Id Arity -> Type exprArityType n (ForAllTy tv body_ty) (Lam _ expr) fm = ForAllTy tv (exprArityType n body_ty expr fm) exprArityType 0 (FunTy arg res) (Lam bndr expr) fm @@ -256,43 +271,28 @@ etaTypeArity _ = 0 -- ^ Given an expression and it's name, generate a new expression with a -- tilde-lambda type. This is the exact same code, but we have encoded the arity -- in the type. -mkArityWorker - :: CoreBndr - -> CoreBndr +mkArityWorkerRhs + :: Id + -> Id + -> CoreExpr -> CoreExpr - -> (CoreBndr,CoreExpr) -mkArityWorker name wname expr - = ( wname , substExpr (text "eta-worker-subst") substitution expr ) - where substitution = extendIdSubst emptySubst name (Var wname) +mkArityWorkerRhs fn_id work_id rhs + = substExprSC (text "eta-worker-subst") subst rhs + where init_subst = mkEmptySubst . mkInScopeSet . exprFreeVars $ rhs + subst = extendSubstWithVar init_subst fn_id work_id -- ^ The wrapper does not change the type and will call the newly created worker -- function. -mkArityWrapper - :: F.Map CoreBndr Arity - -> CoreBndr - -> CoreBndr - -> CoreExpr - -> Arity - -> UniqSM (CoreBndr,CoreExpr) -mkArityWrapper fm name wname expr arity - = mkArityWrapper' fm expr arity wname [] >>= \expr' -> - let name' = setInlinePragma name alwaysInlinePragma in - -- let name' = name in - -- We will always inline the wrapper for call fusion - return ( name' , expr' ) - -mkArityWrapper' - :: F.Map CoreBndr Arity +mkArityWrapperRhs + :: F.Map Id Arity + -> Id -> CoreExpr -> Arity - -> CoreBndr - -> [CoreExpr] -> UniqSM CoreExpr -mkArityWrapper' fm (Lam b e) a w l = - case isId b of - True -> - let expr = etaExpand (F.findWithDefault 0 b fm) (Var b) in - Lam b <$> mkArityWrapper' fm e (a-1) w (expr : l) - False -> - Lam b <$> mkArityWrapper' fm e a w (Type (TyVarTy b) : l) -mkArityWrapper' _ _ _ w l = return $ mkApps (Var w) (reverse l) +-- mkArityWrapperRhs _ work_id _ _ = return (Var work_id) +mkArityWrapperRhs fm wname expr arity = go fm expr arity wname [] + where go fm (Lam b e) a w l + | isId b = let expr = etaExpand (F.findWithDefault 0 b fm) (Var b) in + Lam b <$> go fm e (a-1) w (expr : l) + | otherwise = Lam b <$> go fm e a w (Type (TyVarTy b) : l) + go _ _ _ w l = return $ mkApps (Var w) (reverse l) diff --git a/compiler/simplCore/EtaWorkerWrapper.hs b/compiler/simplCore/EtaWorkerWrapper.hs deleted file mode 100644 index e57567927c..0000000000 --- a/compiler/simplCore/EtaWorkerWrapper.hs +++ /dev/null @@ -1,20 +0,0 @@ -module EtaWorkerWrapper (etaArityWorkerWrapperProgram) where - -import GhcPrelude - -import CallArity -import CoreEta -import CoreSyn -import DynFlags ( DynFlags ) -import UniqSupply -import Outputable -import PprCore - -etaArityWorkerWrapperProgram - :: DynFlags -> UniqSupply -> CoreProgram -> CoreProgram -etaArityWorkerWrapperProgram _dflags us binds - = let binds' = callArityAnalProgram _dflags binds - -- ^ arityWorkerWrapper depends on Call Arity analysis - in - initUs_ us $ concat <$> mapM arityWorkerWrapper binds' - -- = panic (showSDocUnsafe (pprCoreBindingsWithSize (initUs_ us $ concat <$> mapM arityWorkerWrapper binds))) diff --git a/compiler/simplCore/FloatIn.hs b/compiler/simplCore/FloatIn.hs index 216e848889..50b8f1881a 100644 --- a/compiler/simplCore/FloatIn.hs +++ b/compiler/simplCore/FloatIn.hs @@ -201,7 +201,14 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {}) | otherwise = (res_ty, extra_fvs) where - (arg_ty, res_ty) = splitFunTy fun_ty + (arg_ty, res_ty) = + case splitFunTy_maybe fun_ty of + Just x -> x + Nothing -> + case splitFunTildeTy_maybe fun_ty of + Just x -> x + Nothing -> pprPanic "fiExpr.splitFunTy" (ppr fun_ty) + {- Note [Dead bindings] ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/compiler/simplCore/SimplCore.hs b/compiler/simplCore/SimplCore.hs index 2f6a48993f..bb0c75342f 100644 --- a/compiler/simplCore/SimplCore.hs +++ b/compiler/simplCore/SimplCore.hs @@ -46,7 +46,7 @@ import Specialise ( specProgram) import SpecConstr ( specConstrProgram) import DmdAnal ( dmdAnalProgram ) import CallArity ( callArityAnalProgram ) -import EtaWorkerWrapper ( etaArityWorkerWrapperProgram ) +import EtaArityWW ( etaArityWW ) import Exitify ( exitifyProgram ) import WorkWrap ( wwTopBinds ) import SrcLoc @@ -202,7 +202,6 @@ getCoreToDo dflags ] core_todo = - [runWhen eta_arity $ CoreDoPasses [ CoreDoCallArity, CoreDoEtaArity ]] ++ if opt_level == 0 then [ static_ptrs_float_outwards, CoreDoSimplify max_iter @@ -303,6 +302,11 @@ getCoreToDo dflags -- succeed in commoning up things floated out by full laziness. -- CSE used to rely on the no-shadowing invariant, but it doesn't any more + runWhen eta_arity $ CoreDoPasses + [ CoreDoEtaArity + , simpl_phase 0 ["post-eta-arity"] max_iter + ], + runWhen do_float_in CoreDoFloatInwards, maybe_rule_check (Phase 0), @@ -445,7 +449,7 @@ doCorePass CoreDoCallArity = {-# SCC "CallArity" #-} doPassD callArityAnalProgram doCorePass CoreDoEtaArity = {-# SCC "EtaArity" #-} - doPassDU etaArityWorkerWrapperProgram + doPassDU etaArityWW doCorePass CoreDoExitify = {-# SCC "Exitify" #-} doPass exitifyProgram diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs index 795b0f5654..7f1f0cbe6b 100644 --- a/compiler/simplCore/Simplify.hs +++ b/compiler/simplCore/Simplify.hs @@ -1888,7 +1888,7 @@ rebuildCall env info@(ArgInfo { ai_encl = encl_rules, ai_type = fun_ty ; rebuildCall env (addValArgTo info' arg') cont } where info' = info { ai_strs = strs, ai_discs = discs } - arg_ty = funArgTy fun_ty + arg_ty = argTy fun_ty -- Use this for lazy arguments cci_lazy | encl_rules = RuleArgCtxt diff --git a/compiler/simplStg/RepType.hs b/compiler/simplStg/RepType.hs index 75fde79d87..0597ae83db 100644 --- a/compiler/simplStg/RepType.hs +++ b/compiler/simplStg/RepType.hs @@ -110,6 +110,8 @@ countFunRepArgs 0 _ countFunRepArgs n ty | FunTy _ arg res <- unwrapType ty = length (typePrimRepArgs arg) + countFunRepArgs (n - 1) res + | FunTildeTy arg res <- unwrapType ty + = length (typePrimRepArgs arg) + countFunRepArgs (n - 1) res | otherwise = pprPanic "countFunRepArgs: arity greater than type can handle" (ppr (n, ty, typePrimRep ty)) diff --git a/compiler/types/Type.hs b/compiler/types/Type.hs index dba92c4d93..375297c2f4 100644 --- a/compiler/types/Type.hs +++ b/compiler/types/Type.hs @@ -28,7 +28,7 @@ module Type ( mkVisFunTy, mkInvisFunTy, mkVisFunTys, mkInvisFunTys, splitFunTy, splitFunTy_maybe, - splitFunTys, funResultTy, funArgTy, + splitFunTys, funResultTy, funArgTy, argTy, splitFunTildeTy, splitFunTildeTy_maybe, funTildeArgTy, funTildeResultTy, mkTyConApp, mkTyConTy, @@ -1019,6 +1019,13 @@ funTildeArgTy ty | Just ty' <- coreView ty = funTildeArgTy ty' funTildeArgTy (FunTildeTy arg _res) = arg funTildeArgTy ty = pprPanic "funTildeArgTy" (ppr ty) +argTy :: Type -> Type +argTy ty | Just ty' <- coreView ty = argTy ty' +argTy (FunTy arg _res) = arg +argTy (FunTildeTy arg _res) = arg +argTy ty = pprPanic "argTy" (ppr ty) + + piResultTy :: HasDebugCallStack => Type -> Type -> Type piResultTy ty arg = case piResultTy_maybe ty arg of Just res -> res @@ -1181,6 +1188,7 @@ mkTyConApp tycon tys tyConAppTyConPicky_maybe :: Type -> Maybe TyCon tyConAppTyConPicky_maybe (TyConApp tc _) = Just tc tyConAppTyConPicky_maybe (FunTy {}) = Just funTyCon +tyConAppTyConPicky_maybe (FunTildeTy {}) = Just funTildeTyCon tyConAppTyConPicky_maybe _ = Nothing @@ -1189,6 +1197,7 @@ tyConAppTyCon_maybe :: Type -> Maybe TyCon tyConAppTyCon_maybe ty | Just ty' <- coreView ty = tyConAppTyCon_maybe ty' tyConAppTyCon_maybe (TyConApp tc _) = Just tc tyConAppTyCon_maybe (FunTy {}) = Just funTyCon +tyConAppTyCon_maybe (FunTildeTy {}) = Just funTildeTyCon tyConAppTyCon_maybe _ = Nothing tyConAppTyCon :: Type -> TyCon @@ -1202,6 +1211,10 @@ tyConAppArgs_maybe (FunTy _ arg res) | Just rep1 <- getRuntimeRep_maybe arg , Just rep2 <- getRuntimeRep_maybe res = Just [rep1, rep2, arg, res] +tyConAppArgs_maybe (FunTildeTy arg res) + | Just rep1 <- getRuntimeRep_maybe arg + , Just rep2 <- getRuntimeRep_maybe res + = Just [rep1, rep2, arg, res] tyConAppArgs_maybe _ = Nothing tyConAppArgs :: Type -> [Type] @@ -1256,6 +1269,10 @@ repSplitTyConApp_maybe (FunTy _ arg res) | Just arg_rep <- getRuntimeRep_maybe arg , Just res_rep <- getRuntimeRep_maybe res = Just (funTyCon, [arg_rep, res_rep, arg, res]) +repSplitTyConApp_maybe (FunTildeTy arg res) + | Just arg_rep <- getRuntimeRep_maybe arg + , Just res_rep <- getRuntimeRep_maybe res + = Just (funTildeTyCon, [arg_rep, res_rep, arg, res]) repSplitTyConApp_maybe _ = Nothing ------------------- @@ -2313,8 +2330,10 @@ nonDetCmpTypeX env orig_t1 orig_t2 = get_rank (LitTy {}) = 4 get_rank (TyConApp {}) = 5 get_rank (FunTy {}) = 6 + get_rank (FunTildeTy {}) = 6 + -- If we want to distinguish extensional functions from normal + -- functions, then we need to use @isFunTildeTy@. get_rank (ForAllTy {}) = 7 - get_rank (FunTildeTy {}) = 8 gos :: RnEnv2 -> [Type] -> [Type] -> TypeOrdering gos _ [] [] = TEQ |