summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-05-15 15:26:13 +0200
committerMatthias Springer <me@m-sp.org>2023-05-15 15:31:56 +0200
commit1f479c1e46d111a6f001cf4ee24290f60f13257d (patch)
treecf07fbf5c61a63cdd5c9896abc0ed138be8b657a /mlir/lib
parent03b97e0c07fdcb3eac8580a60a4a3f88cee668b1 (diff)
downloadllvm-1f479c1e46d111a6f001cf4ee24290f60f13257d.tar.gz
[mlir][bufferization] Improve findValueInReverseUseDefChain signature
Instead of passing traversal options as a long list of arguments, store them in a TraversalConfig object and pass that object. Differential Revision: https://reviews.llvm.org/D143927
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp53
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp5
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp7
3 files changed, 40 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index cb69a9e5879c..712693ddd53a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -367,10 +367,7 @@ BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Value value) const {
- if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
- if (isOpAllowed(bufferizableOp.getOperation()))
- return bufferizableOp;
- return nullptr;
+ return dynCastBufferizableOp(getOwnerOfValue(value));
}
void BufferizationOptions::setFunctionBoundaryTypeConversion(
@@ -500,7 +497,7 @@ bool AnalysisState::isValueRead(Value value) const {
// further.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
- bool followEquivalentOnly, bool alwaysIncludeLeaves) const {
+ TraversalConfig config) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -512,7 +509,7 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
}
if (llvm::isa<BlockArgument>(value)) {
- if (alwaysIncludeLeaves)
+ if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
@@ -520,26 +517,43 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
OpResult opResult = llvm::cast<OpResult>(value);
BufferizableOpInterface bufferizableOp =
options.dynCastBufferizableOp(opResult.getDefiningOp());
- AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
+ if (!config.followUnknownOps && !bufferizableOp) {
+ // Stop iterating if `followUnknownOps` is unset and the op is either
+ // not bufferizable or excluded in the OpFilter.
+ if (config.alwaysIncludeLeaves)
+ result.insert(value);
+ continue;
+ }
- // Stop iterating in either one of these cases:
- // * The current op is not bufferizable or excluded in the filter.
- // * There are no OpOperands to follow.
- if (!bufferizableOp || aliases.getNumAliases() == 0) {
- if (alwaysIncludeLeaves)
+ AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
+ if (aliases.getNumAliases() == 0) {
+ // The traversal ends naturally if there are no more OpOperands that
+ // could be followed.
+ if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
for (AliasingOpOperand a : aliases) {
- if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) {
+ if (config.followEquivalentOnly &&
+ a.relation != BufferRelation::Equivalent) {
// Stop iterating if `followEquivalentOnly` is set but the alias is not
// equivalent.
- if (alwaysIncludeLeaves)
+ if (config.alwaysIncludeLeaves)
result.insert(value);
} else {
workingSet.insert(a.opOperand->get());
}
+
+ if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
+ // Stop iterating if `followInPlaceOnly` is set but the alias is
+ // out-of-place.
+ if (config.alwaysIncludeLeaves)
+ result.insert(value);
+ continue;
+ }
+
+ workingSet.insert(a.opOperand->get());
}
}
@@ -548,9 +562,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
// Find the values that define the contents of the given value.
llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+ TraversalConfig config;
+ config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
- /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+ value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -927,12 +942,12 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
return false;
return state.bufferizesToMemoryWrite(v);
};
+ TraversalConfig config;
+ config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
.findValueInReverseUseDefChain(alias.opOperand->get(),
- isMemoryWriteInsideOp,
- /*followEquivalentOnly=*/false,
- /*alwaysIncludeLeaves=*/false)
+ isMemoryWriteInsideOp, config)
.empty())
return true;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 58475d225ce8..76d424867af6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -132,10 +132,13 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
+ TraversalConfig config;
+ config.followEquivalentOnly = true;
+ config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
operand.get(), /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
- /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false);
+ config);
for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6da512699cc7..a9f05b21282d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -775,11 +775,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Find the values that define the contents of the given value.
const llvm::SetVector<Value> &
OneShotAnalysisState::findDefinitionsCached(Value value) {
- if (!cachedDefinitions.count(value)) {
- cachedDefinitions[value] = findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
- /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
- }
+ if (!cachedDefinitions.count(value))
+ cachedDefinitions[value] = findDefinitions(value);
return cachedDefinitions[value];
}