summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorThéo Degioanni <theo.degioanni@nextsilicon.com>2023-05-16 08:35:00 +0000
committerTobias Gysi <tobias.gysi@nextsilicon.com>2023-05-16 08:35:13 +0000
commitead8e9d7953e817c52fdfaf7196dfeb2199dab26 (patch)
treea4316f426b23b274c408c59041c0621f85c44355 /mlir
parentdadb77b626cd04d749487b3e2711fb23e5d17200 (diff)
downloadllvm-ead8e9d7953e817c52fdfaf7196dfeb2199dab26.tar.gz
[mlir] [mem2reg] Adapt to be pattern-friendly.
This revision modifies the mem2reg interfaces and algorithm to be more omfortable to use as a pattern. The motivation behind this is that currently the pattern needs to be applied to the scope op of the region in which allocators should be promoted. However, a more natural way to apply the pattern would be to apply it on the allocator directly. This is not only clearer but easier to parallelize. This revision changes the mem2reg pattern to operate this way. This required restraining the interfaces to only mutate IR using RewriterBase, as the previously used escape hatch is not granular enough to match on the region that is modified only. This has the unfortunate cost of preventing batching allocator promotion and making the block argument adding logic more complex. Because batching no longer made any sense, I made the internal analyzer/promoter decoupling private again. This also adds statistics to the mem2reg infrastructure. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D150432
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Interfaces/MemorySlotInterfaces.h1
-rw-r--r--mlir/include/mlir/Interfaces/MemorySlotInterfaces.td42
-rw-r--r--mlir/include/mlir/Transforms/Mem2Reg.h122
-rw-r--r--mlir/include/mlir/Transforms/Passes.td15
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp39
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp23
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp258
-rw-r--r--mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir2
-rw-r--r--mlir/test/Dialect/LLVMIR/mem2reg.mlir2
-rw-r--r--mlir/test/Dialect/MemRef/mem2reg-statistics.mlir60
-rw-r--r--mlir/test/Dialect/MemRef/mem2reg.mlir18
11 files changed, 357 insertions, 225 deletions
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index be761ae427ac..c0f8b2f8ee9c 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index f98bdbabc4c0..73061f79521a 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -31,6 +31,8 @@ def PromotableAllocationOpInterface
Promotion of the slot will lead to the slot pointer no longer being
used, leaving the content of the memory slot unreachable.
+
+ No IR mutation is allowed in this method.
}], "::llvm::SmallVector<::mlir::MemorySlot>", "getPromotableSlots",
(ins)
>,
@@ -38,34 +40,42 @@ def PromotableAllocationOpInterface
Provides the default Value of this memory slot. The provided Value
will be used as the reaching definition of loads done before any store.
This Value must outlive the promotion and dominate all the uses of this
- slot's pointer. The provided builder can be used to create the default
+ slot's pointer. The provided rewriter can be used to create the default
value on the fly.
- The builder is located at the beginning of the block where the slot
- pointer is defined.
+ The rewriter is located at the beginning of the block where the slot
+ pointer is defined. All IR mutations must happen through the rewriter.
}], "::mlir::Value", "getDefaultValue",
- (ins "const ::mlir::MemorySlot &":$slot, "::mlir::OpBuilder &":$builder)
+ (ins
+ "const ::mlir::MemorySlot &":$slot,
+ "::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<[{
Hook triggered for every new block argument added to a block.
This will only be called for slots declared by this operation.
- The builder is located at the beginning of the block on call.
+ The rewriter is located at the beginning of the block on call. All IR
+ mutations must happen through the rewriter.
}],
"void", "handleBlockArgument",
(ins
"const ::mlir::MemorySlot &":$slot,
"::mlir::BlockArgument":$argument,
- "::mlir::OpBuilder &":$builder
+ "::mlir::RewriterBase &":$rewriter
)
>,
InterfaceMethod<[{
Hook triggered once the promotion of a slot is complete. This can
also clean up the created default value if necessary.
This will only be called for slots declared by this operation.
+
+ All IR mutations must happen through the rewriter.
}],
"void", "handlePromotionComplete",
- (ins "const ::mlir::MemorySlot &":$slot, "::mlir::Value":$defaultValue)
+ (ins
+ "const ::mlir::MemorySlot &":$slot,
+ "::mlir::Value":$defaultValue,
+ "::mlir::RewriterBase &":$rewriter)
>,
];
}
@@ -87,6 +97,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
let methods = [
InterfaceMethod<[{
Gets whether this operation loads from the specified slot.
+
+ No IR mutation is allowed in this method.
}],
"bool", "loadsFrom",
(ins "const ::mlir::MemorySlot &":$slot)
@@ -96,6 +108,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
value if this operation does not store to this slot. An operation
storing a value to a slot must always be able to provide the value it
stores. This method is only called on operations that use the slot.
+
+ No IR mutation is allowed in this method.
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot)
@@ -107,6 +121,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
If the removal procedure of the use will require that other uses get
removed, that dependency should be added to the `newBlockingUses`
argument. Dependent uses must only be uses of results of this operation.
+
+ No IR mutation is allowed in this method.
}], "bool", "canUsesBeRemoved",
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
@@ -132,13 +148,14 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
have been done at the point of calling this method, but it will be done
eventually.
- The builder is located after the promotable operation on call.
+ The rewriter is located after the promotable operation on call. All IR
+ mutations must happen through the rewriter.
}],
"::mlir::DeletionKind",
"removeBlockingUses",
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
- "::mlir::OpBuilder &":$builder,
+ "::mlir::RewriterBase &":$rewriter,
"::mlir::Value":$reachingDefinition)
>,
];
@@ -160,6 +177,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
If the removal procedure of the use will require that other uses get
removed, that dependency should be added to the `newBlockingUses`
argument. Dependent uses must only be uses of results of this operation.
+
+ No IR mutation is allowed in this method.
}], "bool", "canUsesBeRemoved",
(ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
"::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses)
@@ -185,12 +204,13 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
have been done at the point of calling this method, but it will be done
eventually.
- The builder is located after the promotable operation on call.
+ The rewriter is located after the promotable operation on call. All IR
+ mutations must happen through the rewriter.
}],
"::mlir::DeletionKind",
"removeBlockingUses",
(ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
- "::mlir::OpBuilder &":$builder)
+ "::mlir::RewriterBase &":$rewriter)
>,
];
}
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index a34ea68e750b..46b2a1f56d21 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -13,129 +13,39 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/Statistic.h"
namespace mlir {
-/// Information computed during promotion analysis used to perform actual
-/// promotion.
-struct MemorySlotPromotionInfo {
- /// Blocks for which at least two definitions of the slot values clash.
- SmallPtrSet<Block *, 8> mergePoints;
- /// Contains, for each operation, which uses must be eliminated by promotion.
- /// This is a DAG structure because if an operation must eliminate some of
- /// its uses, it is because the defining ops of the blocking uses requested
- /// it. The defining ops therefore must also have blocking uses or be the
- /// starting point of the bloccking uses.
- DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
-};
-
-/// Computes information for basic slot promotion. This will check that direct
-/// slot promotion can be performed, and provide the information to execute the
-/// promotion. This does not mutate IR.
-class MemorySlotPromotionAnalyzer {
-public:
- MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
- : slot(slot), dominance(dominance) {}
-
- /// Computes the information for slot promotion if promotion is possible,
- /// returns nothing otherwise.
- std::optional<MemorySlotPromotionInfo> computeInfo();
-
-private:
- /// Computes the transitive uses of the slot that block promotion. This finds
- /// uses that would block the promotion, checks that the operation has a
- /// solution to remove the blocking use, and potentially forwards the analysis
- /// if the operation needs further blocking uses resolved to resolve its own
- /// uses (typically, removing its users because it will delete itself to
- /// resolve its own blocking uses). This will fail if one of the transitive
- /// users cannot remove a requested use, and should prevent promotion.
- LogicalResult computeBlockingUses(
- DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
-
- /// Computes in which blocks the value stored in the slot is actually used,
- /// meaning blocks leading to a load. This method uses `definingBlocks`, the
- /// set of blocks containing a store to the slot (defining the value of the
- /// slot).
- SmallPtrSet<Block *, 16>
- computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
-
- /// Computes the points in which multiple re-definitions of the slot's value
- /// (stores) may conflict.
- void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
-
- /// Ensures predecessors of merge points can properly provide their current
- /// definition of the value stored in the slot to the merge point. This can
- /// notably be an issue if the terminator used does not have the ability to
- /// forward values through block operands.
- bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
-
- MemorySlot slot;
- DominanceInfo &dominance;
-};
-
-/// The MemorySlotPromoter handles the state of promoting a memory slot. It
-/// wraps a slot and its associated allocator. This will perform the mutation of
-/// IR.
-class MemorySlotPromoter {
-public:
- MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
- OpBuilder &builder, DominanceInfo &dominance,
- MemorySlotPromotionInfo info);
-
- /// Actually promotes the slot by mutating IR. Promoting a slot does not
- /// invalidate the MemorySlotPromotionInfo of other slots.
- void promoteSlot();
-
-private:
- /// Computes the reaching definition for all the operations that require
- /// promotion. `reachingDef` is the value the slot should contain at the
- /// beginning of the block. This method returns the reached definition at the
- /// end of the block.
- Value computeReachingDefInBlock(Block *block, Value reachingDef);
-
- /// Computes the reaching definition for all the operations that require
- /// promotion. `reachingDef` corresponds to the initial value the
- /// slot will contain before any write, typically a poison value.
- void computeReachingDefInRegion(Region *region, Value reachingDef);
-
- /// Removes the blocking uses of the slot, in topological order.
- void removeBlockingUses();
-
- /// Lazily-constructed default value representing the content of the slot when
- /// no store has been executed. This function may mutate IR.
- Value getLazyDefaultValue();
-
- MemorySlot slot;
- PromotableAllocationOpInterface allocator;
- OpBuilder &builder;
- /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
- /// initialize it on demand.
- Value defaultValue;
- /// Contains the reaching definition at this operation. Reaching definitions
- /// are only computed for promotable memory operations with blocking uses.
- DenseMap<PromotableMemOpInterface, Value> reachingDefs;
- DominanceInfo &dominance;
- MemorySlotPromotionInfo info;
+struct Mem2RegStatistics {
+ llvm::Statistic *promotedAmount = nullptr;
+ llvm::Statistic *newBlockArgumentAmount = nullptr;
};
/// Pattern applying mem2reg to the regions of the operations on which it
/// matches.
-class Mem2RegPattern : public RewritePattern {
+class Mem2RegPattern
+ : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
public:
- using RewritePattern::RewritePattern;
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {}
+ Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
- LogicalResult matchAndRewrite(Operation *op,
+ LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
PatternRewriter &rewriter) const override;
+
+private:
+ Mem2RegStatistics statistics;
};
/// Attempts to promote the memory slots of the provided allocators. Succeeds if
/// at least one memory slot was promoted.
LogicalResult
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
- OpBuilder &builder, DominanceInfo &dominance);
+ RewriterBase &rewriter,
+ Mem2RegStatistics statistics = {});
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1cc357ca1f9f..62b8dd075f21 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -189,6 +189,21 @@ def Mem2Reg : Pass<"mem2reg"> {
This pass only supports unstructured control-flow. Promotion of operations
within subregions will not happen.
}];
+
+ let options = [
+ Option<"enableRegionSimplification", "region-simplify", "bool",
+ /*default=*/"true",
+ "Perform control flow optimizations to the region tree">,
+ ];
+
+ let statistics = [
+ Statistic<"promotedAmount",
+ "promoted slots",
+ "Number of promoted memory slot">,
+ Statistic<"newBlockArgumentAmount",
+ "new block args",
+ "Total number of block arguments added">,
+ ];
}
def PrintOpStats : Pass<"print-op-stats"> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index e4fd2a755d90..51c49892d71c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -35,24 +35,25 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
}
Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
- OpBuilder &builder) {
- return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
+ RewriterBase &rewriter) {
+ return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
}
void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
BlockArgument argument,
- OpBuilder &builder) {
+ RewriterBase &rewriter) {
for (Operation *user : getOperation()->getUsers())
if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
- builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
- declareOp.getVarInfo());
+ rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
+ declareOp.getVarInfo());
}
void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue) {
+ Value defaultValue,
+ RewriterBase &rewriter) {
if (defaultValue && defaultValue.use_empty())
- defaultValue.getDefiningOp()->erase();
- erase();
+ rewriter.eraseOp(defaultValue.getDefiningOp());
+ rewriter.eraseOp(*this);
}
//===----------------------------------------------------------------------===//
@@ -87,10 +88,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- OpBuilder &builder, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- getResult().replaceAllUsesWith(reachingDefinition);
+ rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
return DeletionKind::Delete;
}
@@ -110,13 +111,13 @@ bool LLVM::StoreOp::canUsesBeRemoved(
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- OpBuilder &builder, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the stored slot
// pointer.
for (Operation *user : slot.ptr.getUsers())
if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
- builder.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
- declareOp.getVarInfo());
+ rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
+ declareOp.getVarInfo());
return DeletionKind::Delete;
}
@@ -140,7 +141,7 @@ bool LLVM::BitcastOp::canUsesBeRemoved(
}
DeletionKind LLVM::BitcastOp::removeBlockingUses(
- const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+ const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
@@ -151,7 +152,7 @@ bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
}
DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
- const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+ const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
@@ -162,7 +163,7 @@ bool LLVM::LifetimeStartOp::canUsesBeRemoved(
}
DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
- const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+ const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
@@ -173,7 +174,7 @@ bool LLVM::LifetimeEndOp::canUsesBeRemoved(
}
DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
- const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+ const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
@@ -184,7 +185,7 @@ bool LLVM::DbgDeclareOp::canUsesBeRemoved(
}
DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
- const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+ const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
@@ -209,6 +210,6 @@ bool LLVM::GEPOp::canUsesBeRemoved(
}
DeletionKind LLVM::GEPOp::removeBlockingUses(
- const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+ const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
index b5f5272d6421..12d9ebd5a02a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
@@ -40,29 +40,30 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
}
Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
- OpBuilder &builder) {
+ RewriterBase &rewriter) {
assert(isSupportedElementType(slot.elemType));
// TODO: support more types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](MemRefType t) {
- return builder.create<memref::AllocaOp>(getLoc(), t);
+ return rewriter.create<memref::AllocaOp>(getLoc(), t);
})
.Default([&](Type t) {
- return builder.create<arith::ConstantOp>(getLoc(), t,
- builder.getZeroAttr(t));
+ return rewriter.create<arith::ConstantOp>(getLoc(), t,
+ rewriter.getZeroAttr(t));
});
}
void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue) {
+ Value defaultValue,
+ RewriterBase &rewriter) {
if (defaultValue.use_empty())
- defaultValue.getDefiningOp()->erase();
- erase();
+ rewriter.eraseOp(defaultValue.getDefiningOp());
+ rewriter.eraseOp(*this);
}
void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
BlockArgument argument,
- OpBuilder &builder) {}
+ RewriterBase &rewriter) {}
//===----------------------------------------------------------------------===//
// LoadOp/StoreOp interfaces
@@ -86,10 +87,10 @@ bool memref::LoadOp::canUsesBeRemoved(
DeletionKind memref::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- OpBuilder &builder, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- getResult().replaceAllUsesWith(reachingDefinition);
+ rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
return DeletionKind::Delete;
}
@@ -113,6 +114,6 @@ bool memref::StoreOp::canUsesBeRemoved(
DeletionKind memref::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- OpBuilder &builder, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition) {
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 45d6f7d0c1ed..3b303f9836cf 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -10,6 +10,8 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -92,11 +94,121 @@ using namespace mlir;
/// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
/// Springer.
+namespace {
+
+/// Information computed during promotion analysis used to perform actual
+/// promotion.
+struct MemorySlotPromotionInfo {
+ /// Blocks for which at least two definitions of the slot values clash.
+ SmallPtrSet<Block *, 8> mergePoints;
+ /// Contains, for each operation, which uses must be eliminated by promotion.
+ /// This is a DAG structure because if an operation must eliminate some of
+ /// its uses, it is because the defining ops of the blocking uses requested
+ /// it. The defining ops therefore must also have blocking uses or be the
+ /// starting point of the bloccking uses.
+ DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+};
+
+/// Computes information for basic slot promotion. This will check that direct
+/// slot promotion can be performed, and provide the information to execute the
+/// promotion. This does not mutate IR.
+class MemorySlotPromotionAnalyzer {
+public:
+ MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
+ : slot(slot), dominance(dominance) {}
+
+ /// Computes the information for slot promotion if promotion is possible,
+ /// returns nothing otherwise.
+ std::optional<MemorySlotPromotionInfo> computeInfo();
+
+private:
+ /// Computes the transitive uses of the slot that block promotion. This finds
+ /// uses that would block the promotion, checks that the operation has a
+ /// solution to remove the blocking use, and potentially forwards the analysis
+ /// if the operation needs further blocking uses resolved to resolve its own
+ /// uses (typically, removing its users because it will delete itself to
+ /// resolve its own blocking uses). This will fail if one of the transitive
+ /// users cannot remove a requested use, and should prevent promotion.
+ LogicalResult computeBlockingUses(
+ DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+
+ /// Computes in which blocks the value stored in the slot is actually used,
+ /// meaning blocks leading to a load. This method uses `definingBlocks`, the
+ /// set of blocks containing a store to the slot (defining the value of the
+ /// slot).
+ SmallPtrSet<Block *, 16>
+ computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
+
+ /// Computes the points in which multiple re-definitions of the slot's value
+ /// (stores) may conflict.
+ void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
+
+ /// Ensures predecessors of merge points can properly provide their current
+ /// definition of the value stored in the slot to the merge point. This can
+ /// notably be an issue if the terminator used does not have the ability to
+ /// forward values through block operands.
+ bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
+
+ MemorySlot slot;
+ DominanceInfo &dominance;
+};
+
+/// The MemorySlotPromoter handles the state of promoting a memory slot. It
+/// wraps a slot and its associated allocator. This will perform the mutation of
+/// IR.
+class MemorySlotPromoter {
+public:
+ MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
+ RewriterBase &rewriter, DominanceInfo &dominance,
+ MemorySlotPromotionInfo info,
+ const Mem2RegStatistics &statistics);
+
+ /// Actually promotes the slot by mutating IR. Promoting a slot DOES
+ /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
+ /// promotion info should NOT be performed in batches.
+ void promoteSlot();
+
+private:
+ /// Computes the reaching definition for all the operations that require
+ /// promotion. `reachingDef` is the value the slot should contain at the
+ /// beginning of the block. This method returns the reached definition at the
+ /// end of the block.
+ Value computeReachingDefInBlock(Block *block, Value reachingDef);
+
+ /// Computes the reaching definition for all the operations that require
+ /// promotion. `reachingDef` corresponds to the initial value the
+ /// slot will contain before any write, typically a poison value.
+ void computeReachingDefInRegion(Region *region, Value reachingDef);
+
+ /// Removes the blocking uses of the slot, in topological order.
+ void removeBlockingUses();
+
+ /// Lazily-constructed default value representing the content of the slot when
+ /// no store has been executed. This function may mutate IR.
+ Value getLazyDefaultValue();
+
+ MemorySlot slot;
+ PromotableAllocationOpInterface allocator;
+ RewriterBase &rewriter;
+ /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
+ /// initialize it on demand.
+ Value defaultValue;
+ /// Contains the reaching definition at this operation. Reaching definitions
+ /// are only computed for promotable memory operations with blocking uses.
+ DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+ DominanceInfo &dominance;
+ MemorySlotPromotionInfo info;
+ const Mem2RegStatistics &statistics;
+};
+
+} // namespace
+
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
- OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info)
- : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
- info(std::move(info)) {
+ RewriterBase &rewriter, DominanceInfo &dominance,
+ MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+ : slot(slot), allocator(allocator), rewriter(rewriter),
+ dominance(dominance), info(std::move(info)), statistics(statistics) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -114,9 +226,9 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
if (defaultValue)
return defaultValue;
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(slot.ptr.getParentBlock());
- return defaultValue = allocator.getDefaultValue(slot, builder);
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
+ return defaultValue = allocator.getDefaultValue(slot, rewriter);
}
LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
@@ -341,11 +453,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
Block *block = job.block->getBlock();
if (info.mergePoints.contains(block)) {
- BlockArgument blockArgument =
- block->addArgument(slot.elemType, slot.ptr.getLoc());
- builder.setInsertionPointToStart(block);
- allocator.handleBlockArgument(slot, blockArgument, builder);
+ // If the block is a merge point, we need to add a block argument to hold
+ // the selected reaching definition. This has to be a bit complicated
+ // because of RewriterBase limitations: we need to create a new block with
+ // the extra block argument, move the content of the block to the new
+ // block, and replace the block with the new block in the merge point set.
+ SmallVector<Type> argTypes;
+ SmallVector<Location> argLocs;
+ for (BlockArgument arg : block->getArguments()) {
+ argTypes.push_back(arg.getType());
+ argLocs.push_back(arg.getLoc());
+ }
+ argTypes.push_back(slot.elemType);
+ argLocs.push_back(slot.ptr.getLoc());
+ Block *newBlock = rewriter.createBlock(block, argTypes, argLocs);
+
+ info.mergePoints.erase(block);
+ info.mergePoints.insert(newBlock);
+
+ rewriter.replaceAllUsesWith(block, newBlock);
+ rewriter.mergeBlocks(block, newBlock,
+ newBlock->getArguments().drop_back());
+
+ block = newBlock;
+
+ BlockArgument blockArgument = block->getArguments().back();
+ rewriter.setInsertionPointToStart(block);
+ allocator.handleBlockArgument(slot, blockArgument, rewriter);
job.reachingDef = blockArgument;
+
+ if (statistics.newBlockArgumentAmount)
+ (*statistics.newBlockArgumentAmount)++;
}
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
@@ -355,8 +493,10 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
if (info.mergePoints.contains(blockOperand.get())) {
if (!job.reachingDef)
job.reachingDef = getLazyDefaultValue();
- terminator.getSuccessorOperands(blockOperand.getOperandNumber())
- .append(job.reachingDef);
+ rewriter.updateRootInPlace(terminator, [&]() {
+ terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+ .append(job.reachingDef);
+ });
}
}
}
@@ -382,24 +522,24 @@ void MemorySlotPromoter::removeBlockingUses() {
if (!reachingDef)
reachingDef = getLazyDefaultValue();
- builder.setInsertionPointAfter(toPromote);
+ rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], builder, reachingDef) ==
- DeletionKind::Delete)
+ slot, info.userToBlockingUses[toPromote], rewriter,
+ reachingDef) == DeletionKind::Delete)
toErase.push_back(toPromote);
continue;
}
auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
- builder.setInsertionPointAfter(toPromote);
+ rewriter.setInsertionPointAfter(toPromote);
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
- builder) == DeletionKind::Delete)
+ rewriter) == DeletionKind::Delete)
toErase.push_back(toPromote);
}
for (Operation *toEraseOp : toErase)
- toEraseOp->erase();
+ rewriter.eraseOp(toEraseOp);
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
@@ -421,87 +561,73 @@ void MemorySlotPromoter::promoteSlot() {
assert(succOperands.size() == mergePoint->getNumArguments() ||
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
- succOperands.append(getLazyDefaultValue());
+ rewriter.updateRootInPlace(
+ user, [&]() { succOperands.append(getLazyDefaultValue()); });
}
}
LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
<< "\n");
- allocator.handlePromotionComplete(slot, defaultValue);
+ if (statistics.promotedAmount)
+ (*statistics.promotedAmount)++;
+
+ allocator.handlePromotionComplete(slot, defaultValue, rewriter);
}
LogicalResult mlir::tryToPromoteMemorySlots(
- ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
- DominanceInfo &dominance) {
- // Actual promotion may invalidate the dominance analysis, so slot promotion
- // is prepated in batches.
- SmallVector<MemorySlotPromoter> toPromote;
+ ArrayRef<PromotableAllocationOpInterface> allocators,
+ RewriterBase &rewriter, Mem2RegStatistics statistics) {
+ DominanceInfo dominance;
+
+ bool promotedAny = false;
+
for (PromotableAllocationOpInterface allocator : allocators) {
for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
continue;
+ DominanceInfo dominance;
MemorySlotPromotionAnalyzer analyzer(slot, dominance);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
- if (info)
- toPromote.emplace_back(slot, allocator, builder, dominance,
- std::move(*info));
+ if (info) {
+ MemorySlotPromoter(slot, allocator, rewriter, dominance,
+ std::move(*info), statistics)
+ .promoteSlot();
+ promotedAny = true;
+ }
}
}
- for (MemorySlotPromoter &promoter : toPromote)
- promoter.promoteSlot();
-
- return success(!toPromote.empty());
+ return success(promotedAny);
}
-LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
+LogicalResult
+Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
+ PatternRewriter &rewriter) const {
hasBoundedRewriteRecursion();
-
- if (op->getNumRegions() == 0)
- return failure();
-
- DominanceInfo dominance;
-
- SmallVector<PromotableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to promote the slots of.
- for (Region &region : op->getRegions())
- for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
- allocators.emplace_back(allocator);
-
- // Because pattern rewriters are normally not expressive enough to support a
- // transformation like mem2reg, this uses an escape hatch to mark modified
- // operations manually and operate outside of its context.
- rewriter.startRootUpdate(op);
-
- OpBuilder builder(rewriter.getContext());
-
- if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
- rewriter.cancelRootUpdate(op);
- return failure();
- }
-
- rewriter.finalizeRootUpdate(op);
- return success();
+ return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
}
namespace {
struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+ using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
+
void runOnOperation() override {
Operation *scopeOp = getOperation();
- bool changed = false;
+
+ Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = enableRegionSimplification;
RewritePatternSet rewritePatterns(&getContext());
- rewritePatterns.add<Mem2RegPattern>(&getContext());
+ rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
FrozenRewritePatternSet frozen(std::move(rewritePatterns));
- (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
- &changed);
- if (!changed)
- markAllAnalysesPreserved();
+ if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
+ signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
index d8d04dfcfec5..0c1908ec8fdc 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg{region-simplify=false}))' | FileCheck %s
llvm.func @use(i64)
llvm.func @use_ptr(!llvm.ptr)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 090f9133f7a9..fc696c5073c3 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @default_value
llvm.func @default_value() -> i32 {
diff --git a/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir
new file mode 100644
index 000000000000..29ca51194ffd
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file --mlir-pass-statistics 2>&1 >/dev/null | FileCheck %s
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 1 promoted slots
+func.func @basic() -> i32 {
+ %0 = arith.constant 5 : i32
+ %1 = memref.alloca() : memref<i32>
+ memref.store %0, %1[] : memref<i32>
+ %2 = memref.load %1[] : memref<i32>
+ return %2 : i32
+}
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 0 promoted slots
+func.func @no_alloca() -> i32 {
+ %0 = arith.constant 5 : i32
+ return %0 : i32
+}
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 2 new block args
+// CHECK-NEXT: (S) 1 promoted slots
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+ %alloca = memref.alloca() : memref<i64>
+ memref.store %arg2, %alloca[] : memref<i64>
+ cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ %use = memref.load %alloca[] : memref<i64>
+ call @use(%use) : (i64) -> ()
+ memref.store %arg0, %alloca[] : memref<i64>
+ cf.br ^bb2
+^bb2:
+ cf.br ^bb1
+}
+
+func.func @use(%arg: i64) { return }
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 3 promoted slots
+func.func @recursive(%arg: i64) -> i64 {
+ %alloca0 = memref.alloca() : memref<memref<memref<i64>>>
+ %alloca1 = memref.alloca() : memref<memref<i64>>
+ %alloca2 = memref.alloca() : memref<i64>
+ memref.store %arg, %alloca2[] : memref<i64>
+ memref.store %alloca2, %alloca1[] : memref<memref<i64>>
+ memref.store %alloca1, %alloca0[] : memref<memref<memref<i64>>>
+ %load0 = memref.load %alloca0[] : memref<memref<memref<i64>>>
+ %load1 = memref.load %load0[] : memref<memref<i64>>
+ %load2 = memref.load %load1[] : memref<i64>
+ return %load2 : i64
+}
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index 86707ac0b497..d300699f6f34 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg{region-simplify=false}))' --split-input-file | FileCheck %s
// CHECK-LABEL: func.func @basic
func.func @basic() -> i32 {
@@ -148,20 +148,18 @@ func.func @deny_store_of_alloca(%arg: memref<memref<i32>>) -> i32 {
// CHECK-LABEL: func.func @promotable_nonpromotable_intertwined
func.func @promotable_nonpromotable_intertwined() -> i32 {
- // CHECK: %[[VAL:.*]] = arith.constant 5 : i32
- %0 = arith.constant 5 : i32
// CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref<i32>
- %1 = memref.alloca() : memref<i32>
+ %0 = memref.alloca() : memref<i32>
// CHECK-NOT: = memref.alloca() : memref<memref<i32>>
- %2 = memref.alloca() : memref<memref<i32>>
- memref.store %1, %2[] : memref<memref<i32>>
- %3 = memref.load %2[] : memref<memref<i32>>
+ %1 = memref.alloca() : memref<memref<i32>>
+ memref.store %0, %1[] : memref<memref<i32>>
+ %2 = memref.load %1[] : memref<memref<i32>>
// CHECK: call @use(%[[NON_PROMOTED]])
- call @use(%1) : (memref<i32>) -> ()
+ call @use(%0) : (memref<i32>) -> ()
// CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][]
- %4 = memref.load %1[] : memref<i32>
+ %3 = memref.load %0[] : memref<i32>
// CHECK: return %[[RES]] : i32
- return %4 : i32
+ return %3 : i32
}
func.func @use(%arg: memref<i32>) { return }