summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Vershilov <alexander.vershilov@gmail.com>2015-01-12 05:29:18 -0600
committerAustin Seipp <aseipp@pobox.com>2015-01-19 08:03:29 -0600
commitcd66ec3620cbf56fb856712633b045991adf28f0 (patch)
tree8d9aa21cecad4fb78d6bad553955092d2512fa4f
parent5541b6c34161278180c45d378941d53ed20d9a5a (diff)
downloadhaskell-cd66ec3620cbf56fb856712633b045991adf28f0.tar.gz
Trac #9878: Have StaticPointers support dynamic loading.
Summary: A mutex is used to protect the SPT. unsafeLookupStaticPtr and staticPtrKeys in GHC.StaticPtr are made monadic. SPT entries are removed in a destructor function of modules. Authored-by: Facundo Domínguez <facundo.dominguez@tweag.io> Authored-by: Alexander Vershilov <alexander.vershilov@tweag.io> Test Plan: ./validate Reviewers: austin, simonpj, hvr Subscribers: carter, thomie, qnikst, mboes Differential Revision: https://phabricator.haskell.org/D587 GHC Trac Issues: #9878 (cherry picked from commit 7637810a93441d29bc84bbeeeced0615bbb9d9e4)
-rw-r--r--compiler/deSugar/StaticPtrTable.hs23
-rw-r--r--includes/rts/StaticPtrTable.h8
-rw-r--r--libraries/base/GHC/StaticPtr.hs33
-rw-r--r--rts/Linker.c1
-rw-r--r--rts/StaticPtrTable.c61
-rw-r--r--testsuite/tests/codeGen/should_run/CgStaticPointers.hs11
-rw-r--r--testsuite/tests/rts/GcStaticPointers.hs2
-rw-r--r--testsuite/tests/rts/ListStaticPointers.hs10
8 files changed, 113 insertions, 36 deletions
diff --git a/compiler/deSugar/StaticPtrTable.hs b/compiler/deSugar/StaticPtrTable.hs
index 858a0e8f7b..d1e8e051d3 100644
--- a/compiler/deSugar/StaticPtrTable.hs
+++ b/compiler/deSugar/StaticPtrTable.hs
@@ -26,6 +26,20 @@
--
-- where the constants are fingerprints produced from the static forms.
--
+-- There is also a finalization function for the time when the module is
+-- unloaded.
+--
+-- > static void hs_hpc_fini_Main(void) __attribute__((destructor));
+-- > static void hs_hpc_fini_Main(void) {
+-- >
+-- > static StgWord64 k0[2] = {16252233372134256ULL,7370534374096082ULL};
+-- > hs_spt_remove(k0);
+-- >
+-- > static StgWord64 k1[2] = {12545634534567898ULL,5409674567544151ULL};
+-- > hs_spt_remove(k1);
+-- >
+-- > }
+--
module StaticPtrTable (sptInitCode) where
import CoreSyn
@@ -62,6 +76,15 @@ sptInitCode this_mod entries = vcat
<> semi
| (i, (fp, (n, _))) <- zip [0..] entries
]
+ , text "static void hs_spt_fini_" <> ppr this_mod
+ <> text "(void) __attribute__((destructor));"
+ , text "static void hs_spt_fini_" <> ppr this_mod <> text "(void)"
+ , braces $ vcat $
+ [ text "StgWord64 k" <> int i <> text "[2] = "
+ <> pprFingerprint fp <> semi
+ $$ text "hs_spt_remove" <> parens (char 'k' <> int i) <> semi
+ | (i, (fp, _)) <- zip [0..] entries
+ ]
]
where
diff --git a/includes/rts/StaticPtrTable.h b/includes/rts/StaticPtrTable.h
index 80bb9c469a..9c03d05ed3 100644
--- a/includes/rts/StaticPtrTable.h
+++ b/includes/rts/StaticPtrTable.h
@@ -28,4 +28,12 @@
* */
void hs_spt_insert (StgWord64 key[2],void* spe_closure);
+/** Removes an entry from the Static Pointer Table.
+ *
+ * This function is called from the code generated by
+ * compiler/deSugar/StaticPtrTable.sptInitCode
+ *
+ * */
+void hs_spt_remove (StgWord64 key[2]);
+
#endif /* RTS_STATICPTRTABLE_H */
diff --git a/libraries/base/GHC/StaticPtr.hs b/libraries/base/GHC/StaticPtr.hs
index 65cdf41ceb..ab7998402f 100644
--- a/libraries/base/GHC/StaticPtr.hs
+++ b/libraries/base/GHC/StaticPtr.hs
@@ -24,9 +24,9 @@
--
-- To solve such concern, the references provided by this module offer a key
-- that can be used to locate the values on each process. Each process maintains
--- a global and immutable table of references which can be looked up with a
--- given key. This table is known as the Static Pointer Table. The reference can
--- then be dereferenced to obtain the value.
+-- a global table of references which can be looked up with a given key. This
+-- table is known as the Static Pointer Table. The reference can then be
+-- dereferenced to obtain the value.
--
-----------------------------------------------------------------------------
@@ -48,7 +48,6 @@ import Foreign.Ptr (castPtr)
import GHC.Exts (addrToAny#)
import GHC.Ptr (Ptr(..), nullPtr)
import GHC.Fingerprint (Fingerprint(..))
-import System.IO.Unsafe (unsafePerformIO)
-- | A reference to a value of type 'a'.
@@ -74,8 +73,15 @@ staticKey (StaticPtr k _ _) = k
-- This function is unsafe because the program behavior is undefined if the type
-- of the returned 'StaticPtr' does not match the expected one.
--
-unsafeLookupStaticPtr :: StaticKey -> Maybe (StaticPtr a)
-unsafeLookupStaticPtr k = unsafePerformIO $ sptLookup k
+unsafeLookupStaticPtr :: StaticKey -> IO (Maybe (StaticPtr a))
+unsafeLookupStaticPtr (Fingerprint w1 w2) = do
+ ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr)
+ if (ptr == nullPtr)
+ then return Nothing
+ else case addrToAny# addr of
+ (# spe #) -> return (Just spe)
+
+foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a)
-- | Miscelaneous information available for debugging purposes.
data StaticPtrInfo = StaticPtrInfo
@@ -96,20 +102,9 @@ data StaticPtrInfo = StaticPtrInfo
staticPtrInfo :: StaticPtr a -> StaticPtrInfo
staticPtrInfo (StaticPtr _ n _) = n
--- | Like 'unsafeLookupStaticPtr' but evaluates in 'IO'.
-sptLookup :: StaticKey -> IO (Maybe (StaticPtr a))
-sptLookup (Fingerprint w1 w2) = do
- ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr)
- if (ptr == nullPtr)
- then return Nothing
- else case addrToAny# addr of
- (# spe #) -> return (Just spe)
-
-foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a)
-
-- | A list of all known keys.
-staticPtrKeys :: [StaticKey]
-staticPtrKeys = unsafePerformIO $ do
+staticPtrKeys :: IO [StaticKey]
+staticPtrKeys = do
keyCount <- hs_spt_key_count
allocaArray (fromIntegral keyCount) $ \p -> do
count <- hs_spt_keys p keyCount
diff --git a/rts/Linker.c b/rts/Linker.c
index 29a4b7500d..2ba84f8040 100644
--- a/rts/Linker.c
+++ b/rts/Linker.c
@@ -1428,6 +1428,7 @@ typedef struct _RtsSymbolVal {
SymI_HasProto(atomic_dec) \
SymI_HasProto(hs_spt_lookup) \
SymI_HasProto(hs_spt_insert) \
+ SymI_HasProto(hs_spt_remove) \
SymI_HasProto(hs_spt_keys) \
SymI_HasProto(hs_spt_key_count) \
RTS_USER_SIGNALS_SYMBOLS \
diff --git a/rts/StaticPtrTable.c b/rts/StaticPtrTable.c
index bd450809d0..f7fe06647a 100644
--- a/rts/StaticPtrTable.c
+++ b/rts/StaticPtrTable.c
@@ -8,12 +8,18 @@
*
*/
-#include "Rts.h"
#include "StaticPtrTable.h"
+#include "Rts.h"
+#include "RtsUtils.h"
#include "Hash.h"
+#include "Stable.h"
static HashTable * spt = NULL;
+#ifdef THREADED_RTS
+static Mutex spt_lock;
+#endif
+
/// Hash function for the SPT.
static int hashFingerprint(HashTable *table, StgWord64 key[2]) {
// Take half of the key to compute the hash.
@@ -28,21 +34,59 @@ static int compareFingerprint(StgWord64 ptra[2], StgWord64 ptrb[2]) {
void hs_spt_insert(StgWord64 key[2],void *spe_closure) {
// hs_spt_insert is called from constructor functions, so
// the SPT needs to be initialized here.
- if (spt == NULL)
+ if (spt == NULL) {
spt = allocHashTable_( (HashFunction *)hashFingerprint
, (CompareFunction *)compareFingerprint
);
+#ifdef THREADED_RTS
+ initMutex(&spt_lock);
+#endif
+ }
+
+ StgStablePtr * entry = stgMallocBytes( sizeof(StgStablePtr)
+ , "hs_spt_insert: entry"
+ );
+ *entry = getStablePtr(spe_closure);
+ ACQUIRE_LOCK(&spt_lock);
+ insertHashTable(spt, (StgWord)key, entry);
+ RELEASE_LOCK(&spt_lock);
+}
- getStablePtr(spe_closure);
- insertHashTable(spt, (StgWord)key, spe_closure);
+static void freeSptEntry(void* entry) {
+ freeStablePtr(*(StgStablePtr*)entry);
+ stgFree(entry);
+}
+
+void hs_spt_remove(StgWord64 key[2]) {
+ if (spt) {
+ ACQUIRE_LOCK(&spt_lock);
+ StgStablePtr* entry = removeHashTable(spt, (StgWord)key, NULL);
+ RELEASE_LOCK(&spt_lock);
+
+ if (entry)
+ freeSptEntry(entry);
+ }
}
StgPtr hs_spt_lookup(StgWord64 key[2]) {
- return spt ? lookupHashTable(spt, (StgWord)key) : NULL;
+ if (spt) {
+ ACQUIRE_LOCK(&spt_lock);
+ const StgStablePtr * entry = lookupHashTable(spt, (StgWord)key);
+ RELEASE_LOCK(&spt_lock);
+ const StgPtr ret = entry ? deRefStablePtr(*entry) : NULL;
+ return ret;
+ } else
+ return NULL;
}
int hs_spt_keys(StgPtr keys[], int szKeys) {
- return spt ? keysHashTable(spt, (StgWord*)keys, szKeys) : 0;
+ if (spt) {
+ ACQUIRE_LOCK(&spt_lock);
+ const int ret = keysHashTable(spt, (StgWord*)keys, szKeys);
+ RELEASE_LOCK(&spt_lock);
+ return ret;
+ } else
+ return 0;
}
int hs_spt_key_count() {
@@ -51,7 +95,10 @@ int hs_spt_key_count() {
void exitStaticPtrTable() {
if (spt) {
- freeHashTable(spt, NULL);
+ freeHashTable(spt, freeSptEntry);
spt = NULL;
+#ifdef THREADED_RTS
+ closeMutex(&spt_lock);
+#endif
}
}
diff --git a/testsuite/tests/codeGen/should_run/CgStaticPointers.hs b/testsuite/tests/codeGen/should_run/CgStaticPointers.hs
index 5576f431e8..f7776b0c06 100644
--- a/testsuite/tests/codeGen/should_run/CgStaticPointers.hs
+++ b/testsuite/tests/codeGen/should_run/CgStaticPointers.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE StaticPointers #-}
-- | A test to use symbols produced by the static form.
@@ -9,15 +10,15 @@ import GHC.StaticPtr
main :: IO ()
main = do
- print $ lookupKey (static (id . id)) (1 :: Int)
- print $ lookupKey (static method :: StaticPtr (Char -> Int)) 'a'
+ lookupKey (static (id . id)) >>= \f -> print $ f (1 :: Int)
+ lookupKey (static method :: StaticPtr (Char -> Int)) >>= \f -> print $ f 'a'
print $ deRefStaticPtr (static g)
print $ deRefStaticPtr p0 'a'
print $ deRefStaticPtr (static t_field) $ T 'b'
-lookupKey :: StaticPtr a -> a
-lookupKey p = case unsafeLookupStaticPtr (staticKey p) of
- Just p -> deRefStaticPtr p
+lookupKey :: StaticPtr a -> IO a
+lookupKey p = unsafeLookupStaticPtr (staticKey p) >>= \case
+ Just p -> return $ deRefStaticPtr p
Nothing -> error $ "couldn't find " ++ show (staticPtrInfo p)
g :: String
diff --git a/testsuite/tests/rts/GcStaticPointers.hs b/testsuite/tests/rts/GcStaticPointers.hs
index c498af5842..3bf02d9da9 100644
--- a/testsuite/tests/rts/GcStaticPointers.hs
+++ b/testsuite/tests/rts/GcStaticPointers.hs
@@ -26,7 +26,7 @@ main = do
print z
performGC
threadDelay 1000000
- let Just p = unsafeLookupStaticPtr nats_key
+ Just p <- unsafeLookupStaticPtr nats_key
print (deRefStaticPtr (unsafeCoerce p) !! 800 :: Integer)
-- Uncommenting the next line keeps 'nats' alive and would prevent a segfault
-- if 'nats' were garbage collected.
diff --git a/testsuite/tests/rts/ListStaticPointers.hs b/testsuite/tests/rts/ListStaticPointers.hs
index 5ddb63613f..01c747d45a 100644
--- a/testsuite/tests/rts/ListStaticPointers.hs
+++ b/testsuite/tests/rts/ListStaticPointers.hs
@@ -7,10 +7,12 @@ import Data.List ((\\))
import GHC.StaticPtr
import System.Exit
-main = when (not $ eqBags staticPtrKeys expected) $ do
- print ("expected", expected)
- print ("found", staticPtrKeys)
- exitFailure
+main = do
+ found <- staticPtrKeys
+ when (not $ eqBags found expected) $ do
+ print ("expected", expected)
+ print ("found", found)
+ exitFailure
where
expected =