summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Mem2Reg.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Mem2Reg.cpp')
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp258
1 files changed, 192 insertions, 66 deletions
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 45d6f7d0c1ed..3b303f9836cf 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -10,6 +10,8 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -92,11 +94,121 @@ using namespace mlir;
/// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
/// Springer.
+namespace {
+
+/// Information computed during promotion analysis used to perform actual
+/// promotion.
+struct MemorySlotPromotionInfo {
+ /// Blocks for which at least two definitions of the slot values clash.
+ SmallPtrSet<Block *, 8> mergePoints;
+ /// Contains, for each operation, which uses must be eliminated by promotion.
+ /// This is a DAG structure because if an operation must eliminate some of
+ /// its uses, it is because the defining ops of the blocking uses requested
+ /// it. The defining ops therefore must also have blocking uses or be the
+ /// starting point of the bloccking uses.
+ DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+};
+
+/// Computes information for basic slot promotion. This will check that direct
+/// slot promotion can be performed, and provide the information to execute the
+/// promotion. This does not mutate IR.
+class MemorySlotPromotionAnalyzer {
+public:
+ MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
+ : slot(slot), dominance(dominance) {}
+
+ /// Computes the information for slot promotion if promotion is possible,
+ /// returns nothing otherwise.
+ std::optional<MemorySlotPromotionInfo> computeInfo();
+
+private:
+ /// Computes the transitive uses of the slot that block promotion. This finds
+ /// uses that would block the promotion, checks that the operation has a
+ /// solution to remove the blocking use, and potentially forwards the analysis
+ /// if the operation needs further blocking uses resolved to resolve its own
+ /// uses (typically, removing its users because it will delete itself to
+ /// resolve its own blocking uses). This will fail if one of the transitive
+ /// users cannot remove a requested use, and should prevent promotion.
+ LogicalResult computeBlockingUses(
+ DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+
+ /// Computes in which blocks the value stored in the slot is actually used,
+ /// meaning blocks leading to a load. This method uses `definingBlocks`, the
+ /// set of blocks containing a store to the slot (defining the value of the
+ /// slot).
+ SmallPtrSet<Block *, 16>
+ computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
+
+ /// Computes the points in which multiple re-definitions of the slot's value
+ /// (stores) may conflict.
+ void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
+
+ /// Ensures predecessors of merge points can properly provide their current
+ /// definition of the value stored in the slot to the merge point. This can
+ /// notably be an issue if the terminator used does not have the ability to
+ /// forward values through block operands.
+ bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
+
+ MemorySlot slot;
+ DominanceInfo &dominance;
+};
+
+/// The MemorySlotPromoter handles the state of promoting a memory slot. It
+/// wraps a slot and its associated allocator. This will perform the mutation of
+/// IR.
+class MemorySlotPromoter {
+public:
+ MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
+ RewriterBase &rewriter, DominanceInfo &dominance,
+ MemorySlotPromotionInfo info,
+ const Mem2RegStatistics &statistics);
+
+ /// Actually promotes the slot by mutating IR. Promoting a slot DOES
+ /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
+ /// promotion info should NOT be performed in batches.
+ void promoteSlot();
+
+private:
+ /// Computes the reaching definition for all the operations that require
+ /// promotion. `reachingDef` is the value the slot should contain at the
+ /// beginning of the block. This method returns the reached definition at the
+ /// end of the block.
+ Value computeReachingDefInBlock(Block *block, Value reachingDef);
+
+ /// Computes the reaching definition for all the operations that require
+ /// promotion. `reachingDef` corresponds to the initial value the
+ /// slot will contain before any write, typically a poison value.
+ void computeReachingDefInRegion(Region *region, Value reachingDef);
+
+ /// Removes the blocking uses of the slot, in topological order.
+ void removeBlockingUses();
+
+ /// Lazily-constructed default value representing the content of the slot when
+ /// no store has been executed. This function may mutate IR.
+ Value getLazyDefaultValue();
+
+ MemorySlot slot;
+ PromotableAllocationOpInterface allocator;
+ RewriterBase &rewriter;
+ /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
+ /// initialize it on demand.
+ Value defaultValue;
+ /// Contains the reaching definition at this operation. Reaching definitions
+ /// are only computed for promotable memory operations with blocking uses.
+ DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+ DominanceInfo &dominance;
+ MemorySlotPromotionInfo info;
+ const Mem2RegStatistics &statistics;
+};
+
+} // namespace
+
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
- OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info)
- : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
- info(std::move(info)) {
+ RewriterBase &rewriter, DominanceInfo &dominance,
+ MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+ : slot(slot), allocator(allocator), rewriter(rewriter),
+ dominance(dominance), info(std::move(info)), statistics(statistics) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -114,9 +226,9 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
if (defaultValue)
return defaultValue;
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(slot.ptr.getParentBlock());
- return defaultValue = allocator.getDefaultValue(slot, builder);
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
+ return defaultValue = allocator.getDefaultValue(slot, rewriter);
}
LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
@@ -341,11 +453,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
Block *block = job.block->getBlock();
if (info.mergePoints.contains(block)) {
- BlockArgument blockArgument =
- block->addArgument(slot.elemType, slot.ptr.getLoc());
- builder.setInsertionPointToStart(block);
- allocator.handleBlockArgument(slot, blockArgument, builder);
+ // If the block is a merge point, we need to add a block argument to hold
+ // the selected reaching definition. This has to be a bit complicated
+ // because of RewriterBase limitations: we need to create a new block with
+ // the extra block argument, move the content of the block to the new
+ // block, and replace the block with the new block in the merge point set.
+ SmallVector<Type> argTypes;
+ SmallVector<Location> argLocs;
+ for (BlockArgument arg : block->getArguments()) {
+ argTypes.push_back(arg.getType());
+ argLocs.push_back(arg.getLoc());
+ }
+ argTypes.push_back(slot.elemType);
+ argLocs.push_back(slot.ptr.getLoc());
+ Block *newBlock = rewriter.createBlock(block, argTypes, argLocs);
+
+ info.mergePoints.erase(block);
+ info.mergePoints.insert(newBlock);
+
+ rewriter.replaceAllUsesWith(block, newBlock);
+ rewriter.mergeBlocks(block, newBlock,
+ newBlock->getArguments().drop_back());
+
+ block = newBlock;
+
+ BlockArgument blockArgument = block->getArguments().back();
+ rewriter.setInsertionPointToStart(block);
+ allocator.handleBlockArgument(slot, blockArgument, rewriter);
job.reachingDef = blockArgument;
+
+ if (statistics.newBlockArgumentAmount)
+ (*statistics.newBlockArgumentAmount)++;
}
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
@@ -355,8 +493,10 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
if (info.mergePoints.contains(blockOperand.get())) {
if (!job.reachingDef)
job.reachingDef = getLazyDefaultValue();
- terminator.getSuccessorOperands(blockOperand.getOperandNumber())
- .append(job.reachingDef);
+ rewriter.updateRootInPlace(terminator, [&]() {
+ terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+ .append(job.reachingDef);
+ });
}
}
}
@@ -382,24 +522,24 @@ void MemorySlotPromoter::removeBlockingUses() {
if (!reachingDef)
reachingDef = getLazyDefaultValue();
- builder.setInsertionPointAfter(toPromote);
+ rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], builder, reachingDef) ==
- DeletionKind::Delete)
+ slot, info.userToBlockingUses[toPromote], rewriter,
+ reachingDef) == DeletionKind::Delete)
toErase.push_back(toPromote);
continue;
}
auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
- builder.setInsertionPointAfter(toPromote);
+ rewriter.setInsertionPointAfter(toPromote);
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
- builder) == DeletionKind::Delete)
+ rewriter) == DeletionKind::Delete)
toErase.push_back(toPromote);
}
for (Operation *toEraseOp : toErase)
- toEraseOp->erase();
+ rewriter.eraseOp(toEraseOp);
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
@@ -421,87 +561,73 @@ void MemorySlotPromoter::promoteSlot() {
assert(succOperands.size() == mergePoint->getNumArguments() ||
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
- succOperands.append(getLazyDefaultValue());
+ rewriter.updateRootInPlace(
+ user, [&]() { succOperands.append(getLazyDefaultValue()); });
}
}
LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
<< "\n");
- allocator.handlePromotionComplete(slot, defaultValue);
+ if (statistics.promotedAmount)
+ (*statistics.promotedAmount)++;
+
+ allocator.handlePromotionComplete(slot, defaultValue, rewriter);
}
LogicalResult mlir::tryToPromoteMemorySlots(
- ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
- DominanceInfo &dominance) {
- // Actual promotion may invalidate the dominance analysis, so slot promotion
- // is prepated in batches.
- SmallVector<MemorySlotPromoter> toPromote;
+ ArrayRef<PromotableAllocationOpInterface> allocators,
+ RewriterBase &rewriter, Mem2RegStatistics statistics) {
+ DominanceInfo dominance;
+
+ bool promotedAny = false;
+
for (PromotableAllocationOpInterface allocator : allocators) {
for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
continue;
+ DominanceInfo dominance;
MemorySlotPromotionAnalyzer analyzer(slot, dominance);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
- if (info)
- toPromote.emplace_back(slot, allocator, builder, dominance,
- std::move(*info));
+ if (info) {
+ MemorySlotPromoter(slot, allocator, rewriter, dominance,
+ std::move(*info), statistics)
+ .promoteSlot();
+ promotedAny = true;
+ }
}
}
- for (MemorySlotPromoter &promoter : toPromote)
- promoter.promoteSlot();
-
- return success(!toPromote.empty());
+ return success(promotedAny);
}
-LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+LogicalResult
+Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
+ PatternRewriter &rewriter) const {
hasBoundedRewriteRecursion();
-
- if (op->getNumRegions() == 0)
- return failure();
-
- DominanceInfo dominance;
-
- SmallVector<PromotableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to promote the slots of.
- for (Region &region : op->getRegions())
- for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
- allocators.emplace_back(allocator);
-
- // Because pattern rewriters are normally not expressive enough to support a
- // transformation like mem2reg, this uses an escape hatch to mark modified
- // operations manually and operate outside of its context.
- rewriter.startRootUpdate(op);
-
- OpBuilder builder(rewriter.getContext());
-
- if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
- rewriter.cancelRootUpdate(op);
- return failure();
- }
-
- rewriter.finalizeRootUpdate(op);
- return success();
+ return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
}
namespace {
struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+ using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
+
void runOnOperation() override {
Operation *scopeOp = getOperation();
- bool changed = false;
+
+ Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = enableRegionSimplification;
RewritePatternSet rewritePatterns(&getContext());
- rewritePatterns.add<Mem2RegPattern>(&getContext());
+ rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
FrozenRewritePatternSet frozen(std::move(rewritePatterns));
- (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
- &changed);
- if (!changed)
- markAllAnalysesPreserved();
+ if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
+ signalPassFailure();
}
};