summaryrefslogtreecommitdiff
path: root/testsuite/tests/dph/smvm/Main.hs
blob: e30938bc2138246d1d5e88c54c64b5ca95946165 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
{-# LANGUAGE TypeOperators #-}

import SMVMVect (smvm)

import Control.Exception (evaluate)
import System.IO
import System.Environment

import qualified Data.Array.Parallel.Unlifted as U
import Data.Array.Parallel.Prelude
import Data.Array.Parallel.PArray as P


-- Load sparse matrix from a file
loadSM 	:: String 
	-> IO (PArray (PArray (Int, Double)), PArray Double)

loadSM s 
  = do
      (segd, m, v) <- loadSM' s
      return $ (nestUSegdPA' segd (fromUArrPA_2' m), fromUArrPA' v)


loadSM' :: String 
	-> IO	( U.Segd
		, U.Array (Int, Double)
		, U.Array Double)
loadSM' fname =
  do
    h <- openBinaryFile fname ReadMode
    lengths <- U.hGet h
    indices <- U.hGet h
    values  <- U.hGet h
    dv      <- U.hGet h
    let segd = U.lengthsToSegd lengths
        m    = U.zip indices values
    evaluate lengths
    evaluate indices
    evaluate values
    evaluate dv
    return (segd, m, dv)

main 
 = do	[inFile, outFile]	<- getArgs
	(m, v)			<- loadSM inFile
	let result		= smvm m v

	-- ignore wibbles in low-order bits
	let output
		=  (unlines
			$ map (take 12)
			$ map show
			$ P.toList result)
		++ ("SUM = "
			++ (take 12 $ show $ sum $ P.toList result)
			++ "\n")

	-- check our result against the provided outFile
	outputCheck <- readFile outFile
	print $	output == outputCheck