diff options
Diffstat (limited to 'mlir/lib/Transforms/Mem2Reg.cpp')
-rw-r--r-- | mlir/lib/Transforms/Mem2Reg.cpp | 258 |
1 files changed, 192 insertions, 66 deletions
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 45d6f7d0c1ed..3b303f9836cf 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -10,6 +10,8 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -92,11 +94,121 @@ using namespace mlir; /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), /// Springer. +namespace { + +/// Information computed during promotion analysis used to perform actual +/// promotion. +struct MemorySlotPromotionInfo { + /// Blocks for which at least two definitions of the slot values clash. + SmallPtrSet<Block *, 8> mergePoints; + /// Contains, for each operation, which uses must be eliminated by promotion. + /// This is a DAG structure because if an operation must eliminate some of + /// its uses, it is because the defining ops of the blocking uses requested + /// it. The defining ops therefore must also have blocking uses or be the + /// starting point of the bloccking uses. + DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses; +}; + +/// Computes information for basic slot promotion. This will check that direct +/// slot promotion can be performed, and provide the information to execute the +/// promotion. This does not mutate IR. +class MemorySlotPromotionAnalyzer { +public: + MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) + : slot(slot), dominance(dominance) {} + + /// Computes the information for slot promotion if promotion is possible, + /// returns nothing otherwise. + std::optional<MemorySlotPromotionInfo> computeInfo(); + +private: + /// Computes the transitive uses of the slot that block promotion. This finds + /// uses that would block the promotion, checks that the operation has a + /// solution to remove the blocking use, and potentially forwards the analysis + /// if the operation needs further blocking uses resolved to resolve its own + /// uses (typically, removing its users because it will delete itself to + /// resolve its own blocking uses). This will fail if one of the transitive + /// users cannot remove a requested use, and should prevent promotion. + LogicalResult computeBlockingUses( + DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses); + + /// Computes in which blocks the value stored in the slot is actually used, + /// meaning blocks leading to a load. This method uses `definingBlocks`, the + /// set of blocks containing a store to the slot (defining the value of the + /// slot). + SmallPtrSet<Block *, 16> + computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks); + + /// Computes the points in which multiple re-definitions of the slot's value + /// (stores) may conflict. + void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints); + + /// Ensures predecessors of merge points can properly provide their current + /// definition of the value stored in the slot to the merge point. This can + /// notably be an issue if the terminator used does not have the ability to + /// forward values through block operands. + bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints); + + MemorySlot slot; + DominanceInfo &dominance; +}; + +/// The MemorySlotPromoter handles the state of promoting a memory slot. It +/// wraps a slot and its associated allocator. This will perform the mutation of +/// IR. +class MemorySlotPromoter { +public: + MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, + RewriterBase &rewriter, DominanceInfo &dominance, + MemorySlotPromotionInfo info, + const Mem2RegStatistics &statistics); + + /// Actually promotes the slot by mutating IR. Promoting a slot DOES + /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of + /// promotion info should NOT be performed in batches. + void promoteSlot(); + +private: + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` is the value the slot should contain at the + /// beginning of the block. This method returns the reached definition at the + /// end of the block. + Value computeReachingDefInBlock(Block *block, Value reachingDef); + + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` corresponds to the initial value the + /// slot will contain before any write, typically a poison value. + void computeReachingDefInRegion(Region *region, Value reachingDef); + + /// Removes the blocking uses of the slot, in topological order. + void removeBlockingUses(); + + /// Lazily-constructed default value representing the content of the slot when + /// no store has been executed. This function may mutate IR. + Value getLazyDefaultValue(); + + MemorySlot slot; + PromotableAllocationOpInterface allocator; + RewriterBase &rewriter; + /// Potentially non-initialized default value. Use `getLazyDefaultValue` to + /// initialize it on demand. + Value defaultValue; + /// Contains the reaching definition at this operation. Reaching definitions + /// are only computed for promotable memory operations with blocking uses. + DenseMap<PromotableMemOpInterface, Value> reachingDefs; + DominanceInfo &dominance; + MemorySlotPromotionInfo info; + const Mem2RegStatistics &statistics; +}; + +} // namespace + MemorySlotPromoter::MemorySlotPromoter( MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info) - : slot(slot), allocator(allocator), builder(builder), dominance(dominance), - info(std::move(info)) { + RewriterBase &rewriter, DominanceInfo &dominance, + MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics) + : slot(slot), allocator(allocator), rewriter(rewriter), + dominance(dominance), info(std::move(info)), statistics(statistics) { #ifndef NDEBUG auto isResultOrNewBlockArgument = [&]() { if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr)) @@ -114,9 +226,9 @@ Value MemorySlotPromoter::getLazyDefaultValue() { if (defaultValue) return defaultValue; - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(slot.ptr.getParentBlock()); - return defaultValue = allocator.getDefaultValue(slot, builder); + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(slot.ptr.getParentBlock()); + return defaultValue = allocator.getDefaultValue(slot, rewriter); } LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( @@ -341,11 +453,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region, Block *block = job.block->getBlock(); if (info.mergePoints.contains(block)) { - BlockArgument blockArgument = - block->addArgument(slot.elemType, slot.ptr.getLoc()); - builder.setInsertionPointToStart(block); - allocator.handleBlockArgument(slot, blockArgument, builder); + // If the block is a merge point, we need to add a block argument to hold + // the selected reaching definition. This has to be a bit complicated + // because of RewriterBase limitations: we need to create a new block with + // the extra block argument, move the content of the block to the new + // block, and replace the block with the new block in the merge point set. + SmallVector<Type> argTypes; + SmallVector<Location> argLocs; + for (BlockArgument arg : block->getArguments()) { + argTypes.push_back(arg.getType()); + argLocs.push_back(arg.getLoc()); + } + argTypes.push_back(slot.elemType); + argLocs.push_back(slot.ptr.getLoc()); + Block *newBlock = rewriter.createBlock(block, argTypes, argLocs); + + info.mergePoints.erase(block); + info.mergePoints.insert(newBlock); + + rewriter.replaceAllUsesWith(block, newBlock); + rewriter.mergeBlocks(block, newBlock, + newBlock->getArguments().drop_back()); + + block = newBlock; + + BlockArgument blockArgument = block->getArguments().back(); + rewriter.setInsertionPointToStart(block); + allocator.handleBlockArgument(slot, blockArgument, rewriter); job.reachingDef = blockArgument; + + if (statistics.newBlockArgumentAmount) + (*statistics.newBlockArgumentAmount)++; } job.reachingDef = computeReachingDefInBlock(block, job.reachingDef); @@ -355,8 +493,10 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region, if (info.mergePoints.contains(blockOperand.get())) { if (!job.reachingDef) job.reachingDef = getLazyDefaultValue(); - terminator.getSuccessorOperands(blockOperand.getOperandNumber()) - .append(job.reachingDef); + rewriter.updateRootInPlace(terminator, [&]() { + terminator.getSuccessorOperands(blockOperand.getOperandNumber()) + .append(job.reachingDef); + }); } } } @@ -382,24 +522,24 @@ void MemorySlotPromoter::removeBlockingUses() { if (!reachingDef) reachingDef = getLazyDefaultValue(); - builder.setInsertionPointAfter(toPromote); + rewriter.setInsertionPointAfter(toPromote); if (toPromoteMemOp.removeBlockingUses( - slot, info.userToBlockingUses[toPromote], builder, reachingDef) == - DeletionKind::Delete) + slot, info.userToBlockingUses[toPromote], rewriter, + reachingDef) == DeletionKind::Delete) toErase.push_back(toPromote); continue; } auto toPromoteBasic = cast<PromotableOpInterface>(toPromote); - builder.setInsertionPointAfter(toPromote); + rewriter.setInsertionPointAfter(toPromote); if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], - builder) == DeletionKind::Delete) + rewriter) == DeletionKind::Delete) toErase.push_back(toPromote); } for (Operation *toEraseOp : toErase) - toEraseOp->erase(); + rewriter.eraseOp(toEraseOp); assert(slot.ptr.use_empty() && "after promotion, the slot pointer should not be used anymore"); @@ -421,87 +561,73 @@ void MemorySlotPromoter::promoteSlot() { assert(succOperands.size() == mergePoint->getNumArguments() || succOperands.size() + 1 == mergePoint->getNumArguments()); if (succOperands.size() + 1 == mergePoint->getNumArguments()) - succOperands.append(getLazyDefaultValue()); + rewriter.updateRootInPlace( + user, [&]() { succOperands.append(getLazyDefaultValue()); }); } } LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr << "\n"); - allocator.handlePromotionComplete(slot, defaultValue); + if (statistics.promotedAmount) + (*statistics.promotedAmount)++; + + allocator.handlePromotionComplete(slot, defaultValue, rewriter); } LogicalResult mlir::tryToPromoteMemorySlots( - ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder, - DominanceInfo &dominance) { - // Actual promotion may invalidate the dominance analysis, so slot promotion - // is prepated in batches. - SmallVector<MemorySlotPromoter> toPromote; + ArrayRef<PromotableAllocationOpInterface> allocators, + RewriterBase &rewriter, Mem2RegStatistics statistics) { + DominanceInfo dominance; + + bool promotedAny = false; + for (PromotableAllocationOpInterface allocator : allocators) { for (MemorySlot slot : allocator.getPromotableSlots()) { if (slot.ptr.use_empty()) continue; + DominanceInfo dominance; MemorySlotPromotionAnalyzer analyzer(slot, dominance); std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo(); - if (info) - toPromote.emplace_back(slot, allocator, builder, dominance, - std::move(*info)); + if (info) { + MemorySlotPromoter(slot, allocator, rewriter, dominance, + std::move(*info), statistics) + .promoteSlot(); + promotedAny = true; + } } } - for (MemorySlotPromoter &promoter : toPromote) - promoter.promoteSlot(); - - return success(!toPromote.empty()); + return success(promotedAny); } -LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { +LogicalResult +Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator, + PatternRewriter &rewriter) const { hasBoundedRewriteRecursion(); - - if (op->getNumRegions() == 0) - return failure(); - - DominanceInfo dominance; - - SmallVector<PromotableAllocationOpInterface> allocators; - // Build a list of allocators to attempt to promote the slots of. - for (Region ®ion : op->getRegions()) - for (auto allocator : region.getOps<PromotableAllocationOpInterface>()) - allocators.emplace_back(allocator); - - // Because pattern rewriters are normally not expressive enough to support a - // transformation like mem2reg, this uses an escape hatch to mark modified - // operations manually and operate outside of its context. - rewriter.startRootUpdate(op); - - OpBuilder builder(rewriter.getContext()); - - if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) { - rewriter.cancelRootUpdate(op); - return failure(); - } - - rewriter.finalizeRootUpdate(op); - return success(); + return tryToPromoteMemorySlots({allocator}, rewriter, statistics); } namespace { struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> { + using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase; + void runOnOperation() override { Operation *scopeOp = getOperation(); - bool changed = false; + + Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount}; + + GreedyRewriteConfig config; + config.enableRegionSimplification = enableRegionSimplification; RewritePatternSet rewritePatterns(&getContext()); - rewritePatterns.add<Mem2RegPattern>(&getContext()); + rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); - (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(), - &changed); - if (!changed) - markAllAnalysesPreserved(); + if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config))) + signalPassFailure(); } }; |