diff options
author | Matthias Springer <me@m-sp.org> | 2023-05-15 15:39:35 +0200 |
---|---|---|
committer | Matthias Springer <me@m-sp.org> | 2023-05-15 15:42:56 +0200 |
commit | ae8cb6437294ca99ba203607c0dd522db4dbf6b6 (patch) | |
tree | 1fd5afb66bcab6eef67c4ffbe00b085ee4b167f5 /mlir | |
parent | bb9d1b551a4407660293c2bf3f2343ba70ed8e68 (diff) | |
download | llvm-ae8cb6437294ca99ba203607c0dd522db4dbf6b6.tar.gz |
[mlir][scf][bufferize] Fix bug in WhileOp analysis verification
Block arguments and yielded values are not equivalent if there are not enough block arguments. This fixes #59442.
Differential Revision: https://reviews.llvm.org/D145575
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp | 12 | ||||
-rw-r--r-- | mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir | 14 |
2 files changed, 22 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index ad395a9ac457..4b0d0e40740f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -902,10 +902,12 @@ struct WhileOpInterface auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { + Block *block = conditionOp->getBlock(); if (!isa<TensorType>(it.value().getType())) continue; - if (!state.areEquivalentBufferizedValues( - it.value(), conditionOp->getBlock()->getArgument(it.index()))) + if (it.index() >= block->getNumArguments() || + !state.areEquivalentBufferizedValues(it.value(), + block->getArgument(it.index()))) return conditionOp->emitError() << "Condition arg #" << it.index() << " is not equivalent to the corresponding iter bbArg"; @@ -913,10 +915,12 @@ struct WhileOpInterface auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + Block *block = yieldOp->getBlock(); if (!isa<TensorType>(it.value().getType())) continue; - if (!state.areEquivalentBufferizedValues( - it.value(), yieldOp->getBlock()->getArgument(it.index()))) + if (it.index() >= block->getNumArguments() || + !state.areEquivalentBufferizedValues(it.value(), + block->getArgument(it.index()))) return yieldOp->emitError() << "Yield operand #" << it.index() << " is not equivalent to the corresponding iter bbArg"; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index 189ef6be0dff..a2d47f08adba 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -324,3 +324,17 @@ func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> { // This function may write to buffer(%ptr). func.func private @maybe_writing_func(%ptr : tensor<*xf32>) + +// ----- + +func.func @regression_scf_while() { + %false = arith.constant false + %8 = bufferization.alloc_tensor() : tensor<10x10xf32> + scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () { + scf.condition(%false) + } do { + // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}} + scf.yield %8 : tensor<10x10xf32> + } + return +} |