summaryrefslogtreecommitdiff
path: root/testsuite/tests/concurrent/prog003/MVarListLockCoupling.hs
blob: 0820ccd0d2c967e977a0e6a0f61d409b9a4ca0a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
{-# LANGUAGE BangPatterns,CPP #-}
module MVarListLockCoupling where

import Control.Monad
import Data.IORef
import Control.Concurrent
import Control.Concurrent.Chan
import System.Environment
import Data.Time


-- #define USE_UNPACK
-- #define USE_STRICT

#if defined(USE_UNPACK)
#define UNPACK(p) {-# UNPACK #-} !(p)
#elif defined(USE_STRICT)
#define UNPACK(p) !(p)
#else
#define UNPACK(p) p
#endif


data List a = Node { val :: a
                   , next :: UNPACK(MVar (List a)) }
            | Null
            | Head { next :: UNPACK(MVar (List a)) } deriving Eq

data ListHandle a = ListHandle { headList :: IORef (MVar (List a)), 
                             tailList :: IORef (MVar (List a)) }


-- we assume a static head pointer, pointing to the first node which must be Head
-- the deleted field of Head is always False, it's only there to make some of the code
-- more uniform
-- tail points to the last node which must be Null

-- head is static, therefore IORef
-- tail will be adjusted, therefore MVar


type Iterator a = IORef (MVar (List a))

-- iterators are private

-------------------------------------------
-- auxilliary functions



while b cmd = if b then do {cmd; while b cmd}
              else return ()

repeatUntil cmd = do { b <- cmd; if b then return ()
                                  else repeatUntil cmd }

atomCAS :: Eq a => IORef a -> a -> a -> IO Bool
atomCAS ptr old new =
   atomicModifyIORef ptr (\ cur -> if cur == old
                                   then (new, True)
                                   else (cur, False))

atomicWrite :: IORef a -> a -> IO ()
atomicWrite ptr x =
   atomicModifyIORef ptr (\ _ -> (x,()))


----------------------------------------------
-- functions operating on lists


-- we create a new list
newList :: IO (ListHandle a)
newList = 
   do null <- newMVar Null
      hd <- newMVar (Head {next = null })
      hdPtr <- newIORef hd
      tailPtr <- newIORef null
      return (ListHandle {headList = hdPtr, tailList = tailPtr})


-- we add a new node, by overwriting the null tail node
-- we only need to adjust tailList but not headList because
-- of the static Head
-- we return the location of the newly added node
addToTail :: Eq a => ListHandle a -> a -> IO ()
addToTail (ListHandle {tailList = tailPtrPtr}) x =
   do null <- newMVar Null
      tailPtr <- readIORef tailPtrPtr
      takeMVar tailPtr
      writeIORef tailPtrPtr null 
      putMVar tailPtr (Node {val = x, next = null})


find :: Eq a => ListHandle a -> a -> IO Bool
find (ListHandle { headList = head }) x =
  let go prevPtr prevNode =
           do let curPtr = next prevNode -- head/node/delnode have all next
              curNode <- takeMVar curPtr
              case curNode of
                Node {val = y, next = nextNode } ->
                   if (x == y)
                   then -- node found 
                      do putMVar prevPtr prevNode
                         putMVar curPtr curNode
                         return True
                   else 
                      do putMVar prevPtr prevNode
                         go curPtr curNode -- continue
                Null -> do putMVar prevPtr prevNode
                           putMVar curPtr curNode
                           return False -- reached end of list
  in do startPtr <- readIORef head
        startNode <- takeMVar startPtr
        go startPtr startNode

delete :: Eq a => ListHandle a -> a -> IO Bool
delete (ListHandle { headList = head }) x =
  let go prevPtr prevNode =
        do do let curPtr = next prevNode -- head/node/delnode have all next
              curNode <- takeMVar curPtr
              case curNode of
                Node {val = y, next = nextNode } ->
                   if (x == y)
                   then -- delink node
                     do case prevNode of
                          Node {} -> do putMVar prevPtr (Node {val = val prevNode, 
                                                               next = nextNode})
                                        putMVar curPtr curNode
                                        return True
                          Head {} -> do putMVar prevPtr (Head {next = nextNode})
                                        putMVar curPtr curNode
                                        return True
                   else do putMVar prevPtr prevNode                         
                           go curPtr curNode -- continue
                Null -> do putMVar curPtr curNode
                           putMVar prevPtr prevNode                         
                           return False -- reached end of list

  in do startPtr <- readIORef head
        startNode <- takeMVar startPtr
        go startPtr startNode



--printing and counting

printList :: Show a => ListHandle a -> IO ()
printList (ListHandle {headList = ptrPtr}) =
  do startptr <- (
          do ptr <- readIORef ptrPtr
             Head {next = startptr} <- readMVar ptr
             return startptr)
     printListHelp startptr


printListHelp :: Show a => MVar (List a) -> IO ()
printListHelp curNodePtr =
   do { curNode <- readMVar curNodePtr
      ; case curNode of
          Null -> putStr "Nil"
          Node {val = curval, next = curnext} ->
             do { putStr (show curval  ++ " -> ")
                ;  printListHelp curnext }
      } 

cntList :: Show a => ListHandle a -> IO Int
cntList (ListHandle {headList = ptrPtr}) =
  do startptr <- (
          do ptr <- readIORef ptrPtr
             Head {next = startptr} <- readMVar ptr
             return startptr)
     cntListHelp startptr 0


cntListHelp :: Show a => MVar (List a) -> Int -> IO Int
cntListHelp curNodePtr i =
   do { curNode <- readMVar curNodePtr
      ; case curNode of
          Null -> return i
          Node {val = curval, next = curnext} -> 
                cntListHelp curnext (i+1)
      }