diff options
-rw-r--r-- | compiler/deSugar/StaticPtrTable.hs | 23 | ||||
-rw-r--r-- | includes/rts/StaticPtrTable.h | 8 | ||||
-rw-r--r-- | libraries/base/GHC/StaticPtr.hs | 33 | ||||
-rw-r--r-- | rts/Linker.c | 1 | ||||
-rw-r--r-- | rts/StaticPtrTable.c | 61 | ||||
-rw-r--r-- | testsuite/tests/codeGen/should_run/CgStaticPointers.hs | 11 | ||||
-rw-r--r-- | testsuite/tests/rts/GcStaticPointers.hs | 2 | ||||
-rw-r--r-- | testsuite/tests/rts/ListStaticPointers.hs | 10 |
8 files changed, 113 insertions, 36 deletions
diff --git a/compiler/deSugar/StaticPtrTable.hs b/compiler/deSugar/StaticPtrTable.hs index 858a0e8f7b..d1e8e051d3 100644 --- a/compiler/deSugar/StaticPtrTable.hs +++ b/compiler/deSugar/StaticPtrTable.hs @@ -26,6 +26,20 @@ -- -- where the constants are fingerprints produced from the static forms. -- +-- There is also a finalization function for the time when the module is +-- unloaded. +-- +-- > static void hs_hpc_fini_Main(void) __attribute__((destructor)); +-- > static void hs_hpc_fini_Main(void) { +-- > +-- > static StgWord64 k0[2] = {16252233372134256ULL,7370534374096082ULL}; +-- > hs_spt_remove(k0); +-- > +-- > static StgWord64 k1[2] = {12545634534567898ULL,5409674567544151ULL}; +-- > hs_spt_remove(k1); +-- > +-- > } +-- module StaticPtrTable (sptInitCode) where import CoreSyn @@ -62,6 +76,15 @@ sptInitCode this_mod entries = vcat <> semi | (i, (fp, (n, _))) <- zip [0..] entries ] + , text "static void hs_spt_fini_" <> ppr this_mod + <> text "(void) __attribute__((destructor));" + , text "static void hs_spt_fini_" <> ppr this_mod <> text "(void)" + , braces $ vcat $ + [ text "StgWord64 k" <> int i <> text "[2] = " + <> pprFingerprint fp <> semi + $$ text "hs_spt_remove" <> parens (char 'k' <> int i) <> semi + | (i, (fp, _)) <- zip [0..] entries + ] ] where diff --git a/includes/rts/StaticPtrTable.h b/includes/rts/StaticPtrTable.h index 87a905c073..d863160342 100644 --- a/includes/rts/StaticPtrTable.h +++ b/includes/rts/StaticPtrTable.h @@ -28,4 +28,12 @@ * */ void hs_spt_insert (StgWord64 key[2],void* spe_closure); +/** Removes an entry from the Static Pointer Table. + * + * This function is called from the code generated by + * compiler/deSugar/StaticPtrTable.sptInitCode + * + * */ +void hs_spt_remove (StgWord64 key[2]); + #endif /* RTS_STATICPTRTABLE_H */ diff --git a/libraries/base/GHC/StaticPtr.hs b/libraries/base/GHC/StaticPtr.hs index b58564e1d2..efaabf2dd2 100644 --- a/libraries/base/GHC/StaticPtr.hs +++ b/libraries/base/GHC/StaticPtr.hs @@ -24,9 +24,9 @@ -- -- To solve such concern, the references provided by this module offer a key -- that can be used to locate the values on each process. Each process maintains --- a global and immutable table of references which can be looked up with a --- given key. This table is known as the Static Pointer Table. The reference can --- then be dereferenced to obtain the value. +-- a global table of references which can be looked up with a given key. This +-- table is known as the Static Pointer Table. The reference can then be +-- dereferenced to obtain the value. -- ----------------------------------------------------------------------------- @@ -48,7 +48,6 @@ import Foreign.Ptr (castPtr) import GHC.Exts (addrToAny#) import GHC.Ptr (Ptr(..), nullPtr) import GHC.Fingerprint (Fingerprint(..)) -import System.IO.Unsafe (unsafePerformIO) -- | A reference to a value of type 'a'. @@ -74,8 +73,15 @@ staticKey (StaticPtr k _ _) = k -- This function is unsafe because the program behavior is undefined if the type -- of the returned 'StaticPtr' does not match the expected one. -- -unsafeLookupStaticPtr :: StaticKey -> Maybe (StaticPtr a) -unsafeLookupStaticPtr k = unsafePerformIO $ sptLookup k +unsafeLookupStaticPtr :: StaticKey -> IO (Maybe (StaticPtr a)) +unsafeLookupStaticPtr (Fingerprint w1 w2) = do + ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr) + if (ptr == nullPtr) + then return Nothing + else case addrToAny# addr of + (# spe #) -> return (Just spe) + +foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a) -- | Miscelaneous information available for debugging purposes. data StaticPtrInfo = StaticPtrInfo @@ -96,20 +102,9 @@ data StaticPtrInfo = StaticPtrInfo staticPtrInfo :: StaticPtr a -> StaticPtrInfo staticPtrInfo (StaticPtr _ n _) = n --- | Like 'unsafeLookupStaticPtr' but evaluates in 'IO'. -sptLookup :: StaticKey -> IO (Maybe (StaticPtr a)) -sptLookup (Fingerprint w1 w2) = do - ptr@(Ptr addr) <- withArray [w1,w2] (hs_spt_lookup . castPtr) - if (ptr == nullPtr) - then return Nothing - else case addrToAny# addr of - (# spe #) -> return (Just spe) - -foreign import ccall unsafe hs_spt_lookup :: Ptr () -> IO (Ptr a) - -- | A list of all known keys. -staticPtrKeys :: [StaticKey] -staticPtrKeys = unsafePerformIO $ do +staticPtrKeys :: IO [StaticKey] +staticPtrKeys = do keyCount <- hs_spt_key_count allocaArray (fromIntegral keyCount) $ \p -> do count <- hs_spt_keys p keyCount diff --git a/rts/Linker.c b/rts/Linker.c index 4a0e5eadb1..6bf06ed944 100644 --- a/rts/Linker.c +++ b/rts/Linker.c @@ -1420,6 +1420,7 @@ typedef struct _RtsSymbolVal { SymI_HasProto(atomic_dec) \ SymI_HasProto(hs_spt_lookup) \ SymI_HasProto(hs_spt_insert) \ + SymI_HasProto(hs_spt_remove) \ SymI_HasProto(hs_spt_keys) \ SymI_HasProto(hs_spt_key_count) \ RTS_USER_SIGNALS_SYMBOLS \ diff --git a/rts/StaticPtrTable.c b/rts/StaticPtrTable.c index bd450809d0..f7fe06647a 100644 --- a/rts/StaticPtrTable.c +++ b/rts/StaticPtrTable.c @@ -8,12 +8,18 @@ * */ -#include "Rts.h" #include "StaticPtrTable.h" +#include "Rts.h" +#include "RtsUtils.h" #include "Hash.h" +#include "Stable.h" static HashTable * spt = NULL; +#ifdef THREADED_RTS +static Mutex spt_lock; +#endif + /// Hash function for the SPT. static int hashFingerprint(HashTable *table, StgWord64 key[2]) { // Take half of the key to compute the hash. @@ -28,21 +34,59 @@ static int compareFingerprint(StgWord64 ptra[2], StgWord64 ptrb[2]) { void hs_spt_insert(StgWord64 key[2],void *spe_closure) { // hs_spt_insert is called from constructor functions, so // the SPT needs to be initialized here. - if (spt == NULL) + if (spt == NULL) { spt = allocHashTable_( (HashFunction *)hashFingerprint , (CompareFunction *)compareFingerprint ); +#ifdef THREADED_RTS + initMutex(&spt_lock); +#endif + } + + StgStablePtr * entry = stgMallocBytes( sizeof(StgStablePtr) + , "hs_spt_insert: entry" + ); + *entry = getStablePtr(spe_closure); + ACQUIRE_LOCK(&spt_lock); + insertHashTable(spt, (StgWord)key, entry); + RELEASE_LOCK(&spt_lock); +} - getStablePtr(spe_closure); - insertHashTable(spt, (StgWord)key, spe_closure); +static void freeSptEntry(void* entry) { + freeStablePtr(*(StgStablePtr*)entry); + stgFree(entry); +} + +void hs_spt_remove(StgWord64 key[2]) { + if (spt) { + ACQUIRE_LOCK(&spt_lock); + StgStablePtr* entry = removeHashTable(spt, (StgWord)key, NULL); + RELEASE_LOCK(&spt_lock); + + if (entry) + freeSptEntry(entry); + } } StgPtr hs_spt_lookup(StgWord64 key[2]) { - return spt ? lookupHashTable(spt, (StgWord)key) : NULL; + if (spt) { + ACQUIRE_LOCK(&spt_lock); + const StgStablePtr * entry = lookupHashTable(spt, (StgWord)key); + RELEASE_LOCK(&spt_lock); + const StgPtr ret = entry ? deRefStablePtr(*entry) : NULL; + return ret; + } else + return NULL; } int hs_spt_keys(StgPtr keys[], int szKeys) { - return spt ? keysHashTable(spt, (StgWord*)keys, szKeys) : 0; + if (spt) { + ACQUIRE_LOCK(&spt_lock); + const int ret = keysHashTable(spt, (StgWord*)keys, szKeys); + RELEASE_LOCK(&spt_lock); + return ret; + } else + return 0; } int hs_spt_key_count() { @@ -51,7 +95,10 @@ int hs_spt_key_count() { void exitStaticPtrTable() { if (spt) { - freeHashTable(spt, NULL); + freeHashTable(spt, freeSptEntry); spt = NULL; +#ifdef THREADED_RTS + closeMutex(&spt_lock); +#endif } } diff --git a/testsuite/tests/codeGen/should_run/CgStaticPointers.hs b/testsuite/tests/codeGen/should_run/CgStaticPointers.hs index 5576f431e8..f7776b0c06 100644 --- a/testsuite/tests/codeGen/should_run/CgStaticPointers.hs +++ b/testsuite/tests/codeGen/should_run/CgStaticPointers.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE StaticPointers #-} -- | A test to use symbols produced by the static form. @@ -9,15 +10,15 @@ import GHC.StaticPtr main :: IO () main = do - print $ lookupKey (static (id . id)) (1 :: Int) - print $ lookupKey (static method :: StaticPtr (Char -> Int)) 'a' + lookupKey (static (id . id)) >>= \f -> print $ f (1 :: Int) + lookupKey (static method :: StaticPtr (Char -> Int)) >>= \f -> print $ f 'a' print $ deRefStaticPtr (static g) print $ deRefStaticPtr p0 'a' print $ deRefStaticPtr (static t_field) $ T 'b' -lookupKey :: StaticPtr a -> a -lookupKey p = case unsafeLookupStaticPtr (staticKey p) of - Just p -> deRefStaticPtr p +lookupKey :: StaticPtr a -> IO a +lookupKey p = unsafeLookupStaticPtr (staticKey p) >>= \case + Just p -> return $ deRefStaticPtr p Nothing -> error $ "couldn't find " ++ show (staticPtrInfo p) g :: String diff --git a/testsuite/tests/rts/GcStaticPointers.hs b/testsuite/tests/rts/GcStaticPointers.hs index c498af5842..3bf02d9da9 100644 --- a/testsuite/tests/rts/GcStaticPointers.hs +++ b/testsuite/tests/rts/GcStaticPointers.hs @@ -26,7 +26,7 @@ main = do print z performGC threadDelay 1000000 - let Just p = unsafeLookupStaticPtr nats_key + Just p <- unsafeLookupStaticPtr nats_key print (deRefStaticPtr (unsafeCoerce p) !! 800 :: Integer) -- Uncommenting the next line keeps 'nats' alive and would prevent a segfault -- if 'nats' were garbage collected. diff --git a/testsuite/tests/rts/ListStaticPointers.hs b/testsuite/tests/rts/ListStaticPointers.hs index 5ddb63613f..01c747d45a 100644 --- a/testsuite/tests/rts/ListStaticPointers.hs +++ b/testsuite/tests/rts/ListStaticPointers.hs @@ -7,10 +7,12 @@ import Data.List ((\\)) import GHC.StaticPtr import System.Exit -main = when (not $ eqBags staticPtrKeys expected) $ do - print ("expected", expected) - print ("found", staticPtrKeys) - exitFailure +main = do + found <- staticPtrKeys + when (not $ eqBags found expected) $ do + print ("expected", expected) + print ("found", found) + exitFailure where expected = |