summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2023-05-11 15:44:12 +0000
committerBenoit Jacob <benoitjacob@google.com>2023-05-12 15:35:24 +0000
commit2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf (patch)
tree5420b6025d0f59421fee7bc66982f4fa074ba910 /mlir/lib
parent2c52a1892505aeefd7735beafda2a410cde2c380 (diff)
downloadllvm-2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf.tar.gz
DestinationPassingStyle: allow additional non-tensor results
Also some simplifications: * `outputBufferOperands` was unused. * The condition that the number of operands equals the number of inputs plus the number of inits seemed vacuously true (?). Differential Revision: https://reviews.llvm.org/D150376
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Interfaces/DestinationStyleOpInterface.cpp35
1 files changed, 19 insertions, 16 deletions
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index a9bab23f1a72..f344ea656b24 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -22,35 +22,38 @@ OpOperandVector::operator SmallVector<Value>() {
return result;
}
+namespace {
+size_t getNumTensorResults(Operation *op) {
+ size_t numTensorResults = 0;
+ for (auto t : op->getResultTypes()) {
+ if (isa<TensorType>(t)) {
+ ++numTensorResults;
+ }
+ }
+ return numTensorResults;
+}
+} // namespace
+
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
- SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
+ SmallVector<OpOperand *> outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
Type type = operand->get().getType();
- if (isa<MemRefType>(type)) {
- outputBufferOperands.push_back(operand);
- } else if (isa<RankedTensorType>(type)) {
+ if (isa<RankedTensorType>(type)) {
outputTensorOperands.push_back(operand);
- } else {
+ } else if (!isa<MemRefType>(type)) {
return op->emitOpError("expected that operand #")
<< operand->getOperandNumber()
<< " is a ranked tensor or a ranked memref";
}
}
- // Expect at least one output operand.
- int64_t numInputs = dstStyleOp.getNumDpsInputs();
- int64_t numInits = dstStyleOp.getNumDpsInits();
- if (numInits == 0)
- return op->emitOpError("expected at least one output operand");
- if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits)))
- return failure();
- // Verify the number of results matches the number of output tensors.
- if (op->getNumResults() != outputTensorOperands.size())
- return op->emitOpError("expected the number of results (")
- << op->getNumResults()
+ // Verify the number of tensor results matches the number of output tensors.
+ if (getNumTensorResults(op) != outputTensorOperands.size())
+ return op->emitOpError("expected the number of tensor results (")
+ << getNumTensorResults(op)
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";