diff options
author | Matthias Springer <springerm@google.com> | 2023-04-06 13:20:41 +0900 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2023-04-06 13:22:10 +0900 |
commit | 2443d946f9156dcc1352a6c9abffdc05f0b52d69 (patch) | |
tree | 134ca9defe5f2a6cc58bb732bf234faae7704c11 /mlir/examples/toy/Ch5 | |
parent | 211f1d2bb8b753f7068d17b3939d1c59b60e838c (diff) | |
download | llvm-2443d946f9156dcc1352a6c9abffdc05f0b52d69.tar.gz |
[mlir] Use RankedTensorType when rank is required
`RankedTensorOf` and `TensorRankOf` (in Tablegen files) now generate code that uses `RankedTensorType` instead of `TensorType`. This gives us more accurate type information (e.g., when calling `op.getType()`).
Also use restrict tensor.expand_shape/tensor.collapse_shape/tensor.pad to ranked tensors. Only cast ops should deal with unranked tensors.
Also improves a few places in the code base (e.g., Toy tutorial) where a ranked tensor is assumed (e.g., because `getRank` is called) but a `TensorType` is currently used: cast to `RankedTensorType` directly, so that the assertion is triggered directly at the cast.
Differential Revision: https://reviews.llvm.org/D147149
Diffstat (limited to 'mlir/examples/toy/Ch5')
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp | 9 |
2 files changed, 5 insertions, 6 deletions
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index a959969c0449..98c8eb5dd798 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() { // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast<mlir::TensorType>(); + auto attrType = getValue().getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index c52f5bd2c59b..a40353e3fd8b 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ using namespace mlir; // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value( static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast<TensorType>(); + auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast<TensorType>(); + auto tensorType = op.getType().cast<RankedTensorType>(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); |