diff options
author | Matthias Springer <me@m-sp.org> | 2023-05-15 15:26:13 +0200 |
---|---|---|
committer | Matthias Springer <me@m-sp.org> | 2023-05-15 15:31:56 +0200 |
commit | 1f479c1e46d111a6f001cf4ee24290f60f13257d (patch) | |
tree | cf07fbf5c61a63cdd5c9896abc0ed138be8b657a /mlir/lib | |
parent | 03b97e0c07fdcb3eac8580a60a4a3f88cee668b1 (diff) | |
download | llvm-1f479c1e46d111a6f001cf4ee24290f60f13257d.tar.gz |
[mlir][bufferization] Improve findValueInReverseUseDefChain signature
Instead of passing traversal options as a long list of arguments, store them in a TraversalConfig object and pass that object.
Differential Revision: https://reviews.llvm.org/D143927
Diffstat (limited to 'mlir/lib')
3 files changed, 40 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index cb69a9e5879c..712693ddd53a 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -367,10 +367,7 @@ BufferizationOptions::dynCastBufferizableOp(Operation *op) const { BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Value value) const { - if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>()) - if (isOpAllowed(bufferizableOp.getOperation())) - return bufferizableOp; - return nullptr; + return dynCastBufferizableOp(getOwnerOfValue(value)); } void BufferizationOptions::setFunctionBoundaryTypeConversion( @@ -500,7 +497,7 @@ bool AnalysisState::isValueRead(Value value) const { // further. llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref<bool(Value)> condition, - bool followEquivalentOnly, bool alwaysIncludeLeaves) const { + TraversalConfig config) const { llvm::SetVector<Value> result, workingSet; workingSet.insert(value); @@ -512,7 +509,7 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( } if (llvm::isa<BlockArgument>(value)) { - if (alwaysIncludeLeaves) + if (config.alwaysIncludeLeaves) result.insert(value); continue; } @@ -520,26 +517,43 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( OpResult opResult = llvm::cast<OpResult>(value); BufferizableOpInterface bufferizableOp = options.dynCastBufferizableOp(opResult.getDefiningOp()); - AliasingOpOperandList aliases = getAliasingOpOperands(opResult); + if (!config.followUnknownOps && !bufferizableOp) { + // Stop iterating if `followUnknownOps` is unset and the op is either + // not bufferizable or excluded in the OpFilter. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } - // Stop iterating in either one of these cases: - // * The current op is not bufferizable or excluded in the filter. - // * There are no OpOperands to follow. - if (!bufferizableOp || aliases.getNumAliases() == 0) { - if (alwaysIncludeLeaves) + AliasingOpOperandList aliases = getAliasingOpOperands(opResult); + if (aliases.getNumAliases() == 0) { + // The traversal ends naturally if there are no more OpOperands that + // could be followed. + if (config.alwaysIncludeLeaves) result.insert(value); continue; } for (AliasingOpOperand a : aliases) { - if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) { + if (config.followEquivalentOnly && + a.relation != BufferRelation::Equivalent) { // Stop iterating if `followEquivalentOnly` is set but the alias is not // equivalent. - if (alwaysIncludeLeaves) + if (config.alwaysIncludeLeaves) result.insert(value); } else { workingSet.insert(a.opOperand->get()); } + + if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) { + // Stop iterating if `followInPlaceOnly` is set but the alias is + // out-of-place. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } + + workingSet.insert(a.opOperand->get()); } } @@ -548,9 +562,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( // Find the values that define the contents of the given value. llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const { + TraversalConfig config; + config.alwaysIncludeLeaves = false; return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, - /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); + value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config); } AnalysisState::AnalysisState(const BufferizationOptions &options) @@ -927,12 +942,12 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite( return false; return state.bufferizesToMemoryWrite(v); }; + TraversalConfig config; + config.alwaysIncludeLeaves = false; for (AliasingOpOperand alias : opOperands) { if (!state .findValueInReverseUseDefChain(alias.opOperand->get(), - isMemoryWriteInsideOp, - /*followEquivalentOnly=*/false, - /*alwaysIncludeLeaves=*/false) + isMemoryWriteInsideOp, config) .empty()) return true; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 58475d225ce8..76d424867af6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -132,10 +132,13 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( // Find tensor.empty ops on the reverse SSA use-def chain. Only follow // equivalent tensors. I.e., stop when there are ops such as extract_slice // on the path. + TraversalConfig config; + config.followEquivalentOnly = true; + config.alwaysIncludeLeaves = false; SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( operand.get(), /*condition=*/ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, - /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false); + config); for (Value v : emptyTensors) { Operation *emptyTensorOp = v.getDefiningOp(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 6da512699cc7..a9f05b21282d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -775,11 +775,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand, // Find the values that define the contents of the given value. const llvm::SetVector<Value> & OneShotAnalysisState::findDefinitionsCached(Value value) { - if (!cachedDefinitions.count(value)) { - cachedDefinitions[value] = findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, - /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); - } + if (!cachedDefinitions.count(value)) + cachedDefinitions[value] = findDefinitions(value); return cachedDefinitions[value]; } |