summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/codeGen/StgCmmPrim.hs73
1 files changed, 59 insertions, 14 deletions
diff --git a/compiler/codeGen/StgCmmPrim.hs b/compiler/codeGen/StgCmmPrim.hs
index 5250c9378e..523fcb21f9 100644
--- a/compiler/codeGen/StgCmmPrim.hs
+++ b/compiler/codeGen/StgCmmPrim.hs
@@ -509,7 +509,8 @@ emitPrimOp _ [res] Word2DoubleOp [w] = emitPrimCall [res]
(MO_UF_Conv W64) [w]
-- SIMD primops
-emitPrimOp dflags [res] (VecBroadcastOp vcat n w) [e] =
+emitPrimOp dflags [res] (VecBroadcastOp vcat n w) [e] = do
+ checkVecCompatibility dflags vcat n w
doVecPackOp (vecElemInjectCast dflags vcat w) ty zeros (replicate n e) res
where
zeros :: CmmExpr
@@ -525,6 +526,7 @@ emitPrimOp dflags [res] (VecBroadcastOp vcat n w) [e] =
ty = vecVmmType vcat n w
emitPrimOp dflags [res] (VecPackOp vcat n w) es = do
+ checkVecCompatibility dflags vcat n w
when (length es /= n) $
panic "emitPrimOp: VecPackOp has wrong number of arguments"
doVecPackOp (vecElemInjectCast dflags vcat w) ty zeros es res
@@ -542,6 +544,7 @@ emitPrimOp dflags [res] (VecPackOp vcat n w) es = do
ty = vecVmmType vcat n w
emitPrimOp dflags res (VecUnpackOp vcat n w) [arg] = do
+ checkVecCompatibility dflags vcat n w
when (length res /= n) $
panic "emitPrimOp: VecUnpackOp has wrong number of results"
doVecUnpackOp (vecElemProjectCast dflags vcat w) ty arg res
@@ -549,49 +552,57 @@ emitPrimOp dflags res (VecUnpackOp vcat n w) [arg] = do
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp dflags [res] (VecInsertOp vcat n w) [v,e,i] =
+emitPrimOp dflags [res] (VecInsertOp vcat n w) [v,e,i] = do
+ checkVecCompatibility dflags vcat n w
doVecInsertOp (vecElemInjectCast dflags vcat w) ty v e i res
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecIndexByteArrayOp vcat n w) args =
+emitPrimOp dflags res (VecIndexByteArrayOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexByteArrayOp Nothing ty res args
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecReadByteArrayOp vcat n w) args =
+emitPrimOp dflags res (VecReadByteArrayOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexByteArrayOp Nothing ty res args
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecWriteByteArrayOp vcat n w) args =
+emitPrimOp dflags res (VecWriteByteArrayOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doWriteByteArrayOp Nothing ty res args
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecIndexOffAddrOp vcat n w) args =
+emitPrimOp dflags res (VecIndexOffAddrOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexOffAddrOp Nothing ty res args
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecReadOffAddrOp vcat n w) args =
+emitPrimOp dflags res (VecReadOffAddrOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexOffAddrOp Nothing ty res args
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecWriteOffAddrOp vcat n w) args =
+emitPrimOp dflags res (VecWriteOffAddrOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doWriteOffAddrOp Nothing ty res args
where
ty :: CmmType
ty = vecVmmType vcat n w
-emitPrimOp _ res (VecIndexScalarByteArrayOp vcat n w) args =
+emitPrimOp dflags res (VecIndexScalarByteArrayOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexByteArrayOpAs Nothing vecty ty res args
where
vecty :: CmmType
@@ -600,7 +611,8 @@ emitPrimOp _ res (VecIndexScalarByteArrayOp vcat n w) args =
ty :: CmmType
ty = vecCmmCat vcat w
-emitPrimOp _ res (VecReadScalarByteArrayOp vcat n w) args =
+emitPrimOp dflags res (VecReadScalarByteArrayOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexByteArrayOpAs Nothing vecty ty res args
where
vecty :: CmmType
@@ -609,13 +621,15 @@ emitPrimOp _ res (VecReadScalarByteArrayOp vcat n w) args =
ty :: CmmType
ty = vecCmmCat vcat w
-emitPrimOp _ res (VecWriteScalarByteArrayOp vcat _ w) args =
+emitPrimOp dflags res (VecWriteScalarByteArrayOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doWriteByteArrayOp Nothing ty res args
where
ty :: CmmType
ty = vecCmmCat vcat w
-emitPrimOp _ res (VecIndexScalarOffAddrOp vcat n w) args =
+emitPrimOp dflags res (VecIndexScalarOffAddrOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexOffAddrOpAs Nothing vecty ty res args
where
vecty :: CmmType
@@ -624,7 +638,8 @@ emitPrimOp _ res (VecIndexScalarOffAddrOp vcat n w) args =
ty :: CmmType
ty = vecCmmCat vcat w
-emitPrimOp _ res (VecReadScalarOffAddrOp vcat n w) args =
+emitPrimOp dflags res (VecReadScalarOffAddrOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doIndexOffAddrOpAs Nothing vecty ty res args
where
vecty :: CmmType
@@ -633,7 +648,8 @@ emitPrimOp _ res (VecReadScalarOffAddrOp vcat n w) args =
ty :: CmmType
ty = vecCmmCat vcat w
-emitPrimOp _ res (VecWriteScalarOffAddrOp vcat _ w) args =
+emitPrimOp dflags res (VecWriteScalarOffAddrOp vcat n w) args = do
+ checkVecCompatibility dflags vcat n w
doWriteOffAddrOp Nothing ty res args
where
ty :: CmmType
@@ -1220,6 +1236,35 @@ vecElemProjectCast dflags WordVec W32 = Just (mo_u_32ToWord dflags)
vecElemProjectCast _ WordVec W64 = Nothing
vecElemProjectCast _ _ _ = Nothing
+-- Check to make sure that we can generate code for the specified vector type
+-- given the current set of dynamic flags.
+checkVecCompatibility :: DynFlags -> PrimOpVecCat -> Length -> Width -> FCode ()
+checkVecCompatibility dflags vcat l w = do
+ when (hscTarget dflags /= HscLlvm) $ do
+ sorry $ unlines ["SIMD vector instructions require the LLVM back-end."
+ ,"Please use -fllvm."]
+ check vecWidth vcat l w
+ where
+ check :: Width -> PrimOpVecCat -> Length -> Width -> FCode ()
+ check W128 FloatVec 4 W32 | not (isSseEnabled dflags) =
+ sorry $ "128-bit wide single-precision floating point " ++
+ "SIMD vector instructions require at least -msse."
+ check W128 _ _ _ | not (isSse2Enabled dflags) =
+ sorry $ "128-bit wide integer and double precision " ++
+ "SIMD vector instructions require at least -msse2."
+ check W256 FloatVec _ _ | not (isAvxEnabled dflags) =
+ sorry $ "256-bit wide floating point " ++
+ "SIMD vector instructions require at least -mavx."
+ check W256 _ _ _ | not (isAvx2Enabled dflags) =
+ sorry $ "256-bit wide integer " ++
+ "SIMD vector instructions require at least -mavx2."
+ check W512 _ _ _ | not (isAvx512fEnabled dflags) =
+ sorry $ "512-bit wide " ++
+ "SIMD vector instructions require -mavx512f."
+ check _ _ _ _ = return ()
+
+ vecWidth = typeWidth (vecVmmType vcat l w)
+
------------------------------------------------------------------------------
-- Helpers for translating vector packing and unpacking.