diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 3d25f9d7d0e3..95d388cace0f 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7796,10 +7796,37 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op, SelectionDAG &DAG) const { EVT Ty = Op.getValueType(); auto Idx = Op.getConstantOperandAPInt(2); + int64_t IdxVal = Idx.getSExtValue(); + assert(Ty.isScalableVector() && + "Only expect scalable vectors for custom lowering of VECTOR_SPLICE"); + + // We can use the splice instruction for certain index values where we are + // able to efficiently generate the correct predicate. The index will be + // inverted and used directly as the input to the ptrue instruction, i.e. + // -1 -> vl1, -2 -> vl2, etc. The predicate will then be reversed to get the + // splice predicate. However, we can only do this if we can guarantee that + // there are enough elements in the vector, hence we check the index <= min + // number of elements. + Optional<unsigned> PredPattern; + if (Ty.isScalableVector() && IdxVal < 0 && + (PredPattern = getSVEPredPatternFromNumElements(std::abs(IdxVal))) != + None) { + SDLoc DL(Op); + + // Create a predicate where all but the last -IdxVal elements are false. + EVT PredVT = Ty.changeVectorElementType(MVT::i1); + SDValue Pred = getPTrue(DAG, DL, PredVT, *PredPattern); + Pred = DAG.getNode(ISD::VECTOR_REVERSE, DL, PredVT, Pred); + + // Now splice the two inputs together using the predicate. + return DAG.getNode(AArch64ISD::SPLICE, DL, Ty, Pred, Op.getOperand(0), + Op.getOperand(1)); + } // This will select to an EXT instruction, which has a maximum immediate // value of 255, hence 2048-bits is the maximum value we can lower. - if (Idx.sge(-1) && Idx.slt(2048 / Ty.getVectorElementType().getSizeInBits())) + if (IdxVal >= 0 && + IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits())) return Op; return SDValue(); @@ -11011,10 +11038,10 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op, if (Vec0.isUndef()) return Op; - unsigned int PredPattern = + Optional<unsigned> PredPattern = getSVEPredPatternFromNumElements(InVT.getVectorNumElements()); auto PredTy = VT.changeVectorElementType(MVT::i1); - SDValue PTrue = getPTrue(DAG, DL, PredTy, PredPattern); + SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern); SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1); return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0); } @@ -12319,7 +12346,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad( Value *PTrue = nullptr; if (UseScalable) { - unsigned PgPattern = + Optional<unsigned> PgPattern = getSVEPredPatternFromNumElements(FVTy->getNumElements()); if (Subtarget->getMinSVEVectorSizeInBits() == Subtarget->getMaxSVEVectorSizeInBits() && @@ -12327,7 +12354,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad( PgPattern = AArch64SVEPredPattern::all; auto *PTruePat = - ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), PgPattern); + ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), *PgPattern); PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy}, {PTruePat}); } @@ -12499,7 +12526,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, Value *PTrue = nullptr; if (UseScalable) { - unsigned PgPattern = + Optional<unsigned> PgPattern = getSVEPredPatternFromNumElements(SubVecTy->getNumElements()); if (Subtarget->getMinSVEVectorSizeInBits() == Subtarget->getMaxSVEVectorSizeInBits() && @@ -12508,7 +12535,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, PgPattern = AArch64SVEPredPattern::all; auto *PTruePat = - ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), PgPattern); + ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), *PgPattern); PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy}, {PTruePat}); } @@ -18752,7 +18779,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL, DAG.getTargetLoweringInfo().isTypeLegal(VT) && "Expected legal fixed length vector!"); - unsigned PgPattern = + Optional<unsigned> PgPattern = getSVEPredPatternFromNumElements(VT.getVectorNumElements()); assert(PgPattern && "Unexpected element count for SVE predicate"); @@ -18788,7 +18815,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL, break; } - return getPTrue(DAG, DL, MaskVT, PgPattern); + return getPTrue(DAG, DL, MaskVT, *PgPattern); } static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL, |