summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2021-11-18 16:10:10 +0900
committerMatthias Springer <springerm@google.com>2021-11-18 16:11:24 +0900
commit26e90423f4b81de7d4a6011134308c3e454964c0 (patch)
treef04636ab594f811078d0fb229a8c17a2399770aa
parent0c7890c844fdc7adb6d0cf58403e3fdd7407915d (diff)
downloadllvm-26e90423f4b81de7d4a6011134308c3e454964c0.tar.gz
[mlir][linalg][bufferize][NFC] Decouple ComprehensiveBufferize from tensor dialect
Add a new BufferizableOpInterface method `isNotConflicting` that can be used to implement custom analysis rules. Differential Revision: https://reviews.llvm.org/D113961
-rw-r--r--mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td23
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp224
2 files changed, 147 insertions, 100 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 4742b51623e1..757eca50eb5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -215,6 +215,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{
return false;
}]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return `true` if the `uRead` and `uWrite` do not constitute a RaW
+ conflict. If they are conflicting or if it is unknown whether they are
+ conflicting, return `false`. This method will never be called with
+ OpOperands that do not have a tensor type. At least one of the two
+ given OpOperands belongs to this operation.
+
+ This method can be implemented to specify custom RaW analysis rules.
+ If this method returns `true` the given OpOperands are not considered
+ to be conflicting and do not force out-of-place bufferization. (There
+ may still be other conflicts that do.)
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isNotConflicting",
+ /*args=*/(ins "OpOperand *":$uRead,
+ "OpOperand *":$uWrite,
+ "const BufferizationAliasInfo &":$aliasInfo),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
>
];
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 697b894f8990..96fc066e7553 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -281,24 +281,6 @@ static std::string printValueInfo(Value value, bool prefix) {
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-///
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and
-/// exposed as interfaces on the proper types.
-static bool
-areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
- ExtractSliceOp st, InsertSliceOp sti) {
- if (!st || !sti)
- return false;
- if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
- return false;
- if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
- return false;
- return true;
-}
-
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo) {
@@ -368,21 +350,6 @@ static bool aliasesInPlaceWrite(Value value,
return foundInplaceWrite;
}
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
- Value value, InsertSliceOp insertOp) {
- auto condition = [&](Value val) {
- if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
- return true;
- return false;
- };
-
- return llvm::all_of(findValueInReverseUseDefChain(value, condition),
- condition);
-}
-
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b,
@@ -450,6 +417,21 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (uConflictingWrite == uRead)
continue;
+ // No conflict if the op interface says so.
+ if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
+ if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
+ aliasInfo))
+ continue;
+
+ if (conflictingWritingOp != readingOp)
+ if (auto bufferizableOp =
+ dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
+ if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
+ aliasInfo))
+ continue;
+
+ // Special rules for branches.
+ // TODO: Use an interface.
if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
continue;
@@ -478,73 +460,6 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;
- // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
- // uRead is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
-
- // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
- if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
- insertSliceOp))
- // Case 1: The main insight is that InsertSliceOp reads only part of
- // the destination tensor. The overwritten area is not read. If
- // uConflictingWrite writes into exactly the memory location that is
- // being read by uRead, this is not a conflict.
- //
- // In the above example:
- // uRead = OpOperand 1 (%t) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
- //
- // The read of %t does not conflict with the write of the FillOp
- // (same aliases!) because the area that the FillOp operates on is
- // exactly the one that is *not* read via %t.
- continue;
-
- if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
- uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
- // Case 2: The read of the source tensor and the write to the dest
- // tensor via an InsertSliceOp is not a conflict if the read is
- // reading exactly that part of an equivalent tensor that the
- // InsertSliceOp is writing.
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- continue;
- }
-
- // If uConflictingWrite is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
- // %3 = vector.transfer_read %1, %cst
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of vector.transfer_read
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // lastWrite = %1
- //
- // This is not a conflict because the InsertSliceOp overwrites the
- // memory segment of %1 with the exact same data. (Effectively, there
- // is no memory write here.)
- if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- aliasInfo.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.source()) &&
- hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
- insertSliceOp))
- continue;
-
// All requirements are met. Conflict found!
LDBG("CONFLICT CONFIRMED!\n\n");
return true;
@@ -2321,6 +2236,24 @@ struct ExtractOpInterface
}
};
+/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+///
+/// This is one particular type of relationship between ops on tensors that
+/// reduce to an equivalence on buffers. This should be generalized and
+/// exposed as interfaces on the proper types.
+static bool
+areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
+ ExtractSliceOp st, InsertSliceOp sti) {
+ if (!st || !sti)
+ return false;
+ if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+ return false;
+ if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+ return false;
+ return true;
+}
+
/// Return true if the source of a `insertSliceOp` bufferizes to an
/// equivalent ExtractSliceOp that bufferizes inplace.
static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
@@ -2345,6 +2278,21 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
return foundOp;
}
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+ Value value, InsertSliceOp insertOp) {
+ auto condition = [&](Value val) {
+ if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+ if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
+ return true;
+ return false;
+ };
+
+ return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+ condition);
+}
+
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
@@ -2371,6 +2319,82 @@ struct InsertSliceOpInterface
return BufferRelation::Equivalent;
}
+ bool isNotConflicting(Operation *op, OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const BufferizationAliasInfo &aliasInfo) const {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+ if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
+ insertSliceOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+ uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // lastWrite = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ aliasInfo.areEquivalentBufferizedValues(uRead->get(),
+ insertSliceOp.source()) &&
+ hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
+ insertSliceOp))
+ return true;
+
+ return false;
+ }
+
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is