From ae8cb6437294ca99ba203607c0dd522db4dbf6b6 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 15 May 2023 15:39:35 +0200 Subject: [mlir][scf][bufferize] Fix bug in WhileOp analysis verification Block arguments and yielded values are not equivalent if there are not enough block arguments. This fixes #59442. Differential Revision: https://reviews.llvm.org/D145575 --- .../Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp | 12 ++++++++---- .../Transforms/one-shot-module-bufferize-invalid.mlir | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) (limited to 'mlir') diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index ad395a9ac457..4b0d0e40740f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -902,10 +902,12 @@ struct WhileOpInterface auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { + Block *block = conditionOp->getBlock(); if (!isa(it.value().getType())) continue; - if (!state.areEquivalentBufferizedValues( - it.value(), conditionOp->getBlock()->getArgument(it.index()))) + if (it.index() >= block->getNumArguments() || + !state.areEquivalentBufferizedValues(it.value(), + block->getArgument(it.index()))) return conditionOp->emitError() << "Condition arg #" << it.index() << " is not equivalent to the corresponding iter bbArg"; @@ -913,10 +915,12 @@ struct WhileOpInterface auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + Block *block = yieldOp->getBlock(); if (!isa(it.value().getType())) continue; - if (!state.areEquivalentBufferizedValues( - it.value(), yieldOp->getBlock()->getArgument(it.index()))) + if (it.index() >= block->getNumArguments() || + !state.areEquivalentBufferizedValues(it.value(), + block->getArgument(it.index()))) return yieldOp->emitError() << "Yield operand #" << it.index() << " is not equivalent to the corresponding iter bbArg"; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index 189ef6be0dff..a2d47f08adba 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -324,3 +324,17 @@ func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> { // This function may write to buffer(%ptr). func.func private @maybe_writing_func(%ptr : tensor<*xf32>) + +// ----- + +func.func @regression_scf_while() { + %false = arith.constant false + %8 = bufferization.alloc_tensor() : tensor<10x10xf32> + scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () { + scf.condition(%false) + } do { + // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}} + scf.yield %8 : tensor<10x10xf32> + } + return +} -- cgit v1.2.1