diff options
author | Peter Trommler <ptrommler@acm.org> | 2021-04-17 17:59:44 +0200 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2021-08-02 04:11:27 -0400 |
commit | b4d39adbb5884c764c6c11b2614a340c78cc078e (patch) | |
tree | 57eb45d9078c90c34f8743b961bf87789e292ae8 | |
parent | 7e8c578ed9d3469d6a5c1481f9482982c42f10ea (diff) | |
download | haskell-b4d39adbb5884c764c6c11b2614a340c78cc078e.tar.gz |
PrimOps: Add CAS op for all int sizes
PPC NCG: Implement CAS inline for 32 and 64 bit
testsuite: Add tests for smaller atomic CAS
X86 NCG: Catch calls to CAS C fallback
Primops: Add atomicCasWord[8|16|32|64]Addr#
Add tests for atomicCasWord[8|16|32|64]Addr#
Add changelog entry for new primops
X86 NCG: Fix MO-Cmpxchg W64 on 32-bit arch
ghc-prim: 64-bit CAS C fallback on all archs
-rw-r--r-- | compiler/GHC/Builtin/primops.txt.pp | 96 | ||||
-rw-r--r-- | compiler/GHC/CmmToAsm/PPC/CodeGen.hs | 33 | ||||
-rw-r--r-- | compiler/GHC/CmmToAsm/X86/CodeGen.hs | 7 | ||||
-rw-r--r-- | compiler/GHC/StgToCmm/Prim.hs | 18 | ||||
-rw-r--r-- | includes/stg/MiscClosures.h | 4 | ||||
-rw-r--r-- | libraries/ghc-prim/cbits/atomic.c | 2 | ||||
-rw-r--r-- | libraries/ghc-prim/changelog.md | 11 | ||||
-rw-r--r-- | rts/PrimOps.cmm | 52 | ||||
-rw-r--r-- | rts/RtsSymbols.c | 4 | ||||
-rw-r--r-- | rts/package.conf.in | 4 | ||||
-rw-r--r-- | rts/rts.cabal.in | 4 | ||||
-rw-r--r-- | testsuite/tests/concurrent/should_run/AtomicPrimops.hs | 242 | ||||
-rw-r--r-- | testsuite/tests/concurrent/should_run/AtomicPrimops.stdout | 8 |
13 files changed, 473 insertions, 12 deletions
diff --git a/compiler/GHC/Builtin/primops.txt.pp b/compiler/GHC/Builtin/primops.txt.pp index 5f5cd64cfa..b07c344e18 100644 --- a/compiler/GHC/Builtin/primops.txt.pp +++ b/compiler/GHC/Builtin/primops.txt.pp @@ -1927,6 +1927,46 @@ primop CasByteArrayOp_Int "casIntArray#" GenPrimOp with has_side_effects = True can_fail = True +primop CasByteArrayOp_Int8 "casInt8Array#" GenPrimOp + MutableByteArray# s -> Int# -> Int8# -> Int8# -> State# s -> (# State# s, Int8# #) + {Given an array, an offset in bytes, the expected old value, and + the new value, perform an atomic compare and swap i.e. write the new + value if the current value matches the provided old value. Returns + the value of the element before the operation. Implies a full memory + barrier.} + with has_side_effects = True + can_fail = True + +primop CasByteArrayOp_Int16 "casInt16Array#" GenPrimOp + MutableByteArray# s -> Int# -> Int16# -> Int16# -> State# s -> (# State# s, Int16# #) + {Given an array, an offset in 16 bit units, the expected old value, and + the new value, perform an atomic compare and swap i.e. write the new + value if the current value matches the provided old value. Returns + the value of the element before the operation. Implies a full memory + barrier.} + with has_side_effects = True + can_fail = True + +primop CasByteArrayOp_Int32 "casInt32Array#" GenPrimOp + MutableByteArray# s -> Int# -> Int32# -> Int32# -> State# s -> (# State# s, Int32# #) + {Given an array, an offset in 32 bit units, the expected old value, and + the new value, perform an atomic compare and swap i.e. write the new + value if the current value matches the provided old value. Returns + the value of the element before the operation. Implies a full memory + barrier.} + with has_side_effects = True + can_fail = True + +primop CasByteArrayOp_Int64 "casInt64Array#" GenPrimOp + MutableByteArray# s -> Int# -> INT64 -> INT64 -> State# s -> (# State# s, INT64 #) + {Given an array, an offset in 64 bit units, the expected old value, and + the new value, perform an atomic compare and swap i.e. write the new + value if the current value matches the provided old value. Returns + the value of the element before the operation. Implies a full memory + barrier.} + with has_side_effects = True + can_fail = True + primop FetchAddByteArrayOp_Int "fetchAddIntArray#" GenPrimOp MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #) {Given an array, and offset in machine words, and a value to add, @@ -2387,6 +2427,62 @@ primop CasAddrOp_Word "atomicCasWordAddr#" GenPrimOp with has_side_effects = True can_fail = True +primop CasAddrOp_Word8 "atomicCasWord8Addr#" GenPrimOp + Addr# -> Word8# -> Word8# -> State# s -> (# State# s, Word8# #) + { Compare and swap on a 8 bit-sized and aligned memory location. + + Use as: \s -> atomicCasWordAddr8# location expected desired s + + This version always returns the old value read. This follows the normal + protocol for CAS operations (and matches the underlying instruction on + most architectures). + + Implies a full memory barrier.} + with has_side_effects = True + can_fail = True + +primop CasAddrOp_Word16 "atomicCasWord16Addr#" GenPrimOp + Addr# -> Word16# -> Word16# -> State# s -> (# State# s, Word16# #) + { Compare and swap on a 16 bit-sized and aligned memory location. + + Use as: \s -> atomicCasWordAddr16# location expected desired s + + This version always returns the old value read. This follows the normal + protocol for CAS operations (and matches the underlying instruction on + most architectures). + + Implies a full memory barrier.} + with has_side_effects = True + can_fail = True + +primop CasAddrOp_Word32 "atomicCasWord32Addr#" GenPrimOp + Addr# -> Word32# -> Word32# -> State# s -> (# State# s, Word32# #) + { Compare and swap on a 32 bit-sized and aligned memory location. + + Use as: \s -> atomicCasWordAddr32# location expected desired s + + This version always returns the old value read. This follows the normal + protocol for CAS operations (and matches the underlying instruction on + most architectures). + + Implies a full memory barrier.} + with has_side_effects = True + can_fail = True + +primop CasAddrOp_Word64 "atomicCasWord64Addr#" GenPrimOp + Addr# -> WORD64 -> WORD64 -> State# s -> (# State# s, WORD64 #) + { Compare and swap on a 64 bit-sized and aligned memory location. + + Use as: \s -> atomicCasWordAddr64# location expected desired s + + This version always returns the old value read. This follows the normal + protocol for CAS operations (and matches the underlying instruction on + most architectures). + + Implies a full memory barrier.} + with has_side_effects = True + can_fail = True + primop FetchAddAddrOp_Word "fetchAddWordAddr#" GenPrimOp Addr# -> Word# -> State# s -> (# State# s, Word# #) {Given an address, and a value to add, diff --git a/compiler/GHC/CmmToAsm/PPC/CodeGen.hs b/compiler/GHC/CmmToAsm/PPC/CodeGen.hs index 67bc3d9bdb..1c3b244980 100644 --- a/compiler/GHC/CmmToAsm/PPC/CodeGen.hs +++ b/compiler/GHC/CmmToAsm/PPC/CodeGen.hs @@ -1221,7 +1221,38 @@ genCCall (PrimTarget (MO_AtomicRead width)) [dst] [addr] genCCall (PrimTarget (MO_AtomicWrite width)) [] [addr, val] = do code <- assignMem_IntCode (intFormat width) addr val - return $ unitOL(HWSYNC) `appOL` code + return $ unitOL HWSYNC `appOL` code + +genCCall (PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] + | width == W32 || width == W64 + = do + platform <- getPlatform + (old_reg, old_code) <- getSomeReg old + (new_reg, new_code) <- getSomeReg new + (addr_reg, addr_code) <- getSomeReg addr + lbl_retry <- getBlockIdNat + lbl_eq <- getBlockIdNat + lbl_end <- getBlockIdNat + let reg_dst = getRegisterReg platform (CmmLocal dst) + code = toOL + [ HWSYNC + , BCC ALWAYS lbl_retry Nothing + , NEWBLOCK lbl_retry + , LDR format reg_dst (AddrRegReg r0 addr_reg) + , CMP format reg_dst (RIReg old_reg) + , BCC NE lbl_end Nothing + , BCC ALWAYS lbl_eq Nothing + , NEWBLOCK lbl_eq + , STC format new_reg (AddrRegReg r0 addr_reg) + , BCC NE lbl_retry Nothing + , BCC ALWAYS lbl_end Nothing + , NEWBLOCK lbl_end + , ISYNC + ] + return $ addr_code `appOL` new_code `appOL` old_code `appOL` code + where + format = intFormat width + genCCall (PrimTarget (MO_Clz width)) [dst] [src] = do platform <- getPlatform diff --git a/compiler/GHC/CmmToAsm/X86/CodeGen.hs b/compiler/GHC/CmmToAsm/X86/CodeGen.hs index 5e7c261cbb..1ab24c4a25 100644 --- a/compiler/GHC/CmmToAsm/X86/CodeGen.hs +++ b/compiler/GHC/CmmToAsm/X86/CodeGen.hs @@ -2595,10 +2595,11 @@ genCCall' _ _ (PrimTarget (MO_AtomicWrite width)) [] [addr, val] _ = do code <- assignMem_IntCode (intFormat width) addr val return $ code `snocOL` MFENCE -genCCall' _ is32Bit (PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] _ = do +genCCall' _ is32Bit (PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] _ -- On x86 we don't have enough registers to use cmpxchg with a -- complicated addressing mode, so on that architecture we -- pre-compute the address first. + | not (is32Bit && width == W64) = do Amode amode addr_code <- getSimpleAmode is32Bit addr newval <- getNewRegNat format newval_code <- getAnyReg new @@ -3441,7 +3442,9 @@ outOfLineCmmOp bid mop res args MO_AtomicRMW _ _ -> fsLit "atomicrmw" MO_AtomicRead _ -> fsLit "atomicread" MO_AtomicWrite _ -> fsLit "atomicwrite" - MO_Cmpxchg _ -> fsLit "cmpxchg" + MO_Cmpxchg w -> cmpxchgLabel w -- for W64 on 32-bit + -- TODO: implement + -- cmpxchg8b instr MO_Xchg _ -> should_be_inline MO_UF_Conv _ -> unsupported diff --git a/compiler/GHC/StgToCmm/Prim.hs b/compiler/GHC/StgToCmm/Prim.hs index d61880a0e2..290ace9f01 100644 --- a/compiler/GHC/StgToCmm/Prim.hs +++ b/compiler/GHC/StgToCmm/Prim.hs @@ -872,6 +872,14 @@ emitPrimOp dflags primop = case primop of emitPrimCall [res] (MO_Cmpxchg (wordWidth platform)) [dst, expected, new] CasAddrOp_Word -> \[dst, expected, new] -> opIntoRegs $ \[res] -> emitPrimCall [res] (MO_Cmpxchg (wordWidth platform)) [dst, expected, new] + CasAddrOp_Word8 -> \[dst, expected, new] -> opIntoRegs $ \[res] -> + emitPrimCall [res] (MO_Cmpxchg W8) [dst, expected, new] + CasAddrOp_Word16 -> \[dst, expected, new] -> opIntoRegs $ \[res] -> + emitPrimCall [res] (MO_Cmpxchg W16) [dst, expected, new] + CasAddrOp_Word32 -> \[dst, expected, new] -> opIntoRegs $ \[res] -> + emitPrimCall [res] (MO_Cmpxchg W32) [dst, expected, new] + CasAddrOp_Word64 -> \[dst, expected, new] -> opIntoRegs $ \[res] -> + emitPrimCall [res] (MO_Cmpxchg W64) [dst, expected, new] -- SIMD primops (VecBroadcastOp vcat n w) -> \[e] -> opIntoRegs $ \[res] -> do @@ -1075,6 +1083,14 @@ emitPrimOp dflags primop = case primop of doAtomicWriteByteArray mba ix (bWord platform) val CasByteArrayOp_Int -> \[mba, ix, old, new] -> opIntoRegs $ \[res] -> doCasByteArray res mba ix (bWord platform) old new + CasByteArrayOp_Int8 -> \[mba, ix, old, new] -> opIntoRegs $ \[res] -> + doCasByteArray res mba ix b8 old new + CasByteArrayOp_Int16 -> \[mba, ix, old, new] -> opIntoRegs $ \[res] -> + doCasByteArray res mba ix b16 old new + CasByteArrayOp_Int32 -> \[mba, ix, old, new] -> opIntoRegs $ \[res] -> + doCasByteArray res mba ix b32 old new + CasByteArrayOp_Int64 -> \[mba, ix, old, new] -> opIntoRegs $ \[res] -> + doCasByteArray res mba ix b64 old new -- The rest just translate straightforwardly @@ -3092,7 +3108,7 @@ doCasByteArray doCasByteArray res mba idx idx_ty old new = do profile <- getProfile platform <- getPlatform - let width = (typeWidth idx_ty) + let width = typeWidth idx_ty addr = cmmIndexOffExpr platform (arrWordsHdrSize profile) width mba idx emitPrimCall diff --git a/includes/stg/MiscClosures.h b/includes/stg/MiscClosures.h index d8aefd8035..30469c603d 100644 --- a/includes/stg/MiscClosures.h +++ b/includes/stg/MiscClosures.h @@ -444,6 +444,10 @@ RTS_FUN_DECL(stg_shrinkMutableByteArrayzh); RTS_FUN_DECL(stg_resizzeMutableByteArrayzh); RTS_FUN_DECL(stg_shrinkSmallMutableArrayzh); RTS_FUN_DECL(stg_casIntArrayzh); +RTS_FUN_DECL(stg_casInt8Arrayzh); +RTS_FUN_DECL(stg_casInt16Arrayzh); +RTS_FUN_DECL(stg_casInt32Arrayzh); +RTS_FUN_DECL(stg_casInt64Arrayzh); RTS_FUN_DECL(stg_newArrayzh); RTS_FUN_DECL(stg_newArrayArrayzh); RTS_FUN_DECL(stg_copyArrayzh); diff --git a/libraries/ghc-prim/cbits/atomic.c b/libraries/ghc-prim/cbits/atomic.c index 18451016ea..af26e16268 100644 --- a/libraries/ghc-prim/cbits/atomic.c +++ b/libraries/ghc-prim/cbits/atomic.c @@ -309,14 +309,12 @@ hs_cmpxchg32(StgWord x, StgWord old, StgWord new) return __sync_val_compare_and_swap((volatile StgWord32 *) x, (StgWord32) old, (StgWord32) new); } -#if WORD_SIZE_IN_BITS == 64 extern StgWord hs_cmpxchg64(StgWord x, StgWord64 old, StgWord64 new); StgWord hs_cmpxchg64(StgWord x, StgWord64 old, StgWord64 new) { return __sync_val_compare_and_swap((volatile StgWord64 *) x, old, new); } -#endif // Atomic exchange operations diff --git a/libraries/ghc-prim/changelog.md b/libraries/ghc-prim/changelog.md index 5d27ec197a..63f2881dcb 100644 --- a/libraries/ghc-prim/changelog.md +++ b/libraries/ghc-prim/changelog.md @@ -87,6 +87,17 @@ - `extend{Int,Word}<N>#` -> `extend<N>To{Int,Word}#` - `narrow{Int,Word}<N>#` -> `intTo{Int,Word}<N>#` +- Add primops for atomic compare and swap for sizes other that wordsize: + + casInt8Array# :: MutableByteArray# s -> Int# -> Int8# -> Int8# -> State# s -> (# State# s, Int8# #) + casInt16Array# :: MutableByteArray# s -> Int# -> Int16# -> Int16# -> State# s -> (# State# s, Int16# #) + casInt32Array# :: MutableByteArray# s -> Int# -> Int32# -> Int32# -> State# s -> (# State# s, Int32# #) + casInt64Array# :: MutableByteArray# s -> Int# -> Int64# -> Int64# -> State# s -> (# State# s, Int64# #) + atomicCasWord8Addr# :: Addr# -> Word8# -> Word8# -> State# s -> (# State# s, Word8# #) + atomicCasWord16Addr# :: Addr# -> Word16# -> Word16# -> State# s -> (# State# s, Word16# #) + atomicCasWord32Addr# :: Addr# -> Word32# -> Word32# -> State# s -> (# State# s, Word32# #) + atomicCasWord64Addr# :: Addr# -> WORD64 -> WORD64 -> State# s -> (# State# s, WORD64 #) + ## 0.7.0 (edit as necessary) - Shipped with GHC 9.0.1 diff --git a/rts/PrimOps.cmm b/rts/PrimOps.cmm index 85c708cf92..8f99105b18 100644 --- a/rts/PrimOps.cmm +++ b/rts/PrimOps.cmm @@ -264,6 +264,58 @@ stg_casIntArrayzh( gcptr arr, W_ ind, W_ old, W_ new ) } +stg_casInt8Arrayzh( gcptr arr, W_ ind, I8 old, I8 new ) +/* MutableByteArray# s -> Int# -> Int8# -> Int8# -> State# s -> (# State# s, Int8# #) */ +{ + W_ p; + I8 h; + + p = arr + SIZEOF_StgArrBytes + ind; + (h) = prim %cmpxchg8(p, old, new); + + return(h); +} + + +stg_casInt16Arrayzh( gcptr arr, W_ ind, I16 old, I16 new ) +/* MutableByteArray# s -> Int# -> Int16# -> Int16# -> State# s -> (# State# s, Int16# #) */ +{ + W_ p; + I16 h; + + p = arr + SIZEOF_StgArrBytes + ind*2; + (h) = prim %cmpxchg16(p, old, new); + + return(h); +} + + +stg_casInt32Arrayzh( gcptr arr, W_ ind, I32 old, I32 new ) +/* MutableByteArray# s -> Int# -> Int32# -> Int32# -> State# s -> (# State# s, Int32# #) */ +{ + W_ p; + I32 h; + + p = arr + SIZEOF_StgArrBytes + ind*4; + (h) = prim %cmpxchg32(p, old, new); + + return(h); +} + + +stg_casInt64Arrayzh( gcptr arr, W_ ind, I64 old, I64 new ) +/* MutableByteArray# s -> Int# -> Int64# -> Int64# -> State# s -> (# State# s, Int64# #) */ +{ + W_ p; + I64 h; + + p = arr + SIZEOF_StgArrBytes + ind*8; + (h) = prim %cmpxchg64(p, old, new); + + return(h); +} + + stg_newArrayzh ( W_ n /* words */, gcptr init ) { W_ words, size, p; diff --git a/rts/RtsSymbols.c b/rts/RtsSymbols.c index 678527e328..38e1b8071c 100644 --- a/rts/RtsSymbols.c +++ b/rts/RtsSymbols.c @@ -721,6 +721,10 @@ SymI_HasProto(stg_newBCOzh) \ SymI_HasProto(stg_newByteArrayzh) \ SymI_HasProto(stg_casIntArrayzh) \ + SymI_HasProto(stg_casInt8Arrayzh) \ + SymI_HasProto(stg_casInt16Arrayzh) \ + SymI_HasProto(stg_casInt32Arrayzh) \ + SymI_HasProto(stg_casInt64Arrayzh) \ SymI_HasProto(stg_newMVarzh) \ SymI_HasProto(stg_newMutVarzh) \ SymI_HasProto(stg_newTVarzh) \ diff --git a/rts/package.conf.in b/rts/package.conf.in index 9bc48d57ca..b0796595ff 100644 --- a/rts/package.conf.in +++ b/rts/package.conf.in @@ -170,9 +170,7 @@ ld-options: , "-Wl,-u,_hs_cmpxchg8" , "-Wl,-u,_hs_cmpxchg16" , "-Wl,-u,_hs_cmpxchg32" -#if WORD_SIZE_IN_BITS == 64 , "-Wl,-u,_hs_cmpxchg64" -#endif , "-Wl,-u,_hs_xchg8" , "-Wl,-u,_hs_xchg16" , "-Wl,-u,_hs_xchg32" @@ -284,9 +282,7 @@ ld-options: , "-Wl,-u,hs_cmpxchg8" , "-Wl,-u,hs_cmpxchg16" , "-Wl,-u,hs_cmpxchg32" -#if WORD_SIZE_IN_BITS == 64 , "-Wl,-u,hs_cmpxchg64" -#endif , "-Wl,-u,hs_xchg8" , "-Wl,-u,hs_xchg16" , "-Wl,-u,hs_xchg32" diff --git a/rts/rts.cabal.in b/rts/rts.cabal.in index 3ceae1cbdc..a08e007c2a 100644 --- a/rts/rts.cabal.in +++ b/rts/rts.cabal.in @@ -220,7 +220,6 @@ library "-Wl,-u,_hs_atomic_nand64" "-Wl,-u,_hs_atomic_or64" "-Wl,-u,_hs_atomic_xor64" - "-Wl,-u,_hs_cmpxchg64" "-Wl,-u,_hs_atomicread64" "-Wl,-u,_hs_atomicwrite64" else @@ -231,7 +230,6 @@ library "-Wl,-u,hs_atomic_nand64" "-Wl,-u,hs_atomic_or64" "-Wl,-u,hs_atomic_xor64" - "-Wl,-u,hs_cmpxchg64" "-Wl,-u,hs_atomicread64" "-Wl,-u,hs_atomicwrite64" if flag(leading-underscore) @@ -299,6 +297,7 @@ library "-Wl,-u,_hs_cmpxchg8" "-Wl,-u,_hs_cmpxchg16" "-Wl,-u,_hs_cmpxchg32" + "-Wl,-u,_hs_cmpxchg64" "-Wl,-u,_hs_xchg8" "-Wl,-u,_hs_xchg16" "-Wl,-u,_hs_xchg32" @@ -380,6 +379,7 @@ library "-Wl,-u,hs_cmpxchg8" "-Wl,-u,hs_cmpxchg16" "-Wl,-u,hs_cmpxchg32" + "-Wl,-u,hs_cmpxchg64" "-Wl,-u,hs_xchg8" "-Wl,-u,hs_xchg16" "-Wl,-u,hs_xchg32" diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.hs b/testsuite/tests/concurrent/should_run/AtomicPrimops.hs index 83e5b514f0..b8adb3c621 100644 --- a/testsuite/tests/concurrent/should_run/AtomicPrimops.hs +++ b/testsuite/tests/concurrent/should_run/AtomicPrimops.hs @@ -10,7 +10,9 @@ import Foreign.Marshal.Alloc import Foreign.Ptr import Foreign.Storable import GHC.Exts +import GHC.Int import GHC.IO +import GHC.Word -- | Iterations per worker. iters :: Word @@ -25,6 +27,10 @@ main = do fetchOrTest fetchXorTest casTest + cas8Test + cas16Test + cas32Test + cas64Test readWriteTest -- Addr# fetchAddSubAddrTest @@ -33,6 +39,10 @@ main = do fetchOrAddrTest fetchXorAddrTest casAddrTest + casAddr8Test + casAddr16Test + casAddr32Test + casAddr64Test readWriteAddrTest loop :: Word -> IO () -> IO () @@ -202,6 +212,62 @@ casTest = do old' <- casIntArray mba ix old (old + n) when (old /= old') $ add mba ix n +cas8Test :: IO () +cas8Test = do + tot <- race 0 + (\ mba -> loop iters $ add mba 0 1) + (\ mba -> loop iters $ add mba 0 2) + assertEq (fromIntegral ((3 * fromIntegral iters) :: Word8)) tot "cas8Test" + where + -- Fetch-and-add implemented using CAS. + add :: MByteArray -> Int -> Int8 -> IO () + add mba ix n = do + old <- readInt8Array mba ix + old' <- casInt8Array mba ix old (old + n) + when (old /= old') $ add mba ix n + +cas16Test :: IO () +cas16Test = do + tot <- race 0 + (\ mba -> loop iters $ add mba 0 1) + (\ mba -> loop iters $ add mba 0 2) + assertEq (fromIntegral ((3 * fromIntegral iters) :: Word16)) tot "cas16Test" + where + -- Fetch-and-add implemented using CAS. + add :: MByteArray -> Int -> Int16 -> IO () + add mba ix n = do + old <- readInt16Array mba ix + old' <- casInt16Array mba ix old (old + n) + when (old /= old') $ add mba ix n + +cas32Test :: IO () +cas32Test = do + tot <- race 0 + (\ mba -> loop iters $ add mba 0 1) + (\ mba -> loop iters $ add mba 0 2) + assertEq (fromIntegral ((3 * fromIntegral iters) :: Word32)) tot "cas32Test" + where + -- Fetch-and-add implemented using CAS. + add :: MByteArray -> Int -> Int32 -> IO () + add mba ix n = do + old <- readInt32Array mba ix + old' <- casInt32Array mba ix old (old + n) + when (old /= old') $ add mba ix n + +cas64Test :: IO () +cas64Test = do + tot <- race 0 + (\ mba -> loop iters $ add mba 0 1) + (\ mba -> loop iters $ add mba 0 2) + assertEq (3 * fromIntegral iters) tot "cas64Test" + where + -- Fetch-and-add implemented using CAS. + add :: MByteArray -> Int -> Int64 -> IO () + add mba ix n = do + old <- readInt64Array mba ix + old' <- casInt64Array mba ix old (old + n) + when (old /= old') $ add mba ix n + -- | Test atomicCasWordAddr# by having two threads concurrently increment a -- counter, checking the sum at the end. casAddrTest :: IO () @@ -219,6 +285,69 @@ casAddrTest = do old' <- atomicCasWordPtr ptr old (old + n) when (old /= old') $ go old' +casAddr8Test :: IO () +casAddr8Test = do + tot <- race8Addr 0 + (\ addr -> loop iters $ add addr 1) + (\ addr -> loop iters $ add addr 2) + assertEq (fromIntegral (fromIntegral (3 * iters) :: Word8)) + (fromIntegral tot) "casAddr8Test" + where + -- Fetch-and-add implemented using CAS. + add :: Ptr Word8 -> Word8 -> IO () + add ptr n = peek ptr >>= go + where + go old = do + old' <- atomicCasWord8Ptr ptr old (old + n) + when (old /= old') $ go old' + +casAddr16Test :: IO () +casAddr16Test = do + tot <- race16Addr 0 + (\ addr -> loop iters $ add addr 1) + (\ addr -> loop iters $ add addr 2) + assertEq (fromIntegral (fromIntegral (3 * iters) :: Word16)) + (fromIntegral tot) "casAddr16Test" + where + -- Fetch-and-add implemented using CAS. + add :: Ptr Word16 -> Word16 -> IO () + add ptr n = peek ptr >>= go + where + go old = do + old' <- atomicCasWord16Ptr ptr old (old + n) + when (old /= old') $ go old' + +casAddr32Test :: IO () +casAddr32Test = do + tot <- race32Addr 0 + (\ addr -> loop iters $ add addr 1) + (\ addr -> loop iters $ add addr 2) + assertEq (fromIntegral (fromIntegral (3 * iters) :: Word32)) + (fromIntegral tot) "casAddr32Test" + where + -- Fetch-and-add implemented using CAS. + add :: Ptr Word32 -> Word32 -> IO () + add ptr n = peek ptr >>= go + where + go old = do + old' <- atomicCasWord32Ptr ptr old (old + n) + when (old /= old') $ go old' + +casAddr64Test :: IO () +casAddr64Test = do + tot <- race64Addr 0 + (\ addr -> loop iters $ add addr 1) + (\ addr -> loop iters $ add addr 2) + assertEq (3 * iters) (fromIntegral tot) "casAddr64Test" + where + -- Fetch-and-add implemented using CAS. + add :: Ptr Word64 -> Word64 -> IO () + add ptr n = peek ptr >>= go + where + go old = do + old' <- atomicCasWord64Ptr ptr old (old + n) + when (old /= old') $ go old' + -- | Tests atomic reads and writes by making sure that one thread sees -- updates that are done on another. This test isn't very good at the @@ -286,6 +415,62 @@ raceAddr n0 thread1 thread2 = do mapM_ takeMVar [done1, done2] peek ptr +race8Addr :: Word8 -- ^ Initial value of array element + -> (Ptr Word8 -> IO ()) -- ^ Thread 1 action + -> (Ptr Word8 -> IO ()) -- ^ Thread 2 action + -> IO Word8 -- ^ Final value of array element +race8Addr n0 thread1 thread2 = do + done1 <- newEmptyMVar + done2 <- newEmptyMVar + ptr <- castPtr <$> callocBytes (sizeOf (undefined :: Word8)) + poke ptr n0 + forkIO $ thread1 ptr >> putMVar done1 () + forkIO $ thread2 ptr >> putMVar done2 () + mapM_ takeMVar [done1, done2] + peek ptr + +race16Addr :: Word16 -- ^ Initial value of array element + -> (Ptr Word16 -> IO ()) -- ^ Thread 1 action + -> (Ptr Word16 -> IO ()) -- ^ Thread 2 action + -> IO Word16 -- ^ Final value of array element +race16Addr n0 thread1 thread2 = do + done1 <- newEmptyMVar + done2 <- newEmptyMVar + ptr <- castPtr <$> callocBytes (sizeOf (undefined :: Word16)) + poke ptr n0 + forkIO $ thread1 ptr >> putMVar done1 () + forkIO $ thread2 ptr >> putMVar done2 () + mapM_ takeMVar [done1, done2] + peek ptr + +race32Addr :: Word32 -- ^ Initial value of array element + -> (Ptr Word32 -> IO ()) -- ^ Thread 1 action + -> (Ptr Word32 -> IO ()) -- ^ Thread 2 action + -> IO Word32 -- ^ Final value of array element +race32Addr n0 thread1 thread2 = do + done1 <- newEmptyMVar + done2 <- newEmptyMVar + ptr <- castPtr <$> callocBytes (sizeOf (undefined :: Word32)) + poke ptr n0 + forkIO $ thread1 ptr >> putMVar done1 () + forkIO $ thread2 ptr >> putMVar done2 () + mapM_ takeMVar [done1, done2] + peek ptr + +race64Addr :: Word64 -- ^ Initial value of array element + -> (Ptr Word64 -> IO ()) -- ^ Thread 1 action + -> (Ptr Word64 -> IO ()) -- ^ Thread 2 action + -> IO Word64 -- ^ Final value of array element +race64Addr n0 thread1 thread2 = do + done1 <- newEmptyMVar + done2 <- newEmptyMVar + ptr <- castPtr <$> callocBytes (sizeOf (undefined :: Word64)) + poke ptr n0 + forkIO $ thread1 ptr >> putMVar done1 () + forkIO $ thread2 ptr >> putMVar done2 () + mapM_ takeMVar [done1, done2] + peek ptr + ------------------------------------------------------------------------ -- Test helper @@ -347,6 +532,26 @@ readIntArray (MBA mba#) (I# ix#) = IO $ \ s# -> case readIntArray# mba# ix# s# of (# s2#, n# #) -> (# s2#, I# n# #) +readInt8Array :: MByteArray -> Int -> IO Int8 +readInt8Array (MBA mba#) (I# ix#) = IO $ \ s# -> + case readInt8Array# mba# ix# s# of + (# s2#, n# #) -> (# s2#, I8# n# #) + +readInt16Array :: MByteArray -> Int -> IO Int16 +readInt16Array (MBA mba#) (I# ix#) = IO $ \ s# -> + case readInt16Array# mba# ix# s# of + (# s2#, n# #) -> (# s2#, I16# n# #) + +readInt32Array :: MByteArray -> Int -> IO Int32 +readInt32Array (MBA mba#) (I# ix#) = IO $ \ s# -> + case readInt32Array# mba# ix# s# of + (# s2#, n# #) -> (# s2#, I32# n# #) + +readInt64Array :: MByteArray -> Int -> IO Int64 +readInt64Array (MBA mba#) (I# ix#) = IO $ \ s# -> + case readInt64Array# mba# ix# s# of + (# s2#, n# #) -> (# s2#, I64# n# #) + atomicWriteIntArray :: MByteArray -> Int -> Int -> IO () atomicWriteIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# -> case atomicWriteIntArray# mba# ix# n# s# of @@ -362,6 +567,26 @@ casIntArray (MBA mba#) (I# ix#) (I# old#) (I# new#) = IO $ \ s# -> case casIntArray# mba# ix# old# new# s# of (# s2#, old2# #) -> (# s2#, I# old2# #) +casInt8Array :: MByteArray -> Int -> Int8 -> Int8 -> IO Int8 +casInt8Array (MBA mba#) (I# ix#) (I8# old#) (I8# new#) = IO $ \ s# -> + case casInt8Array# mba# ix# old# new# s# of + (# s2#, old2# #) -> (# s2#, I8# old2# #) + +casInt16Array :: MByteArray -> Int -> Int16 -> Int16 -> IO Int16 +casInt16Array (MBA mba#) (I# ix#) (I16# old#) (I16# new#) = IO $ \ s# -> + case casInt16Array# mba# ix# old# new# s# of + (# s2#, old2# #) -> (# s2#, I16# old2# #) + +casInt32Array :: MByteArray -> Int -> Int32 -> Int32 -> IO Int32 +casInt32Array (MBA mba#) (I# ix#) (I32# old#) (I32# new#) = IO $ \ s# -> + case casInt32Array# mba# ix# old# new# s# of + (# s2#, old2# #) -> (# s2#, I32# old2# #) + +casInt64Array :: MByteArray -> Int -> Int64 -> Int64 -> IO Int64 +casInt64Array (MBA mba#) (I# ix#) (I64# old#) (I64# new#) = IO $ \ s# -> + case casInt64Array# mba# ix# old# new# s# of + (# s2#, old2# #) -> (# s2#, I64# old2# #) + ------------------------------------------------------------------------ -- Wrappers around Addr# @@ -411,3 +636,20 @@ atomicCasWordPtr :: Ptr Word -> Word -> Word -> IO Word atomicCasWordPtr (Ptr addr#) (W# old#) (W# new#) = IO $ \ s# -> case atomicCasWordAddr# addr# old# new# s# of (# s2#, old2# #) -> (# s2#, W# old2# #) + +atomicCasWord8Ptr :: Ptr Word8 -> Word8 -> Word8 -> IO Word8 +atomicCasWord8Ptr (Ptr addr#) (W8# old#) (W8# new#) = IO $ \ s# -> + case atomicCasWord8Addr# addr# old# new# s# of + (# s2#, old2# #) -> (# s2#, W8# old2# #) +atomicCasWord16Ptr :: Ptr Word16 -> Word16 -> Word16 -> IO Word16 +atomicCasWord16Ptr (Ptr addr#) (W16# old#) (W16# new#) = IO $ \ s# -> + case atomicCasWord16Addr# addr# old# new# s# of + (# s2#, old2# #) -> (# s2#, W16# old2# #) +atomicCasWord32Ptr :: Ptr Word32 -> Word32 -> Word32 -> IO Word32 +atomicCasWord32Ptr (Ptr addr#) (W32# old#) (W32# new#) = IO $ \ s# -> + case atomicCasWord32Addr# addr# old# new# s# of + (# s2#, old2# #) -> (# s2#, W32# old2# #) +atomicCasWord64Ptr :: Ptr Word64 -> Word64 -> Word64 -> IO Word64 +atomicCasWord64Ptr (Ptr addr#) (W64# old#) (W64# new#) = IO $ \ s# -> + case atomicCasWord64Addr# addr# old# new# s# of + (# s2#, old2# #) -> (# s2#, W64# old2# #) diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout b/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout index b09c2a8eaa..055f6694a1 100644 --- a/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout +++ b/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout @@ -4,6 +4,10 @@ fetchNandTest: OK fetchOrTest: OK fetchXorTest: OK casTest: OK +cas8Test: OK +cas16Test: OK +cas32Test: OK +cas64Test: OK readWriteTest: OK fetchAddSubAddrTest: OK fetchAndAddrTest: OK @@ -11,4 +15,8 @@ fetchNandAddrTest: OK fetchOrAddrTest: OK fetchXorAddrTest: OK casAddrTest: OK +casAddr8Test: OK +casAddr16Test: OK +casAddr32Test: OK +casAddr64Test: OK readWriteAddrTest: OK |