summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-05-15 15:39:35 +0200
committerMatthias Springer <me@m-sp.org>2023-05-15 15:42:56 +0200
commitae8cb6437294ca99ba203607c0dd522db4dbf6b6 (patch)
tree1fd5afb66bcab6eef67c4ffbe00b085ee4b167f5 /mlir
parentbb9d1b551a4407660293c2bf3f2343ba70ed8e68 (diff)
downloadllvm-ae8cb6437294ca99ba203607c0dd522db4dbf6b6.tar.gz
[mlir][scf][bufferize] Fix bug in WhileOp analysis verification
Block arguments and yielded values are not equivalent if there are not enough block arguments. This fixes #59442. Differential Revision: https://reviews.llvm.org/D145575
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp12
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir14
2 files changed, 22 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index ad395a9ac457..4b0d0e40740f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -902,10 +902,12 @@ struct WhileOpInterface
auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+ Block *block = conditionOp->getBlock();
if (!isa<TensorType>(it.value().getType()))
continue;
- if (!state.areEquivalentBufferizedValues(
- it.value(), conditionOp->getBlock()->getArgument(it.index())))
+ if (it.index() >= block->getNumArguments() ||
+ !state.areEquivalentBufferizedValues(it.value(),
+ block->getArgument(it.index())))
return conditionOp->emitError()
<< "Condition arg #" << it.index()
<< " is not equivalent to the corresponding iter bbArg";
@@ -913,10 +915,12 @@ struct WhileOpInterface
auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
+ Block *block = yieldOp->getBlock();
if (!isa<TensorType>(it.value().getType()))
continue;
- if (!state.areEquivalentBufferizedValues(
- it.value(), yieldOp->getBlock()->getArgument(it.index())))
+ if (it.index() >= block->getNumArguments() ||
+ !state.areEquivalentBufferizedValues(it.value(),
+ block->getArgument(it.index())))
return yieldOp->emitError()
<< "Yield operand #" << it.index()
<< " is not equivalent to the corresponding iter bbArg";
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 189ef6be0dff..a2d47f08adba 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -324,3 +324,17 @@ func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> {
// This function may write to buffer(%ptr).
func.func private @maybe_writing_func(%ptr : tensor<*xf32>)
+
+// -----
+
+func.func @regression_scf_while() {
+ %false = arith.constant false
+ %8 = bufferization.alloc_tensor() : tensor<10x10xf32>
+ scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () {
+ scf.condition(%false)
+ } do {
+ // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}}
+ scf.yield %8 : tensor<10x10xf32>
+ }
+ return
+}