summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simon.peytonjones@gmail.com>2023-02-22 23:17:04 +0000
committerSimon Peyton Jones <simon.peytonjones@gmail.com>2023-02-22 23:20:07 +0000
commit6b718cfcfc23bba795da0fd74a3d8f45b6f3bd7f (patch)
tree0b2faf989fc6603dc9bf1b6122ad856713308aa3
parentf11d9c274d728696bc173c62a2ead62b8288836f (diff)
downloadhaskell-wip/T23024.tar.gz
Account for local rules in specImportswip/T23024
As #23024 showed, in GHC.Core.Opt.Specialise.specImports, we were generating specialisations (a locally-define function) for imported functions; and then generating specialisations for those locally-defined functions. The RULE for the latter should be attached to the local Id, not put in the rules-for-imported-ids set. Fix is easy; similar to what happens in GHC.HsToCore.addExportFlagsAndRules
-rw-r--r--compiler/GHC/Core/Opt/Specialise.hs73
-rw-r--r--compiler/GHC/Core/Rules.hs10
-rw-r--r--compiler/GHC/HsToCore.hs28
-rw-r--r--testsuite/tests/simplCore/should_compile/T23024.hs8
-rw-r--r--testsuite/tests/simplCore/should_compile/T23024a.hs82
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T1
6 files changed, 156 insertions, 46 deletions
diff --git a/compiler/GHC/Core/Opt/Specialise.hs b/compiler/GHC/Core/Opt/Specialise.hs
index 13bff9f170..f028bd428e 100644
--- a/compiler/GHC/Core/Opt/Specialise.hs
+++ b/compiler/GHC/Core/Opt/Specialise.hs
@@ -64,6 +64,7 @@ import GHC.Unit.Module( Module )
import GHC.Unit.Module.ModGuts
import GHC.Core.Unfold
+import Data.List( partition )
import Data.List.NonEmpty ( NonEmpty (..) )
{-
@@ -726,6 +727,33 @@ specialisation (see canSpecImport):
Specialise even INLINE things; it hasn't inlined yet, so perhaps
it never will. Moreover it may have calls inside it that we want
to specialise
+
+Wrinkle (W1): If we specialise an imported Id M.foo, we make a /local/
+binding $sfoo. But specImports may further specialise $sfoo. So we end up
+with RULES for both M.foo (imported) and $sfoo (local). Rules for local
+Ids should be attached to the Ids themselves (see GHC.HsToCore
+Note [Attach rules to local ids]); so we must partition the rules and
+attach the local rules. That is done in specImports, via addRulesToId.
+
+Note [Glom the bindings if imported functions are specialised]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have an imported, *recursive*, INLINABLE function
+ f :: Eq a => a -> a
+ f = /\a \d x. ...(f a d)...
+In the module being compiled we have
+ g x = f (x::Int)
+Now we'll make a specialised function
+ f_spec :: Int -> Int
+ f_spec = \x -> ...(f Int dInt)...
+ {-# RULE f Int _ = f_spec #-}
+ g = \x. f Int dInt x
+Note that f_spec doesn't look recursive
+After rewriting with the RULE, we get
+ f_spec = \x -> ...(f_spec)...
+BUT since f_spec was non-recursive before it'll *stay* non-recursive.
+The occurrence analyser never turns a NonRec into a Rec. So we must
+make sure that f_spec is recursive. Easiest thing is to make all
+the specialisations for imported bindings recursive.
-}
specImports :: SpecEnv
@@ -740,16 +768,24 @@ specImports top_env (MkUD { ud_binds = dict_binds, ud_calls = calls })
= do { let env_w_dict_bndrs = top_env `bringFloatedDictsIntoScope` dict_binds
; (_env, spec_rules, spec_binds) <- spec_imports env_w_dict_bndrs [] dict_binds calls
- -- Don't forget to wrap the specialized bindings with
- -- bindings for the needed dictionaries.
- -- See Note [Wrap bindings returned by specImports]
- -- and Note [Glom the bindings if imported functions are specialised]
- ; let final_binds
+ -- Make a Rec: see Note [Glom the bindings if imported functions are specialised]
+ --
+ -- wrapDictBinds: don't forget to wrap the specialized bindings with
+ -- bindings for the needed dictionaries.
+ -- See Note [Wrap bindings returned by specImports]
+ --
+ -- addRulesToId: see Wrinkle (W1) in Note [Specialising imported functions]
+ -- c.f. GHC.HsToCore.addExportFlagsAndRules
+ ; let (rules_for_locals, rules_for_imps) = partition isLocalRule spec_rules
+ local_rule_base = extendRuleBaseList emptyRuleBase rules_for_locals
+ final_binds
| null spec_binds = wrapDictBinds dict_binds []
- | otherwise = [Rec $ flattenBinds $
- wrapDictBinds dict_binds spec_binds]
+ | otherwise = [Rec $ mapFst (addRulesToId local_rule_base) $
+ flattenBinds $
+ wrapDictBinds dict_binds $
+ spec_binds]
- ; return (spec_rules, final_binds)
+ ; return (rules_for_imps, final_binds)
}
-- | Specialise a set of calls to imported bindings
@@ -1111,27 +1147,6 @@ And if the call is to the same type, one specialisation is enough.
Avoiding this recursive specialisation loop is one reason for the
'callers' stack passed to specImports and specImport.
-Note [Glom the bindings if imported functions are specialised]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Suppose we have an imported, *recursive*, INLINABLE function
- f :: Eq a => a -> a
- f = /\a \d x. ...(f a d)...
-In the module being compiled we have
- g x = f (x::Int)
-Now we'll make a specialised function
- f_spec :: Int -> Int
- f_spec = \x -> ...(f Int dInt)...
- {-# RULE f Int _ = f_spec #-}
- g = \x. f Int dInt x
-Note that f_spec doesn't look recursive
-After rewriting with the RULE, we get
- f_spec = \x -> ...(f_spec)...
-BUT since f_spec was non-recursive before it'll *stay* non-recursive.
-The occurrence analyser never turns a NonRec into a Rec. So we must
-make sure that f_spec is recursive. Easiest thing is to make all
-the specialisations for imported bindings recursive.
-
-
************************************************************************
* *
diff --git a/compiler/GHC/Core/Rules.hs b/compiler/GHC/Core/Rules.hs
index df763835cf..67f47e9d9c 100644
--- a/compiler/GHC/Core/Rules.hs
+++ b/compiler/GHC/Core/Rules.hs
@@ -22,7 +22,7 @@ module GHC.Core.Rules (
-- ** Manipulating 'RuleInfo' rules
extendRuleInfo, addRuleInfo,
- addIdSpecialisations,
+ addIdSpecialisations, addRulesToId,
-- ** RuleBase and RuleEnv
@@ -349,6 +349,14 @@ addIdSpecialisations id rules
= setIdSpecialisation id $
extendRuleInfo (idSpecialisation id) rules
+addRulesToId :: RuleBase -> Id -> Id
+-- Add rules in the RuleBase to the rules in the Id
+addRulesToId rule_base bndr
+ | Just rules <- lookupNameEnv rule_base (idName bndr)
+ = bndr `addIdSpecialisations` rules
+ | otherwise
+ = bndr
+
-- | Gather all the rules for locally bound identifiers from the supplied bindings
rulesOfBinds :: [CoreBind] -> [CoreRule]
rulesOfBinds binds = concatMap (concatMap idCoreRules . bindersOf) binds
diff --git a/compiler/GHC/HsToCore.hs b/compiler/GHC/HsToCore.hs
index 3c6ec71079..5a6bae315d 100644
--- a/compiler/GHC/HsToCore.hs
+++ b/compiler/GHC/HsToCore.hs
@@ -362,32 +362,28 @@ deSugarExpr hsc_env tc_expr = do
addExportFlagsAndRules
:: Backend -> NameSet -> NameSet -> [CoreRule]
-> [(Id, t)] -> [(Id, t)]
-addExportFlagsAndRules bcknd exports keep_alive rules = mapFst add_one
+addExportFlagsAndRules bcknd exports keep_alive rules
+ = mapFst (addRulesToId rule_base . add_export_flag)
+ -- addRulesToId: see Note [Attach rules to local ids]
+ -- NB: the binder might have some existing rules,
+ -- arising from specialisation pragmas
+
where
- add_one bndr = add_rules name (add_export name bndr)
- where
- name = idName bndr
---------- Rules --------
- -- See Note [Attach rules to local ids]
- -- NB: the binder might have some existing rules,
- -- arising from specialisation pragmas
- add_rules name bndr
- | Just rules <- lookupNameEnv rule_base name
- = bndr `addIdSpecialisations` rules
- | otherwise
- = bndr
rule_base = extendRuleBaseList emptyRuleBase rules
---------- Export flag --------
-- See Note [Adding export flags]
- add_export name bndr
- | dont_discard name = setIdExported bndr
+ add_export_flag bndr
+ | dont_discard bndr = setIdExported bndr
| otherwise = bndr
- dont_discard :: Name -> Bool
- dont_discard name = is_exported name
+ dont_discard :: Id -> Bool
+ dont_discard bndr = is_exported name
|| name `elemNameSet` keep_alive
+ where
+ name = idName bndr
-- In interactive mode, we don't want to discard any top-level
-- entities at all (eg. do not inline them away during
diff --git a/testsuite/tests/simplCore/should_compile/T23024.hs b/testsuite/tests/simplCore/should_compile/T23024.hs
new file mode 100644
index 0000000000..9c494e6cee
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T23024.hs
@@ -0,0 +1,8 @@
+{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings #-}
+{-# LANGUAGE RankNTypes #-}
+module T23024 (testPolyn) where
+
+import T23024a
+
+testPolyn :: (forall r. Tensor r => r) -> Vector Double
+testPolyn f = gradientFromDelta f
diff --git a/testsuite/tests/simplCore/should_compile/T23024a.hs b/testsuite/tests/simplCore/should_compile/T23024a.hs
new file mode 100644
index 0000000000..f204d8fd89
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T23024a.hs
@@ -0,0 +1,82 @@
+{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings -Wno-missing-methods #-}
+{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances,
+ DataKinds, MultiParamTypeClasses, RankNTypes, MonoLocalBinds #-}
+module T23024a where
+
+import System.IO.Unsafe
+import Control.Monad.ST ( ST, runST )
+import Foreign.ForeignPtr
+import Foreign.Storable
+import GHC.ForeignPtr ( unsafeWithForeignPtr )
+
+class MyNum a where
+ fi :: a
+
+class (MyNum a, Eq a) => MyReal a
+
+class (MyReal a) => MyRealFrac a where
+ fun :: a -> ()
+
+class (MyRealFrac a, MyNum a) => MyRealFloat a
+
+instance MyNum Double
+instance MyReal Double
+instance MyRealFloat Double
+instance MyRealFrac Double
+
+newtype Vector a = Vector (ForeignPtr a)
+
+class GVector v a where
+instance Storable a => GVector Vector a
+
+vunstream :: () -> ST s (v a)
+vunstream () = vunstream ()
+
+empty :: GVector v a => v a
+empty = runST (vunstream ())
+{-# NOINLINE empty #-}
+
+instance (Storable a, Eq a) => Eq (Vector a) where
+ xs == ys = idx xs == idx ys
+
+{-# NOINLINE idx #-}
+idx (Vector fp) = unsafePerformIO
+ $ unsafeWithForeignPtr fp $ \p ->
+ peekElemOff p 0
+
+instance MyNum (Vector Double)
+instance (MyNum (Vector a), Storable a, Eq a) => MyReal (Vector a)
+instance (MyNum (Vector a), Storable a, Eq a) => MyRealFrac (Vector a)
+instance (MyNum (Vector a), Storable a, MyRealFloat a) => MyRealFloat (Vector a)
+
+newtype ORArray a = A a
+
+instance (Eq a) => Eq (ORArray a) where
+ A x == A y = x == y
+
+instance (MyNum (Vector a)) => MyNum (ORArray a)
+instance (MyNum (Vector a), Storable a, Eq a) => MyReal (ORArray a)
+instance (MyRealFrac (Vector a), Storable a, Eq a) => MyRealFrac (ORArray a)
+instance (MyRealFloat (Vector a), Storable a, Eq a) => MyRealFloat (ORArray a)
+
+newtype Ast r = AstConst (ORArray r)
+
+instance Eq (Ast a) where
+ (==) = undefined
+
+instance MyNum (ORArray a) => MyNum (Ast a) where
+ fi = AstConst fi
+
+instance MyNum (ORArray a) => MyReal (Ast a)
+instance MyRealFrac (ORArray a) => MyRealFrac (Ast a) where
+ {-# INLINE fun #-}
+ fun x = ()
+
+instance MyRealFloat (ORArray a) => MyRealFloat (Ast a)
+
+class (MyRealFloat a) => Tensor a
+instance (MyRealFloat a, MyNum (Vector a), Storable a) => Tensor (Ast a)
+
+gradientFromDelta :: Storable a => Ast a -> Vector a
+gradientFromDelta _ = empty
+{-# NOINLINE gradientFromDelta #-}
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index 12e2ecf042..f38e94edef 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -474,3 +474,4 @@ test('T15205', normal, compile, ['-O -ddump-simpl -dno-typeable-binds -dsuppress
test('T22761', normal, multimod_compile, ['T22761', '-O2 -v0'])
test('RewriteHigherOrderPatterns', normal, compile, ['-O -ddump-rule-rewrites -dsuppress-all -dsuppress-uniques'])
+test('T23024', normal, multimod_compile, ['T23024', '-O -v0'])