diff options
author | Simon Peyton Jones <simon.peytonjones@gmail.com> | 2023-02-22 23:17:04 +0000 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2023-02-28 11:10:31 -0500 |
commit | 0c200ab78c814cb5d1efaf426f0d3d91ceab9f4d (patch) | |
tree | 6788727ca79522790c651d4b451ab0a37e5b38cb /testsuite | |
parent | 9fa545722f9151781344446dd5501db38cb90dd1 (diff) | |
download | haskell-0c200ab78c814cb5d1efaf426f0d3d91ceab9f4d.tar.gz |
Account for local rules in specImports
As #23024 showed, in GHC.Core.Opt.Specialise.specImports, we were
generating specialisations (a locally-define function) for imported
functions; and then generating specialisations for those
locally-defined functions. The RULE for the latter should be
attached to the local Id, not put in the rules-for-imported-ids
set.
Fix is easy; similar to what happens in GHC.HsToCore.addExportFlagsAndRules
Diffstat (limited to 'testsuite')
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T23024.hs | 8 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T23024a.hs | 82 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/all.T | 1 |
3 files changed, 91 insertions, 0 deletions
diff --git a/testsuite/tests/simplCore/should_compile/T23024.hs b/testsuite/tests/simplCore/should_compile/T23024.hs new file mode 100644 index 0000000000..9c494e6cee --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T23024.hs @@ -0,0 +1,8 @@ +{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings #-} +{-# LANGUAGE RankNTypes #-} +module T23024 (testPolyn) where + +import T23024a + +testPolyn :: (forall r. Tensor r => r) -> Vector Double +testPolyn f = gradientFromDelta f diff --git a/testsuite/tests/simplCore/should_compile/T23024a.hs b/testsuite/tests/simplCore/should_compile/T23024a.hs new file mode 100644 index 0000000000..f204d8fd89 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T23024a.hs @@ -0,0 +1,82 @@ +{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings -Wno-missing-methods #-} +{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, + DataKinds, MultiParamTypeClasses, RankNTypes, MonoLocalBinds #-} +module T23024a where + +import System.IO.Unsafe +import Control.Monad.ST ( ST, runST ) +import Foreign.ForeignPtr +import Foreign.Storable +import GHC.ForeignPtr ( unsafeWithForeignPtr ) + +class MyNum a where + fi :: a + +class (MyNum a, Eq a) => MyReal a + +class (MyReal a) => MyRealFrac a where + fun :: a -> () + +class (MyRealFrac a, MyNum a) => MyRealFloat a + +instance MyNum Double +instance MyReal Double +instance MyRealFloat Double +instance MyRealFrac Double + +newtype Vector a = Vector (ForeignPtr a) + +class GVector v a where +instance Storable a => GVector Vector a + +vunstream :: () -> ST s (v a) +vunstream () = vunstream () + +empty :: GVector v a => v a +empty = runST (vunstream ()) +{-# NOINLINE empty #-} + +instance (Storable a, Eq a) => Eq (Vector a) where + xs == ys = idx xs == idx ys + +{-# NOINLINE idx #-} +idx (Vector fp) = unsafePerformIO + $ unsafeWithForeignPtr fp $ \p -> + peekElemOff p 0 + +instance MyNum (Vector Double) +instance (MyNum (Vector a), Storable a, Eq a) => MyReal (Vector a) +instance (MyNum (Vector a), Storable a, Eq a) => MyRealFrac (Vector a) +instance (MyNum (Vector a), Storable a, MyRealFloat a) => MyRealFloat (Vector a) + +newtype ORArray a = A a + +instance (Eq a) => Eq (ORArray a) where + A x == A y = x == y + +instance (MyNum (Vector a)) => MyNum (ORArray a) +instance (MyNum (Vector a), Storable a, Eq a) => MyReal (ORArray a) +instance (MyRealFrac (Vector a), Storable a, Eq a) => MyRealFrac (ORArray a) +instance (MyRealFloat (Vector a), Storable a, Eq a) => MyRealFloat (ORArray a) + +newtype Ast r = AstConst (ORArray r) + +instance Eq (Ast a) where + (==) = undefined + +instance MyNum (ORArray a) => MyNum (Ast a) where + fi = AstConst fi + +instance MyNum (ORArray a) => MyReal (Ast a) +instance MyRealFrac (ORArray a) => MyRealFrac (Ast a) where + {-# INLINE fun #-} + fun x = () + +instance MyRealFloat (ORArray a) => MyRealFloat (Ast a) + +class (MyRealFloat a) => Tensor a +instance (MyRealFloat a, MyNum (Vector a), Storable a) => Tensor (Ast a) + +gradientFromDelta :: Storable a => Ast a -> Vector a +gradientFromDelta _ = empty +{-# NOINLINE gradientFromDelta #-} diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index 3f44a71732..c4e78e0f75 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -475,3 +475,4 @@ test('T22761', normal, multimod_compile, ['T22761', '-O2 -v0']) test('T23012', normal, compile, ['-O']) test('RewriteHigherOrderPatterns', normal, compile, ['-O -ddump-rule-rewrites -dsuppress-all -dsuppress-uniques']) +test('T23024', normal, multimod_compile, ['T23024', '-O -v0']) |