diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2021-10-28 14:07:17 +0100 |
---|---|---|
committer | Simon Pilgrim <llvm-dev@redking.me.uk> | 2021-10-28 14:07:17 +0100 |
commit | d29ccbecd093c881c599fd41db5d68dae744f91f (patch) | |
tree | b8b54458532fc8ac8108095379ddb8e7f7cdec01 | |
parent | fbf1745722a0df95608128561d58744ae7b6f311 (diff) | |
download | llvm-d29ccbecd093c881c599fd41db5d68dae744f91f.tar.gz |
[X86][AVX] Attempt to fold a scaled index into a gather/scatter scale immediate (PR13310)
If the index operand for a gather/scatter intrinsic is being scaled (self-addition or a shl-by-immediate) then we may be able to fold that scaling into the intrinsic scale immediate value instead.
Fixes PR13310.
Differential Revision: https://reviews.llvm.org/D108539
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 38 | ||||
-rw-r--r-- | llvm/test/CodeGen/X86/masked_gather_scatter.ll | 58 |
2 files changed, 60 insertions, 36 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 15eec7a69726..e922cb356dfe 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -50227,9 +50227,40 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, } static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N); + SDValue Index = MemOp->getIndex(); + SDValue Scale = MemOp->getScale(); + SDValue Mask = MemOp->getMask(); + + // Attempt to fold an index scale into the scale value directly. + // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively? + if ((Index.getOpcode() == X86ISD::VSHLI || + (Index.getOpcode() == ISD::ADD && + Index.getOperand(0) == Index.getOperand(1))) && + isa<ConstantSDNode>(Scale)) { + unsigned ShiftAmt = + Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1); + uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue(); + uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt); + if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) { + SDValue NewIndex = Index.getOperand(0); + SDValue NewScale = + DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType()); + if (N->getOpcode() == X86ISD::MGATHER) + return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG, + MemOp->getOperand(1), Mask, + MemOp->getBasePtr(), NewIndex, NewScale, + MemOp->getChain(), Subtarget); + if (N->getOpcode() == X86ISD::MSCATTER) + return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG, + MemOp->getOperand(1), Mask, MemOp->getBasePtr(), + NewIndex, NewScale, MemOp->getChain(), Subtarget); + } + } + // With vector masks we only demand the upper bit of the mask. - SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask(); if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); @@ -52886,7 +52917,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, DCI); case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget); case X86ISD::MGATHER: - case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI); + case X86ISD::MSCATTER: + return combineX86GatherScatter(N, DAG, DCI, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI); case X86ISD::PCMPEQ: diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll index d6c3f8625ffe..fbe02af64e3d 100644 --- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -808,20 +808,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) { ; KNL_64-NEXT: vmovd %esi, %xmm0 ; KNL_64-NEXT: vpbroadcastd %xmm0, %ymm0 ; KNL_64-NEXT: vpmovsxdq %ymm0, %zmm0 -; KNL_64-NEXT: vpsllq $2, %zmm0, %zmm0 ; KNL_64-NEXT: kxnorw %k0, %k0, %k1 ; KNL_64-NEXT: vxorps %xmm1, %xmm1, %xmm1 -; KNL_64-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1} +; KNL_64-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1} ; KNL_64-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0 ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: test14: ; KNL_32: # %bb.0: ; KNL_32-NEXT: vmovd %xmm0, %eax -; KNL_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1 +; KNL_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1 ; KNL_32-NEXT: kxnorw %k0, %k0, %k1 ; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; KNL_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1} +; KNL_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1} ; KNL_32-NEXT: retl ; ; SKX-LABEL: test14: @@ -829,20 +828,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) { ; SKX-NEXT: vmovq %xmm0, %rax ; SKX-NEXT: vpbroadcastd %esi, %ymm0 ; SKX-NEXT: vpmovsxdq %ymm0, %zmm0 -; SKX-NEXT: vpsllq $2, %zmm0, %zmm0 ; SKX-NEXT: kxnorw %k0, %k0, %k1 ; SKX-NEXT: vxorps %xmm1, %xmm1, %xmm1 -; SKX-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1} +; SKX-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1} ; SKX-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: test14: ; SKX_32: # %bb.0: ; SKX_32-NEXT: vmovd %xmm0, %eax -; SKX_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1 +; SKX_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1 ; SKX_32-NEXT: kxnorw %k0, %k0, %k1 ; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; SKX_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1} +; SKX_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1} ; SKX_32-NEXT: retl %broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1 @@ -4988,38 +4986,38 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { ; ; PR13310 -; FIXME: Failure to fold scaled-index into gather/scatter scale operand. +; Failure to fold scaled-index into gather/scatter scale operand. ; define <8 x float> @scaleidx_x86gather(float* %base, <8 x i32> %index, <8 x i32> %imask) nounwind { ; KNL_64-LABEL: scaleidx_x86gather: ; KNL_64: # %bb.0: -; KNL_64-NEXT: vpslld $2, %ymm0, %ymm2 -; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0 +; KNL_64-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2 +; KNL_64-NEXT: vmovaps %ymm2, %ymm0 ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: scaleidx_x86gather: ; KNL_32: # %bb.0: ; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; KNL_32-NEXT: vpslld $2, %ymm0, %ymm2 -; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0 +; KNL_32-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2 +; KNL_32-NEXT: vmovaps %ymm2, %ymm0 ; KNL_32-NEXT: retl ; ; SKX-LABEL: scaleidx_x86gather: ; SKX: # %bb.0: -; SKX-NEXT: vpslld $2, %ymm0, %ymm2 -; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0 +; SKX-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2 +; SKX-NEXT: vmovaps %ymm2, %ymm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: scaleidx_x86gather: ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; SKX_32-NEXT: vpslld $2, %ymm0, %ymm2 -; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0 +; SKX_32-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2 +; SKX_32-NEXT: vmovaps %ymm2, %ymm0 ; SKX_32-NEXT: retl %ptr = bitcast float* %base to i8* %mask = bitcast <8 x i32> %imask to <8 x float> @@ -5070,8 +5068,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> % ; KNL_64-LABEL: scaleidx_x86scatter: ; KNL_64: # %bb.0: ; KNL_64-NEXT: kmovw %esi, %k1 -; KNL_64-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1} +; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1} ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq ; @@ -5079,16 +5076,14 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> % ; KNL_32: # %bb.0: ; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax ; KNL_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1 -; KNL_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1} +; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1} ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl ; ; SKX-LABEL: scaleidx_x86scatter: ; SKX: # %bb.0: ; SKX-NEXT: kmovw %esi, %k1 -; SKX-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1} +; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1} ; SKX-NEXT: vzeroupper ; SKX-NEXT: retq ; @@ -5096,8 +5091,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> % ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax ; SKX_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1 -; SKX_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1} +; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1} ; SKX_32-NEXT: vzeroupper ; SKX_32-NEXT: retl %ptr = bitcast float* %base to i8* @@ -5135,18 +5129,16 @@ define void @scaleidx_scatter(<8 x float> %value, float* %base, <8 x i32> %index ; ; SKX-LABEL: scaleidx_scatter: ; SKX: # %bb.0: -; SKX-NEXT: vpaddd %ymm1, %ymm1, %ymm1 ; SKX-NEXT: kmovw %esi, %k1 -; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,4) {%k1} +; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,8) {%k1} ; SKX-NEXT: vzeroupper ; SKX-NEXT: retq ; ; SKX_32-LABEL: scaleidx_scatter: ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; SKX_32-NEXT: vpaddd %ymm1, %ymm1, %ymm1 ; SKX_32-NEXT: kmovb {{[0-9]+}}(%esp), %k1 -; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,4) {%k1} +; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,8) {%k1} ; SKX_32-NEXT: vzeroupper ; SKX_32-NEXT: retl %scaledindex = mul <8 x i32> %index, <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2> |