summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp45
1 files changed, 36 insertions, 9 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3d25f9d7d0e3..95d388cace0f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7796,10 +7796,37 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
SelectionDAG &DAG) const {
EVT Ty = Op.getValueType();
auto Idx = Op.getConstantOperandAPInt(2);
+ int64_t IdxVal = Idx.getSExtValue();
+ assert(Ty.isScalableVector() &&
+ "Only expect scalable vectors for custom lowering of VECTOR_SPLICE");
+
+ // We can use the splice instruction for certain index values where we are
+ // able to efficiently generate the correct predicate. The index will be
+ // inverted and used directly as the input to the ptrue instruction, i.e.
+ // -1 -> vl1, -2 -> vl2, etc. The predicate will then be reversed to get the
+ // splice predicate. However, we can only do this if we can guarantee that
+ // there are enough elements in the vector, hence we check the index <= min
+ // number of elements.
+ Optional<unsigned> PredPattern;
+ if (Ty.isScalableVector() && IdxVal < 0 &&
+ (PredPattern = getSVEPredPatternFromNumElements(std::abs(IdxVal))) !=
+ None) {
+ SDLoc DL(Op);
+
+ // Create a predicate where all but the last -IdxVal elements are false.
+ EVT PredVT = Ty.changeVectorElementType(MVT::i1);
+ SDValue Pred = getPTrue(DAG, DL, PredVT, *PredPattern);
+ Pred = DAG.getNode(ISD::VECTOR_REVERSE, DL, PredVT, Pred);
+
+ // Now splice the two inputs together using the predicate.
+ return DAG.getNode(AArch64ISD::SPLICE, DL, Ty, Pred, Op.getOperand(0),
+ Op.getOperand(1));
+ }
// This will select to an EXT instruction, which has a maximum immediate
// value of 255, hence 2048-bits is the maximum value we can lower.
- if (Idx.sge(-1) && Idx.slt(2048 / Ty.getVectorElementType().getSizeInBits()))
+ if (IdxVal >= 0 &&
+ IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits()))
return Op;
return SDValue();
@@ -11011,10 +11038,10 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
if (Vec0.isUndef())
return Op;
- unsigned int PredPattern =
+ Optional<unsigned> PredPattern =
getSVEPredPatternFromNumElements(InVT.getVectorNumElements());
auto PredTy = VT.changeVectorElementType(MVT::i1);
- SDValue PTrue = getPTrue(DAG, DL, PredTy, PredPattern);
+ SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern);
SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1);
return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0);
}
@@ -12319,7 +12346,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
Value *PTrue = nullptr;
if (UseScalable) {
- unsigned PgPattern =
+ Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(FVTy->getNumElements());
if (Subtarget->getMinSVEVectorSizeInBits() ==
Subtarget->getMaxSVEVectorSizeInBits() &&
@@ -12327,7 +12354,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
PgPattern = AArch64SVEPredPattern::all;
auto *PTruePat =
- ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), PgPattern);
+ ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), *PgPattern);
PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
{PTruePat});
}
@@ -12499,7 +12526,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
Value *PTrue = nullptr;
if (UseScalable) {
- unsigned PgPattern =
+ Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(SubVecTy->getNumElements());
if (Subtarget->getMinSVEVectorSizeInBits() ==
Subtarget->getMaxSVEVectorSizeInBits() &&
@@ -12508,7 +12535,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
PgPattern = AArch64SVEPredPattern::all;
auto *PTruePat =
- ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), PgPattern);
+ ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), *PgPattern);
PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
{PTruePat});
}
@@ -18752,7 +18779,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
"Expected legal fixed length vector!");
- unsigned PgPattern =
+ Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(VT.getVectorNumElements());
assert(PgPattern && "Unexpected element count for SVE predicate");
@@ -18788,7 +18815,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
break;
}
- return getPTrue(DAG, DL, MaskVT, PgPattern);
+ return getPTrue(DAG, DL, MaskVT, *PgPattern);
}
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,