summaryrefslogtreecommitdiff
path: root/testsuite
diff options
context:
space:
mode:
authorSimon Peyton Jones <simon.peytonjones@gmail.com>2023-02-22 23:17:04 +0000
committerMarge Bot <ben+marge-bot@smart-cactus.org>2023-02-28 11:10:31 -0500
commit0c200ab78c814cb5d1efaf426f0d3d91ceab9f4d (patch)
tree6788727ca79522790c651d4b451ab0a37e5b38cb /testsuite
parent9fa545722f9151781344446dd5501db38cb90dd1 (diff)
downloadhaskell-0c200ab78c814cb5d1efaf426f0d3d91ceab9f4d.tar.gz
Account for local rules in specImports
As #23024 showed, in GHC.Core.Opt.Specialise.specImports, we were generating specialisations (a locally-define function) for imported functions; and then generating specialisations for those locally-defined functions. The RULE for the latter should be attached to the local Id, not put in the rules-for-imported-ids set. Fix is easy; similar to what happens in GHC.HsToCore.addExportFlagsAndRules
Diffstat (limited to 'testsuite')
-rw-r--r--testsuite/tests/simplCore/should_compile/T23024.hs8
-rw-r--r--testsuite/tests/simplCore/should_compile/T23024a.hs82
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T1
3 files changed, 91 insertions, 0 deletions
diff --git a/testsuite/tests/simplCore/should_compile/T23024.hs b/testsuite/tests/simplCore/should_compile/T23024.hs
new file mode 100644
index 0000000000..9c494e6cee
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T23024.hs
@@ -0,0 +1,8 @@
+{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings #-}
+{-# LANGUAGE RankNTypes #-}
+module T23024 (testPolyn) where
+
+import T23024a
+
+testPolyn :: (forall r. Tensor r => r) -> Vector Double
+testPolyn f = gradientFromDelta f
diff --git a/testsuite/tests/simplCore/should_compile/T23024a.hs b/testsuite/tests/simplCore/should_compile/T23024a.hs
new file mode 100644
index 0000000000..f204d8fd89
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T23024a.hs
@@ -0,0 +1,82 @@
+{-# OPTIONS_GHC -fspecialize-aggressively -fexpose-all-unfoldings -Wno-missing-methods #-}
+{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances,
+ DataKinds, MultiParamTypeClasses, RankNTypes, MonoLocalBinds #-}
+module T23024a where
+
+import System.IO.Unsafe
+import Control.Monad.ST ( ST, runST )
+import Foreign.ForeignPtr
+import Foreign.Storable
+import GHC.ForeignPtr ( unsafeWithForeignPtr )
+
+class MyNum a where
+ fi :: a
+
+class (MyNum a, Eq a) => MyReal a
+
+class (MyReal a) => MyRealFrac a where
+ fun :: a -> ()
+
+class (MyRealFrac a, MyNum a) => MyRealFloat a
+
+instance MyNum Double
+instance MyReal Double
+instance MyRealFloat Double
+instance MyRealFrac Double
+
+newtype Vector a = Vector (ForeignPtr a)
+
+class GVector v a where
+instance Storable a => GVector Vector a
+
+vunstream :: () -> ST s (v a)
+vunstream () = vunstream ()
+
+empty :: GVector v a => v a
+empty = runST (vunstream ())
+{-# NOINLINE empty #-}
+
+instance (Storable a, Eq a) => Eq (Vector a) where
+ xs == ys = idx xs == idx ys
+
+{-# NOINLINE idx #-}
+idx (Vector fp) = unsafePerformIO
+ $ unsafeWithForeignPtr fp $ \p ->
+ peekElemOff p 0
+
+instance MyNum (Vector Double)
+instance (MyNum (Vector a), Storable a, Eq a) => MyReal (Vector a)
+instance (MyNum (Vector a), Storable a, Eq a) => MyRealFrac (Vector a)
+instance (MyNum (Vector a), Storable a, MyRealFloat a) => MyRealFloat (Vector a)
+
+newtype ORArray a = A a
+
+instance (Eq a) => Eq (ORArray a) where
+ A x == A y = x == y
+
+instance (MyNum (Vector a)) => MyNum (ORArray a)
+instance (MyNum (Vector a), Storable a, Eq a) => MyReal (ORArray a)
+instance (MyRealFrac (Vector a), Storable a, Eq a) => MyRealFrac (ORArray a)
+instance (MyRealFloat (Vector a), Storable a, Eq a) => MyRealFloat (ORArray a)
+
+newtype Ast r = AstConst (ORArray r)
+
+instance Eq (Ast a) where
+ (==) = undefined
+
+instance MyNum (ORArray a) => MyNum (Ast a) where
+ fi = AstConst fi
+
+instance MyNum (ORArray a) => MyReal (Ast a)
+instance MyRealFrac (ORArray a) => MyRealFrac (Ast a) where
+ {-# INLINE fun #-}
+ fun x = ()
+
+instance MyRealFloat (ORArray a) => MyRealFloat (Ast a)
+
+class (MyRealFloat a) => Tensor a
+instance (MyRealFloat a, MyNum (Vector a), Storable a) => Tensor (Ast a)
+
+gradientFromDelta :: Storable a => Ast a -> Vector a
+gradientFromDelta _ = empty
+{-# NOINLINE gradientFromDelta #-}
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index 3f44a71732..c4e78e0f75 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -475,3 +475,4 @@ test('T22761', normal, multimod_compile, ['T22761', '-O2 -v0'])
test('T23012', normal, compile, ['-O'])
test('RewriteHigherOrderPatterns', normal, compile, ['-O -ddump-rule-rewrites -dsuppress-all -dsuppress-uniques'])
+test('T23024', normal, multimod_compile, ['T23024', '-O -v0'])