summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJing Bao <jing.bao@intel.com>2021-12-14 08:42:38 -0800
committerThomas Lively <tlively@google.com>2021-12-14 08:42:39 -0800
commit2a4a229d6dcceecbb8bab094b6880e2445a6e465 (patch)
tree0a2761cde55e42ad567cf7bd3013417dc290699f
parent0b9b1c8c49fd317ce70d028b041572e1f24f5995 (diff)
downloadllvm-2a4a229d6dcceecbb8bab094b6880e2445a6e465.tar.gz
[WebAssembly] Custom optimization for truncate
When possible, optimize TRUNCATE to generate Wasm SIMD narrow instructions (i16x8.narrow_i32x4_u, i8x16.narrow_i16x8_u), rather than generate lots of extract_lane and replace_lane. Closes #50350.
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISD.def1
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp112
-rw-r--r--llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td8
-rw-r--r--llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll26
4 files changed, 139 insertions, 8 deletions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1fa0ea3867c7..a3a33f4a5b3a 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -31,6 +31,7 @@ HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
HANDLE_NODETYPE(VEC_SHR_S)
HANDLE_NODETYPE(VEC_SHR_U)
+HANDLE_NODETYPE(NARROW_U)
HANDLE_NODETYPE(EXTEND_LOW_S)
HANDLE_NODETYPE(EXTEND_LOW_U)
HANDLE_NODETYPE(EXTEND_HIGH_S)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 0c3ee545f8c5..38ed4c73fb93 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -176,6 +176,8 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setTargetDAGCombine(ISD::FP_ROUND);
setTargetDAGCombine(ISD::CONCAT_VECTORS);
+ setTargetDAGCombine(ISD::TRUNCATE);
+
// Support saturating add for i8x16 and i16x8
for (auto Op : {ISD::SADDSAT, ISD::UADDSAT})
for (auto T : {MVT::v16i8, MVT::v8i16})
@@ -2609,6 +2611,114 @@ performVectorTruncZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getNode(Op, SDLoc(N), ResVT, Source);
}
+// Helper to extract VectorWidth bits from Vec, starting from IdxVal.
+static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
+ const SDLoc &DL, unsigned VectorWidth) {
+ EVT VT = Vec.getValueType();
+ EVT ElVT = VT.getVectorElementType();
+ unsigned Factor = VT.getSizeInBits() / VectorWidth;
+ EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
+ VT.getVectorNumElements() / Factor);
+
+ // Extract the relevant VectorWidth bits. Generate an EXTRACT_SUBVECTOR
+ unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits();
+ assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
+
+ // This is the index of the first element of the VectorWidth-bit chunk
+ // we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
+ IdxVal &= ~(ElemsPerChunk - 1);
+
+ // If the input is a buildvector just emit a smaller one.
+ if (Vec.getOpcode() == ISD::BUILD_VECTOR)
+ return DAG.getBuildVector(ResultVT, DL,
+ Vec->ops().slice(IdxVal, ElemsPerChunk));
+
+ SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx);
+}
+
+// Helper to recursively truncate vector elements in half with NARROW_U. DstVT
+// is the expected destination value type after recursion. In is the initial
+// input. Note that the input should have enough leading zero bits to prevent
+// NARROW_U from saturating results.
+static SDValue truncateVectorWithNARROW(EVT DstVT, SDValue In, const SDLoc &DL,
+ SelectionDAG &DAG) {
+ EVT SrcVT = In.getValueType();
+
+ // No truncation required, we might get here due to recursive calls.
+ if (SrcVT == DstVT)
+ return In;
+
+ unsigned SrcSizeInBits = SrcVT.getSizeInBits();
+ unsigned NumElems = SrcVT.getVectorNumElements();
+ if (!isPowerOf2_32(NumElems))
+ return SDValue();
+ assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
+ assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation");
+
+ LLVMContext &Ctx = *DAG.getContext();
+ EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
+
+ // Narrow to the largest type possible:
+ // vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u.
+ EVT InVT = MVT::i16, OutVT = MVT::i8;
+ if (SrcVT.getScalarSizeInBits() > 16) {
+ InVT = MVT::i32;
+ OutVT = MVT::i16;
+ }
+ unsigned SubSizeInBits = SrcSizeInBits / 2;
+ InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
+ OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
+
+ // Split lower/upper subvectors.
+ SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits);
+ SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits);
+
+ // 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors.
+ if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
+ Lo = DAG.getBitcast(InVT, Lo);
+ Hi = DAG.getBitcast(InVT, Hi);
+ SDValue Res = DAG.getNode(WebAssemblyISD::NARROW_U, DL, OutVT, Lo, Hi);
+ return DAG.getBitcast(DstVT, Res);
+ }
+
+ // Recursively narrow lower/upper subvectors, concat result and narrow again.
+ EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2);
+ Lo = truncateVectorWithNARROW(PackedVT, Lo, DL, DAG);
+ Hi = truncateVectorWithNARROW(PackedVT, Hi, DL, DAG);
+
+ PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
+ SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
+ return truncateVectorWithNARROW(DstVT, Res, DL, DAG);
+}
+
+static SDValue performTruncateCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ auto &DAG = DCI.DAG;
+
+ SDValue In = N->getOperand(0);
+ EVT InVT = In.getValueType();
+ if (!InVT.isSimple())
+ return SDValue();
+
+ EVT OutVT = N->getValueType(0);
+ if (!OutVT.isVector())
+ return SDValue();
+
+ EVT OutSVT = OutVT.getVectorElementType();
+ EVT InSVT = InVT.getVectorElementType();
+ // Currently only cover truncate to v16i8 or v8i16.
+ if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) &&
+ (OutSVT == MVT::i8 || OutSVT == MVT::i16) && OutVT.is128BitVector()))
+ return SDValue();
+
+ SDLoc DL(N);
+ APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(),
+ OutVT.getScalarSizeInBits());
+ In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT));
+ return truncateVectorWithNARROW(OutVT, In, DL, DAG);
+}
+
SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -2625,5 +2735,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_ROUND:
case ISD::CONCAT_VECTORS:
return performVectorTruncZeroCombine(N, DCI);
+ case ISD::TRUNCATE:
+ return performTruncateCombine(N, DCI);
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 30b99c3a69a9..5bb12c7fbdc7 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1278,6 +1278,14 @@ multiclass SIMDNarrow<Vec vec, bits<32> baseInst> {
defm "" : SIMDNarrow<I16x8, 101>;
defm "" : SIMDNarrow<I32x4, 133>;
+// WebAssemblyISD::NARROW_U
+def wasm_narrow_t : SDTypeProfile<1, 2, []>;
+def wasm_narrow_u : SDNode<"WebAssemblyISD::NARROW_U", wasm_narrow_t>;
+def : Pat<(v16i8 (wasm_narrow_u (v8i16 V128:$left), (v8i16 V128:$right))),
+ (NARROW_U_I8x16 $left, $right)>;
+def : Pat<(v8i16 (wasm_narrow_u (v4i32 V128:$left), (v4i32 V128:$right))),
+ (NARROW_U_I16x8 $left, $right)>;
+
// Bitcasts are nops
// Matching bitcast t1 to t1 causes strange errors, so avoid repeating types
foreach t1 = AllVecs in
diff --git a/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll b/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
index c1fd8ef01e38..a595ffe51e2e 100644
--- a/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
+++ b/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
@@ -532,7 +532,7 @@ entry:
define <8 x i16> @stest_f16i16(<8 x half> %x) {
; CHECK-LABEL: stest_f16i16:
; CHECK: .functype stest_f16i16 (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
-; CHECK-NEXT: .local v128, v128
+; CHECK-NEXT: .local v128, v128, v128
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 5
; CHECK-NEXT: call __truncsfhf2
@@ -578,6 +578,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
; CHECK-NEXT: local.tee 9
; CHECK-NEXT: i32x4.max_s
+; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
+; CHECK-NEXT: local.tee 10
+; CHECK-NEXT: v128.and
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i32.trunc_sat_f32_s
; CHECK-NEXT: i32x4.splat
@@ -594,7 +597,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
-; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT: local.get 10
+; CHECK-NEXT: v128.and
+; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>
@@ -666,7 +671,7 @@ define <8 x i16> @utesth_f16i16(<8 x half> %x) {
; CHECK-NEXT: i32x4.replace_lane 3
; CHECK-NEXT: local.get 8
; CHECK-NEXT: i32x4.min_u
-; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptoui <8 x half> %x to <8 x i32>
@@ -741,7 +746,7 @@ define <8 x i16> @ustest_f16i16(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
-; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>
@@ -2106,7 +2111,7 @@ entry:
define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
; CHECK-LABEL: stest_f16i16_mm:
; CHECK: .functype stest_f16i16_mm (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
-; CHECK-NEXT: .local v128, v128
+; CHECK-NEXT: .local v128, v128, v128
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 5
; CHECK-NEXT: call __truncsfhf2
@@ -2152,6 +2157,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
; CHECK-NEXT: local.tee 9
; CHECK-NEXT: i32x4.max_s
+; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
+; CHECK-NEXT: local.tee 10
+; CHECK-NEXT: v128.and
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i32.trunc_sat_f32_s
; CHECK-NEXT: i32x4.splat
@@ -2168,7 +2176,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
-; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT: local.get 10
+; CHECK-NEXT: v128.and
+; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>
@@ -2238,7 +2248,7 @@ define <8 x i16> @utesth_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: i32x4.replace_lane 3
; CHECK-NEXT: local.get 8
; CHECK-NEXT: i32x4.min_u
-; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptoui <8 x half> %x to <8 x i32>
@@ -2312,7 +2322,7 @@ define <8 x i16> @ustest_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
-; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>