diff options
-rw-r--r-- | compiler/GHC/Core/Opt/Specialise.hs | 73 | ||||
-rw-r--r-- | compiler/GHC/Core/Rules.hs | 10 | ||||
-rw-r--r-- | compiler/GHC/HsToCore.hs | 28 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T23024.hs | 8 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T23024a.hs | 82 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/all.T | 1 |
6 files changed, 156 insertions, 46 deletions
diff --git a/compiler/GHC/Core/Opt/Specialise.hs b/compiler/GHC/Core/Opt/Specialise.hs index 13bff9f170..f028bd428e 100644 --- a/compiler/GHC/Core/Opt/Specialise.hs +++ b/compiler/GHC/Core/Opt/Specialise.hs @@ -64,6 +64,7 @@ import GHC.Unit.Module( Module ) import GHC.Unit.Module.ModGuts import GHC.Core.Unfold +import Data.List( partition ) import Data.List.NonEmpty ( NonEmpty (..) ) {- @@ -726,6 +727,33 @@ specialisation (see canSpecImport): Specialise even INLINE things; it hasn't inlined yet, so perhaps it never will. Moreover it may have calls inside it that we want to specialise + +Wrinkle (W1): If we specialise an imported Id M.foo, we make a /local/ +binding $sfoo. But specImports may further specialise $sfoo. So we end up +with RULES for both M.foo (imported) and $sfoo (local). Rules for local +Ids should be attached to the Ids themselves (see GHC.HsToCore +Note [Attach rules to local ids]); so we must partition the rules and +attach the local rules. That is done in specImports, via addRulesToId. + +Note [Glom the bindings if imported functions are specialised] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Suppose we have an imported, *recursive*, INLINABLE function + f :: Eq a => a -> a + f = /\a \d x. ...(f a d)... +In the module being compiled we have + g x = f (x::Int) +Now we'll make a specialised function + f_spec :: Int -> Int + f_spec = \x -> ...(f Int dInt)... + {-# RULE f Int _ = f_spec #-} + g = \x. f Int dInt x +Note that f_spec doesn't look recursive +After rewriting with the RULE, we get + f_spec = \x -> ...(f_spec)... +BUT since f_spec was non-recursive before it'll *stay* non-recursive. +The occurrence analyser never turns a NonRec into a Rec. So we must +make sure that f_spec is recursive. Easiest thing is to make all +the specialisations for imported bindings recursive. -} specImports :: SpecEnv @@ -740,16 +768,24 @@ specImports top_env (MkUD { ud_binds = dict_binds, ud_calls = calls }) = do { let env_w_dict_bndrs = top_env `bringFloatedDictsIntoScope` dict_binds ; (_env, spec_rules, spec_binds) <- spec_imports env_w_dict_bndrs [] dict_binds calls - -- Don't forget to wrap the specialized bindings with - -- bindings for the needed dictionaries. - -- See Note [Wrap bindings returned by specImports] - -- and Note [Glom the bindings if imported functions are specialised] - ; let final_binds + -- Make a Rec: see Note [Glom the bindings if imported functions are specialised] + -- + -- wrapDictBinds: don't forget to wrap the specialized bindings with + -- bindings for the needed dictionaries. + -- See Note [Wrap bindings returned by specImports] + -- + -- addRulesToId: see Wrinkle (W1) in Note [Specialising imported functions] + -- c.f. GHC.HsToCore.addExportFlagsAndRules + ; let (rules_for_locals, rules_for_imps) = partition isLocalRule spec_rules + local_rule_base = extendRuleBaseList emptyRuleBase rules_for_locals + final_binds | null spec_binds = wrapDictBinds dict_binds [] - | otherwise = [Rec $ flattenBinds $ - wrapDictBinds dict_binds spec_binds] + | otherwise = [Rec $ mapFst (addRulesToId local_rule_base) $ + flattenBinds $ + wrapDictBinds dict_binds $ + spec_binds] - ; return (spec_rules, final_binds) + ; return (rules_for_imps, final_binds) } -- | Specialise a set of calls to imported bindings @@ -1111,27 +1147,6 @@ And if the call is to the same type, one specialisation is enough. Avoiding this recursive specialisation loop is one reason for the 'callers' stack passed to specImports and specImport. -Note [Glom the bindings if imported functions are specialised] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Suppose we have an imported, *recursive*, INLINABLE function - f :: Eq a => a -> a - f = /\a \d x. ...(f a d)... -In the module being compiled we have - g x = f (x::Int) -Now we'll make a specialised function - f_spec :: Int -> Int - f_spec = \x -> ...(f Int dInt)... - {-# RULE f Int _ = f_spec #-} - g = \x. f Int dInt x -Note that f_spec doesn't look recursive -After rewriting with the RULE, we get - f_spec = \x -> ...(f_spec)... -BUT since f_spec was non-recursive before it'll *stay* non-recursive. -The occurrence analyser never turns a NonRec into a Rec. So we must -make sure that f_spec is recursive. Easiest thing is to make all -the specialisations for imported bindings recursive. - - ************************************************************************ * * diff --git a/compiler/GHC/Core/Rules.hs b/compiler/GHC/Core/Rules.hs index df763835cf..67f47e9d9c 100644 --- a/compiler/GHC/Core/Rules.hs +++ b/compiler/GHC/Core/Rules.hs @@ -22,7 +22,7 @@ module GHC.Core.Rules ( -- ** Manipulating 'RuleInfo' rules extendRuleInfo, addRuleInfo, - addIdSpecialisations, + addIdSpecialisations, addRulesToId, -- ** RuleBase and RuleEnv @@ -349,6 +349,14 @@ addIdSpecialisations id rules = setIdSpecialisation id $ extendRuleInfo (idSpecialisation id) rules +addRulesToId :: RuleBase -> Id -> Id +-- Add rules in the RuleBase to the rules in the Id +addRulesToId rule_base bndr + | Just rules <- lookupNameEnv rule_base (idName bndr) + = bndr `addIdSpecialisations` rules + | otherwise + = bndr + -- | Gather all the rules for locally bound identifiers from the supplied bindings rulesOfBinds :: [CoreBind] -> [CoreRule] rulesOfBinds binds = concatMap (concatMap idCoreRules . bindersOf) binds diff --git a/compiler/GHC/HsToCore.hs b/compiler/GHC/HsToCore.hs index 755fe3e198..57f1b13391 100644 --- a/compiler/GHC/HsToCore.hs +++ b/compiler/GHC/HsToCore.hs @@ -362,32 +362,28 @@ deSugarExpr hsc_env tc_expr = do addExportFlagsAndRules :: Backend -> NameSet -> NameSet -> [CoreRule] -> [(Id, t)] -> [(Id, t)] -addExportFlagsAndRules bcknd exports keep_alive rules = mapFst add_one +addExportFlagsAndRules bcknd exports keep_alive rules + = mapFst (addRulesToId rule_base . add_export_flag) + -- addRulesToId: see Note [Attach rules to local ids] + -- NB: the binder might have some existing rules, + -- arising from specialisation pragmas + where - add_one bndr = add_rules name (add_export name bndr) - where - name = idName bndr ---------- Rules -------- - -- See Note [Attach rules to local ids] - -- NB: the binder might have some existing rules, - -- arising from specialisation pragmas - add_rules name bndr - | Just rules <- lookupNameEnv rule_base name - = bndr `addIdSpecialisations` rules - | otherwise - = bndr rule_base = extendRuleBaseList emptyRuleBase rules ---------- Export flag -------- -- See Note [Adding export flags] - add_export name bndr - | dont_discard name = setIdExported bndr + add_export_flag bndr + | dont_discard bndr = setIdExported bndr | otherwise = bndr - dont_discard :: Name -> Bool - dont_discard name = is_exported name + dont_discard :: Id -> Bool + dont_discard bndr = is_exported name || name `elemNameSet` keep_alive + where + name = idName bndr -- In interactive mode, we don't want to discard any top-level -- entities at all (eg. do not inline them away during diff --git a/testsuite/tests/simplCore/should_compile/T23024.hs b/testsuite/tests/simplCore/should_compile/T23024.hs new file mode 100644 index 0000000000..9c494e6cee --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T23024.hs @@ -0,0 +1,8 @@ +{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings #-} +{-# LANGUAGE RankNTypes #-} +module T23024 (testPolyn) where + +import T23024a + +testPolyn :: (forall r. Tensor r => r) -> Vector Double +testPolyn f = gradientFromDelta f diff --git a/testsuite/tests/simplCore/should_compile/T23024a.hs b/testsuite/tests/simplCore/should_compile/T23024a.hs new file mode 100644 index 0000000000..f204d8fd89 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T23024a.hs @@ -0,0 +1,82 @@ +{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings -Wno-missing-methods #-} +{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, + DataKinds, MultiParamTypeClasses, RankNTypes, MonoLocalBinds #-} +module T23024a where + +import System.IO.Unsafe +import Control.Monad.ST ( ST, runST ) +import Foreign.ForeignPtr +import Foreign.Storable +import GHC.ForeignPtr ( unsafeWithForeignPtr ) + +class MyNum a where + fi :: a + +class (MyNum a, Eq a) => MyReal a + +class (MyReal a) => MyRealFrac a where + fun :: a -> () + +class (MyRealFrac a, MyNum a) => MyRealFloat a + +instance MyNum Double +instance MyReal Double +instance MyRealFloat Double +instance MyRealFrac Double + +newtype Vector a = Vector (ForeignPtr a) + +class GVector v a where +instance Storable a => GVector Vector a + +vunstream :: () -> ST s (v a) +vunstream () = vunstream () + +empty :: GVector v a => v a +empty = runST (vunstream ()) +{-# NOINLINE empty #-} + +instance (Storable a, Eq a) => Eq (Vector a) where + xs == ys = idx xs == idx ys + +{-# NOINLINE idx #-} +idx (Vector fp) = unsafePerformIO + $ unsafeWithForeignPtr fp $ \p -> + peekElemOff p 0 + +instance MyNum (Vector Double) +instance (MyNum (Vector a), Storable a, Eq a) => MyReal (Vector a) +instance (MyNum (Vector a), Storable a, Eq a) => MyRealFrac (Vector a) +instance (MyNum (Vector a), Storable a, MyRealFloat a) => MyRealFloat (Vector a) + +newtype ORArray a = A a + +instance (Eq a) => Eq (ORArray a) where + A x == A y = x == y + +instance (MyNum (Vector a)) => MyNum (ORArray a) +instance (MyNum (Vector a), Storable a, Eq a) => MyReal (ORArray a) +instance (MyRealFrac (Vector a), Storable a, Eq a) => MyRealFrac (ORArray a) +instance (MyRealFloat (Vector a), Storable a, Eq a) => MyRealFloat (ORArray a) + +newtype Ast r = AstConst (ORArray r) + +instance Eq (Ast a) where + (==) = undefined + +instance MyNum (ORArray a) => MyNum (Ast a) where + fi = AstConst fi + +instance MyNum (ORArray a) => MyReal (Ast a) +instance MyRealFrac (ORArray a) => MyRealFrac (Ast a) where + {-# INLINE fun #-} + fun x = () + +instance MyRealFloat (ORArray a) => MyRealFloat (Ast a) + +class (MyRealFloat a) => Tensor a +instance (MyRealFloat a, MyNum (Vector a), Storable a) => Tensor (Ast a) + +gradientFromDelta :: Storable a => Ast a -> Vector a +gradientFromDelta _ = empty +{-# NOINLINE gradientFromDelta #-} diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index 3f44a71732..c4e78e0f75 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -475,3 +475,4 @@ test('T22761', normal, multimod_compile, ['T22761', '-O2 -v0']) test('T23012', normal, compile, ['-O']) test('RewriteHigherOrderPatterns', normal, compile, ['-O -ddump-rule-rewrites -dsuppress-all -dsuppress-uniques']) +test('T23024', normal, multimod_compile, ['T23024', '-O -v0']) |