summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Pilgrim <llvm-dev@redking.me.uk>2021-10-28 14:07:17 +0100
committerSimon Pilgrim <llvm-dev@redking.me.uk>2021-10-28 14:07:17 +0100
commitd29ccbecd093c881c599fd41db5d68dae744f91f (patch)
treeb8b54458532fc8ac8108095379ddb8e7f7cdec01
parentfbf1745722a0df95608128561d58744ae7b6f311 (diff)
downloadllvm-d29ccbecd093c881c599fd41db5d68dae744f91f.tar.gz
[X86][AVX] Attempt to fold a scaled index into a gather/scatter scale immediate (PR13310)
If the index operand for a gather/scatter intrinsic is being scaled (self-addition or a shl-by-immediate) then we may be able to fold that scaling into the intrinsic scale immediate value instead. Fixes PR13310. Differential Revision: https://reviews.llvm.org/D108539
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp38
-rw-r--r--llvm/test/CodeGen/X86/masked_gather_scatter.ll58
2 files changed, 60 insertions, 36 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 15eec7a69726..e922cb356dfe 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -50227,9 +50227,40 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
}
static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
+ auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
+ SDValue Index = MemOp->getIndex();
+ SDValue Scale = MemOp->getScale();
+ SDValue Mask = MemOp->getMask();
+
+ // Attempt to fold an index scale into the scale value directly.
+ // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively?
+ if ((Index.getOpcode() == X86ISD::VSHLI ||
+ (Index.getOpcode() == ISD::ADD &&
+ Index.getOperand(0) == Index.getOperand(1))) &&
+ isa<ConstantSDNode>(Scale)) {
+ unsigned ShiftAmt =
+ Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1);
+ uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+ uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt);
+ if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
+ SDValue NewIndex = Index.getOperand(0);
+ SDValue NewScale =
+ DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType());
+ if (N->getOpcode() == X86ISD::MGATHER)
+ return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG,
+ MemOp->getOperand(1), Mask,
+ MemOp->getBasePtr(), NewIndex, NewScale,
+ MemOp->getChain(), Subtarget);
+ if (N->getOpcode() == X86ISD::MSCATTER)
+ return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG,
+ MemOp->getOperand(1), Mask, MemOp->getBasePtr(),
+ NewIndex, NewScale, MemOp->getChain(), Subtarget);
+ }
+ }
+
// With vector masks we only demand the upper bit of the mask.
- SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
@@ -52886,7 +52917,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, DCI);
case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget);
case X86ISD::MGATHER:
- case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
+ case X86ISD::MSCATTER:
+ return combineX86GatherScatter(N, DAG, DCI, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index d6c3f8625ffe..fbe02af64e3d 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -808,20 +808,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
; KNL_64-NEXT: vmovd %esi, %xmm0
; KNL_64-NEXT: vpbroadcastd %xmm0, %ymm0
; KNL_64-NEXT: vpmovsxdq %ymm0, %zmm0
-; KNL_64-NEXT: vpsllq $2, %zmm0, %zmm0
; KNL_64-NEXT: kxnorw %k0, %k0, %k1
; KNL_64-NEXT: vxorps %xmm1, %xmm1, %xmm1
-; KNL_64-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1}
+; KNL_64-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1}
; KNL_64-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0
; KNL_64-NEXT: retq
;
; KNL_32-LABEL: test14:
; KNL_32: # %bb.0:
; KNL_32-NEXT: vmovd %xmm0, %eax
-; KNL_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1
+; KNL_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1
; KNL_32-NEXT: kxnorw %k0, %k0, %k1
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; KNL_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1}
+; KNL_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
; KNL_32-NEXT: retl
;
; SKX-LABEL: test14:
@@ -829,20 +828,19 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
; SKX-NEXT: vmovq %xmm0, %rax
; SKX-NEXT: vpbroadcastd %esi, %ymm0
; SKX-NEXT: vpmovsxdq %ymm0, %zmm0
-; SKX-NEXT: vpsllq $2, %zmm0, %zmm0
; SKX-NEXT: kxnorw %k0, %k0, %k1
; SKX-NEXT: vxorps %xmm1, %xmm1, %xmm1
-; SKX-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1}
+; SKX-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1}
; SKX-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0
; SKX-NEXT: retq
;
; SKX_32-LABEL: test14:
; SKX_32: # %bb.0:
; SKX_32-NEXT: vmovd %xmm0, %eax
-; SKX_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1
+; SKX_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1
; SKX_32-NEXT: kxnorw %k0, %k0, %k1
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; SKX_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1}
+; SKX_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1}
; SKX_32-NEXT: retl
%broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1
@@ -4988,38 +4986,38 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
;
; PR13310
-; FIXME: Failure to fold scaled-index into gather/scatter scale operand.
+; Failure to fold scaled-index into gather/scatter scale operand.
;
define <8 x float> @scaleidx_x86gather(float* %base, <8 x i32> %index, <8 x i32> %imask) nounwind {
; KNL_64-LABEL: scaleidx_x86gather:
; KNL_64: # %bb.0:
-; KNL_64-NEXT: vpslld $2, %ymm0, %ymm2
-; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0
+; KNL_64-NEXT: vxorps %xmm2, %xmm2, %xmm2
+; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2
+; KNL_64-NEXT: vmovaps %ymm2, %ymm0
; KNL_64-NEXT: retq
;
; KNL_32-LABEL: scaleidx_x86gather:
; KNL_32: # %bb.0:
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
-; KNL_32-NEXT: vpslld $2, %ymm0, %ymm2
-; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0
+; KNL_32-NEXT: vxorps %xmm2, %xmm2, %xmm2
+; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2
+; KNL_32-NEXT: vmovaps %ymm2, %ymm0
; KNL_32-NEXT: retl
;
; SKX-LABEL: scaleidx_x86gather:
; SKX: # %bb.0:
-; SKX-NEXT: vpslld $2, %ymm0, %ymm2
-; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0
+; SKX-NEXT: vxorps %xmm2, %xmm2, %xmm2
+; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2
+; SKX-NEXT: vmovaps %ymm2, %ymm0
; SKX-NEXT: retq
;
; SKX_32-LABEL: scaleidx_x86gather:
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
-; SKX_32-NEXT: vpslld $2, %ymm0, %ymm2
-; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0
+; SKX_32-NEXT: vxorps %xmm2, %xmm2, %xmm2
+; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2
+; SKX_32-NEXT: vmovaps %ymm2, %ymm0
; SKX_32-NEXT: retl
%ptr = bitcast float* %base to i8*
%mask = bitcast <8 x i32> %imask to <8 x float>
@@ -5070,8 +5068,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
; KNL_64-LABEL: scaleidx_x86scatter:
; KNL_64: # %bb.0:
; KNL_64-NEXT: kmovw %esi, %k1
-; KNL_64-NEXT: vpaddd %zmm1, %zmm1, %zmm1
-; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1}
+; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1}
; KNL_64-NEXT: vzeroupper
; KNL_64-NEXT: retq
;
@@ -5079,16 +5076,14 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
; KNL_32: # %bb.0:
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; KNL_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1
-; KNL_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1
-; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1}
+; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1}
; KNL_32-NEXT: vzeroupper
; KNL_32-NEXT: retl
;
; SKX-LABEL: scaleidx_x86scatter:
; SKX: # %bb.0:
; SKX-NEXT: kmovw %esi, %k1
-; SKX-NEXT: vpaddd %zmm1, %zmm1, %zmm1
-; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1}
+; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1}
; SKX-NEXT: vzeroupper
; SKX-NEXT: retq
;
@@ -5096,8 +5091,7 @@ define void @scaleidx_x86scatter(<16 x float> %value, float* %base, <16 x i32> %
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
; SKX_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1
-; SKX_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1
-; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1}
+; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1}
; SKX_32-NEXT: vzeroupper
; SKX_32-NEXT: retl
%ptr = bitcast float* %base to i8*
@@ -5135,18 +5129,16 @@ define void @scaleidx_scatter(<8 x float> %value, float* %base, <8 x i32> %index
;
; SKX-LABEL: scaleidx_scatter:
; SKX: # %bb.0:
-; SKX-NEXT: vpaddd %ymm1, %ymm1, %ymm1
; SKX-NEXT: kmovw %esi, %k1
-; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,4) {%k1}
+; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,8) {%k1}
; SKX-NEXT: vzeroupper
; SKX-NEXT: retq
;
; SKX_32-LABEL: scaleidx_scatter:
; SKX_32: # %bb.0:
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
-; SKX_32-NEXT: vpaddd %ymm1, %ymm1, %ymm1
; SKX_32-NEXT: kmovb {{[0-9]+}}(%esp), %k1
-; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,4) {%k1}
+; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,8) {%k1}
; SKX_32-NEXT: vzeroupper
; SKX_32-NEXT: retl
%scaledindex = mul <8 x i32> %index, <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>