diff options
author | Benoit Jacob <benoitjacob@google.com> | 2023-05-11 15:44:12 +0000 |
---|---|---|
committer | Benoit Jacob <benoitjacob@google.com> | 2023-05-12 15:35:24 +0000 |
commit | 2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf (patch) | |
tree | 5420b6025d0f59421fee7bc66982f4fa074ba910 /mlir/lib | |
parent | 2c52a1892505aeefd7735beafda2a410cde2c380 (diff) | |
download | llvm-2bd6077d7fb61a54ebfbcf2fa210e07e8005c4bf.tar.gz |
DestinationPassingStyle: allow additional non-tensor results
Also some simplifications:
* `outputBufferOperands` was unused.
* The condition that the number of operands equals the number of inputs
plus the number of inits seemed vacuously true (?).
Differential Revision: https://reviews.llvm.org/D150376
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Interfaces/DestinationStyleOpInterface.cpp | 35 |
1 files changed, 19 insertions, 16 deletions
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp index a9bab23f1a72..f344ea656b24 100644 --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -22,35 +22,38 @@ OpOperandVector::operator SmallVector<Value>() { return result; } +namespace { +size_t getNumTensorResults(Operation *op) { + size_t numTensorResults = 0; + for (auto t : op->getResultTypes()) { + if (isa<TensorType>(t)) { + ++numTensorResults; + } + } + return numTensorResults; +} +} // namespace + LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { DestinationStyleOpInterface dstStyleOp = cast<DestinationStyleOpInterface>(op); - SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands; + SmallVector<OpOperand *> outputTensorOperands; for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); - if (isa<MemRefType>(type)) { - outputBufferOperands.push_back(operand); - } else if (isa<RankedTensorType>(type)) { + if (isa<RankedTensorType>(type)) { outputTensorOperands.push_back(operand); - } else { + } else if (!isa<MemRefType>(type)) { return op->emitOpError("expected that operand #") << operand->getOperandNumber() << " is a ranked tensor or a ranked memref"; } } - // Expect at least one output operand. - int64_t numInputs = dstStyleOp.getNumDpsInputs(); - int64_t numInits = dstStyleOp.getNumDpsInits(); - if (numInits == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != outputTensorOperands.size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() + // Verify the number of tensor results matches the number of output tensors. + if (getNumTensorResults(op) != outputTensorOperands.size()) + return op->emitOpError("expected the number of tensor results (") + << getNumTensorResults(op) << ") to be equal to the number of output tensors (" << outputTensorOperands.size() << ")"; |