summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-05-15 14:31:26 +0200
committerMatthias Springer <me@m-sp.org>2023-05-15 14:33:06 +0200
commit38bef476552021b7ad45d1aa989d250bcd0a38ff (patch)
treecfe15d118c24f8ecd9a875d919fdb46092936ee3 /mlir
parentd0e89116aff224ac2d8d3f88029ae44e12c9b6cc (diff)
downloadllvm-38bef476552021b7ad45d1aa989d250bcd0a38ff.tar.gz
[mlir][bufferization] Fix unknown ops in BufferViewFlowAnalysis
If an op is unknown to the analysis, it must be treated conservatively: assume that every operand aliases with every result. Differential Revision: https://reviews.llvm.org/D150546
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp134
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir21
2 files changed, 95 insertions, 60 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index b4cfe89d7ced..d964f801668f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,7 +8,6 @@
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
@@ -58,74 +57,89 @@ void BufferViewFlowAnalysis::build(Operation *op) {
this->dependencies[value].insert(dep);
};
- // Add additional dependencies created by view changes to the alias list.
- op->walk([&](ViewLikeOpInterface viewInterface) {
- dependencies[viewInterface.getViewSource()].insert(
- viewInterface->getResult(0));
- });
+ op->walk([&](Operation *op) {
+ // TODO: We should have an op interface instead of a hard-coded list of
+ // interfaces/ops.
- // Query all branch interfaces to link block argument dependencies.
- op->walk([&](BranchOpInterface branchInterface) {
- Block *parentBlock = branchInterface->getBlock();
- for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
- it != e; ++it) {
- // Query the branch op interface to get the successor operands.
- auto successorOperands =
- branchInterface.getSuccessorOperands(it.getIndex());
- // Build the actual mapping of values to their immediate dependencies.
- registerDependencies(successorOperands.getForwardedOperands(),
- (*it)->getArguments().drop_front(
- successorOperands.getProducedOperandCount()));
+ // Add additional dependencies created by view changes to the alias list.
+ if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
+ dependencies[viewInterface.getViewSource()].insert(
+ viewInterface->getResult(0));
+ return WalkResult::advance();
}
- });
- // Query the RegionBranchOpInterface to find potential successor regions.
- op->walk([&](RegionBranchOpInterface regionInterface) {
- // Extract all entry regions and wire all initial entry successor inputs.
- SmallVector<RegionSuccessor, 2> entrySuccessors;
- regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
- entrySuccessors);
- for (RegionSuccessor &entrySuccessor : entrySuccessors) {
- // Wire the entry region's successor arguments with the initial
- // successor inputs.
- assert(entrySuccessor.getSuccessor() &&
- "Invalid entry region without an attached successor region");
- registerDependencies(
- regionInterface.getSuccessorEntryOperands(
- entrySuccessor.getSuccessor()->getRegionNumber()),
- entrySuccessor.getSuccessorInputs());
+ if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
+ // Query all branch interfaces to link block argument dependencies.
+ Block *parentBlock = branchInterface->getBlock();
+ for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
+ it != e; ++it) {
+ // Query the branch op interface to get the successor operands.
+ auto successorOperands =
+ branchInterface.getSuccessorOperands(it.getIndex());
+ // Build the actual mapping of values to their immediate dependencies.
+ registerDependencies(successorOperands.getForwardedOperands(),
+ (*it)->getArguments().drop_front(
+ successorOperands.getProducedOperandCount()));
+ }
+ return WalkResult::advance();
}
- // Wire flow between regions and from region exits.
- for (Region &region : regionInterface->getRegions()) {
- // Iterate over all successor region entries that are reachable from the
- // current region.
- SmallVector<RegionSuccessor, 2> successorRegions;
- regionInterface.getSuccessorRegions(region.getRegionNumber(),
- successorRegions);
- for (RegionSuccessor &successorRegion : successorRegions) {
- // Determine the current region index (if any).
- std::optional<unsigned> regionIndex;
- Region *regionSuccessor = successorRegion.getSuccessor();
- if (regionSuccessor)
- regionIndex = regionSuccessor->getRegionNumber();
- // Iterate over all immediate terminator operations and wire the
- // successor inputs with the successor operands of each terminator.
- for (Block &block : region) {
- auto successorOperands = getRegionBranchSuccessorOperands(
- block.getTerminator(), regionIndex);
- if (successorOperands) {
- registerDependencies(*successorOperands,
- successorRegion.getSuccessorInputs());
+ if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
+ // Query the RegionBranchOpInterface to find potential successor regions.
+ // Extract all entry regions and wire all initial entry successor inputs.
+ SmallVector<RegionSuccessor, 2> entrySuccessors;
+ regionInterface.getSuccessorRegions(/*index=*/std::nullopt,
+ entrySuccessors);
+ for (RegionSuccessor &entrySuccessor : entrySuccessors) {
+ // Wire the entry region's successor arguments with the initial
+ // successor inputs.
+ assert(entrySuccessor.getSuccessor() &&
+ "Invalid entry region without an attached successor region");
+ registerDependencies(
+ regionInterface.getSuccessorEntryOperands(
+ entrySuccessor.getSuccessor()->getRegionNumber()),
+ entrySuccessor.getSuccessorInputs());
+ }
+
+ // Wire flow between regions and from region exits.
+ for (Region &region : regionInterface->getRegions()) {
+ // Iterate over all successor region entries that are reachable from the
+ // current region.
+ SmallVector<RegionSuccessor, 2> successorRegions;
+ regionInterface.getSuccessorRegions(region.getRegionNumber(),
+ successorRegions);
+ for (RegionSuccessor &successorRegion : successorRegions) {
+ // Determine the current region index (if any).
+ std::optional<unsigned> regionIndex;
+ Region *regionSuccessor = successorRegion.getSuccessor();
+ if (regionSuccessor)
+ regionIndex = regionSuccessor->getRegionNumber();
+ // Iterate over all immediate terminator operations and wire the
+ // successor inputs with the successor operands of each terminator.
+ for (Block &block : region) {
+ auto successorOperands = getRegionBranchSuccessorOperands(
+ block.getTerminator(), regionIndex);
+ if (successorOperands) {
+ registerDependencies(*successorOperands,
+ successorRegion.getSuccessorInputs());
+ }
}
}
}
+
+ return WalkResult::advance();
}
- });
- // TODO: This should be an interface.
- op->walk([&](arith::SelectOp selectOp) {
- registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()});
- registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()});
+ // Unknown op: Assume that all operands alias with all results.
+ for (Value operand : op->getOperands()) {
+ if (!isa<BaseMemRefType>(operand.getType()))
+ continue;
+ for (Value result : op->getResults()) {
+ if (!isa<BaseMemRefType>(result.getType()))
+ continue;
+ registerDependencies({operand}, {result});
+ }
+ }
+ return WalkResult::advance();
});
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
index 384657222725..3fbe3913c654 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
@@ -1317,6 +1317,27 @@ func.func @select_aliases(%arg0: index, %arg1: memref<?xi8>, %arg2: i1) {
// -----
+func.func @f(%arg0: memref<f64>) -> memref<f64> {
+ return %arg0 : memref<f64>
+}
+
+// CHECK-LABEL: func @function_call
+// CHECK: memref.alloc
+// CHECK: memref.alloc
+// CHECK: call
+// CHECK: test.copy
+// CHECK: memref.dealloc
+// CHECK: memref.dealloc
+func.func @function_call() {
+ %alloc = memref.alloc() : memref<f64>
+ %alloc2 = memref.alloc() : memref<f64>
+ %ret = call @f(%alloc) : (memref<f64>) -> memref<f64>
+ test.copy(%ret, %alloc2) : (memref<f64>, memref<f64>)
+ return
+}
+
+// -----
+
// Memref allocated in `then` region and passed back to the parent if op.
#set = affine_set<() : (0 >= 0)>
// CHECK-LABEL: func @test_affine_if_1