diff options
author | Amara Emerson <amara@apple.com> | 2021-08-18 00:19:58 -0700 |
---|---|---|
committer | Amara Emerson <amara@apple.com> | 2021-08-19 16:38:52 -0700 |
commit | 95ac3d15e9fe86d9b51b51d02cb3c1640bf30dee (patch) | |
tree | eb928600c2d369838959eb0ada71055ff07e0dcb /llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp | |
parent | fbb8e772ec501a1b71643db90e9c6445e17d7cac (diff) | |
download | llvm-95ac3d15e9fe86d9b51b51d02cb3c1640bf30dee.tar.gz |
[AArch64][GlobalISel] Add G_VECREDUCE fewerElements support for full scalarization.
For some reductions like G_VECREDUCE_OR on AArch64, we need to scalarize
completely if the source is <= 64b. This change adds support for that in
the legalizer. If the source has a pow-2 num elements, then we can do
a tree reduction using the scalar operation in the individual elements.
Otherwise, we just create a sequential chain of operations.
For AArch64, we only need to scalarize if the input is <64b. If it's great than
64b then we can first do a fewElements step to 64b, taking advantage of vector
instructions until we reach the point of scalarization.
I also had to relax the verifier checks for reductions because the intrinsics
support <1 x EltTy> types, which we lower to scalars for GlobalISel.
Differential Revision: https://reviews.llvm.org/D108276
Diffstat (limited to 'llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp')
-rw-r--r-- | llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp | 124 |
1 files changed, 94 insertions, 30 deletions
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp index 611cf1055a1b..463437a4db08 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -3489,6 +3489,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) { return lowerRotate(MI); case G_ISNAN: return lowerIsNaN(MI); + GISEL_VECREDUCE_CASES_NONSEQ + return lowerVectorReduction(MI); } } @@ -4637,35 +4639,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle( return Legalized; } -LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions( - MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) { - unsigned Opc = MI.getOpcode(); - assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD && - Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL && - "Sequential reductions not expected"); - - if (TypeIdx != 1) - return UnableToLegalize; - - // The semantics of the normal non-sequential reductions allow us to freely - // re-associate the operation. - Register SrcReg = MI.getOperand(1).getReg(); - LLT SrcTy = MRI.getType(SrcReg); - Register DstReg = MI.getOperand(0).getReg(); - LLT DstTy = MRI.getType(DstReg); - - if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0) - return UnableToLegalize; - - SmallVector<Register> SplitSrcs; - const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements(); - extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs); - SmallVector<Register> PartialReductions; - for (unsigned Part = 0; Part < NumParts; ++Part) { - PartialReductions.push_back( - MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0)); - } - +static unsigned getScalarOpcForReduction(unsigned Opc) { unsigned ScalarOpc; switch (Opc) { case TargetOpcode::G_VECREDUCE_FADD: @@ -4708,10 +4682,81 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions( ScalarOpc = TargetOpcode::G_UMIN; break; default: - LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n"); + llvm_unreachable("Unhandled reduction"); + } + return ScalarOpc; +} + +LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions( + MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) { + unsigned Opc = MI.getOpcode(); + assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD && + Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL && + "Sequential reductions not expected"); + + if (TypeIdx != 1) return UnableToLegalize; + + // The semantics of the normal non-sequential reductions allow us to freely + // re-associate the operation. + Register SrcReg = MI.getOperand(1).getReg(); + LLT SrcTy = MRI.getType(SrcReg); + Register DstReg = MI.getOperand(0).getReg(); + LLT DstTy = MRI.getType(DstReg); + + if (NarrowTy.isVector() && + (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)) + return UnableToLegalize; + + unsigned ScalarOpc = getScalarOpcForReduction(Opc); + SmallVector<Register> SplitSrcs; + // If NarrowTy is a scalar then we're being asked to scalarize. + const unsigned NumParts = + NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements() + : SrcTy.getNumElements(); + + extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs); + if (NarrowTy.isScalar()) { + if (DstTy != NarrowTy) + return UnableToLegalize; // FIXME: handle implicit extensions. + + if (isPowerOf2_32(NumParts)) { + // Generate a tree of scalar operations to reduce the critical path. + SmallVector<Register> PartialResults; + unsigned NumPartsLeft = NumParts; + while (NumPartsLeft > 1) { + for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) { + PartialResults.emplace_back( + MIRBuilder + .buildInstr(ScalarOpc, {NarrowTy}, + {SplitSrcs[Idx], SplitSrcs[Idx + 1]}) + .getReg(0)); + } + SplitSrcs = PartialResults; + PartialResults.clear(); + NumPartsLeft = SplitSrcs.size(); + } + assert(SplitSrcs.size() == 1); + MIRBuilder.buildCopy(DstReg, SplitSrcs[0]); + MI.eraseFromParent(); + return Legalized; + } + // If we can't generate a tree, then just do sequential operations. + Register Acc = SplitSrcs[0]; + for (unsigned Idx = 1; Idx < NumParts; ++Idx) + Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]}) + .getReg(0); + MIRBuilder.buildCopy(DstReg, Acc); + MI.eraseFromParent(); + return Legalized; + } + SmallVector<Register> PartialReductions; + for (unsigned Part = 0; Part < NumParts; ++Part) { + PartialReductions.push_back( + MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0)); } + // If the types involved are powers of 2, we can generate intermediate vector // ops, before generating a final reduction operation. if (isPowerOf2_32(SrcTy.getNumElements()) && @@ -7389,3 +7434,22 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerIsNaN(MachineInstr &MI) { MI.eraseFromParent(); return Legalized; } + +LegalizerHelper::LegalizeResult +LegalizerHelper::lowerVectorReduction(MachineInstr &MI) { + Register SrcReg = MI.getOperand(1).getReg(); + LLT SrcTy = MRI.getType(SrcReg); + LLT DstTy = MRI.getType(SrcReg); + + // The source could be a scalar if the IR type was <1 x sN>. + if (SrcTy.isScalar()) { + if (DstTy.getSizeInBits() > SrcTy.getSizeInBits()) + return UnableToLegalize; // FIXME: handle extension. + // This can be just a plain copy. + Observer.changingInstr(MI); + MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY)); + Observer.changedInstr(MI); + return Legalized; + } + return UnableToLegalize;; +} |