summaryrefslogtreecommitdiff
path: root/mlir/examples/toy/Ch5
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2023-04-06 13:20:41 +0900
committerMatthias Springer <springerm@google.com>2023-04-06 13:22:10 +0900
commit2443d946f9156dcc1352a6c9abffdc05f0b52d69 (patch)
tree134ca9defe5f2a6cc58bb732bf234faae7704c11 /mlir/examples/toy/Ch5
parent211f1d2bb8b753f7068d17b3939d1c59b60e838c (diff)
downloadllvm-2443d946f9156dcc1352a6c9abffdc05f0b52d69.tar.gz
[mlir] Use RankedTensorType when rank is required
`RankedTensorOf` and `TensorRankOf` (in Tablegen files) now generate code that uses `RankedTensorType` instead of `TensorType`. This gives us more accurate type information (e.g., when calling `op.getType()`). Also use restrict tensor.expand_shape/tensor.collapse_shape/tensor.pad to ranked tensors. Only cast ops should deal with unranked tensors. Also improves a few places in the code base (e.g., Toy tutorial) where a ranked tensor is assumed (e.g., because `getRank` is called) but a `TensorType` is currently used: cast to `RankedTensorType` directly, so that the assertion is triggered directly at the cast. Differential Revision: https://reviews.llvm.org/D147149
Diffstat (limited to 'mlir/examples/toy/Ch5')
-rw-r--r--mlir/examples/toy/Ch5/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp9
2 files changed, 5 insertions, 6 deletions
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index a959969c0449..98c8eb5dd798 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index c52f5bd2c59b..a40353e3fd8b 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -30,9 +30,8 @@ using namespace mlir;
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);