summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/examples/toy/Ch2/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch3/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch4/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch5/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp9
-rw-r--r--mlir/examples/toy/Ch6/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp9
-rw-r--r--mlir/examples/toy/Ch7/mlir/Dialect.cpp2
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp9
-rw-r--r--mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td12
-rw-r--r--mlir/include/mlir/IR/OpBase.td40
-rw-r--r--mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp23
-rw-r--r--mlir/test/Dialect/Tensor/invalid.mlir4
14 files changed, 66 insertions, 54 deletions
diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index a6ccbbfab48a..ef07af26ec43 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 913979a94950..43f8d5b1481d 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index f5258eb5cff1..75a517159a6d 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index a959969c0449..98c8eb5dd798 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index c52f5bd2c59b..a40353e3fd8b 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -30,9 +30,8 @@ using namespace mlir;
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index a959969c0449..98c8eb5dd798 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index c52f5bd2c59b..a40353e3fd8b 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -30,9 +30,8 @@ using namespace mlir;
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index d332411b63bb..5fcb0be36c8a 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -195,7 +195,7 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
// Check that the rank of the attribute type matches the rank of the
// constant result type.
- auto attrType = attrValue.getType().cast<mlir::TensorType>();
+ auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op->emitOpError("return type must match the one of the attached "
"value attribute: ")
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index c52f5bd2c59b..a40353e3fd8b 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -30,9 +30,8 @@ using namespace mlir;
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 65b2c12b4507..d106f1285dfd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -974,8 +974,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Tensor_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure])>,
- Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
- Results<(outs AnyTensor:$result)> {
+ Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
+ Results<(outs AnyRankedTensor:$result)> {
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1210,7 +1210,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
}];
let arguments = (ins
- AnyTensor:$source,
+ AnyRankedTensor:$source,
Variadic<Index>:$low,
Variadic<Index>:$high,
DenseI64ArrayAttr:$static_low,
@@ -1219,7 +1219,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
let regions = (region SizedRegion<1>:$region);
- let results = (outs AnyTensor:$result);
+ let results = (outs AnyRankedTensor:$result);
// TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
let assemblyFormat = [{
@@ -1678,8 +1678,8 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
"$_self">])> {
code commonExtraClassDeclaration = [{
- size_t getSourceRank() { return getSource().getType().getRank(); };
- size_t getDestRank() { return getDest().getType().getRank(); };
+ size_t getSourceRank() { return getSourceType().getRank(); };
+ size_t getDestRank() { return getDestType().getRank(); };
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>(); };
RankedTensorType getDestType() {
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 554f02675363..f7f009cce317 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -240,6 +240,10 @@ def IsUnrankedMemRefTypePred
def IsUnrankedTensorTypePred
: CPred<"$_self.isa<::mlir::UnrankedTensorType>()">;
+// Whether a type is a RankedTensorType
+def IsRankedTensorTypePred
+ : CPred<"$_self.isa<::mlir::RankedTensorType>()">;
+
// Whether a type is a BaseMemRefType
def IsBaseMemRefTypePred
: CPred<"$_self.isa<::mlir::BaseMemRefType>()">;
@@ -721,11 +725,21 @@ def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
//===----------------------------------------------------------------------===//
// Tensor types.
-// Unranked tensor type whose element type is from the given
-// `allowedTypes` list.
-class UnrankedTensorOf<list<Type> allowedTypes>
- : ShapedContainerType<allowedTypes, IsUnrankedTensorTypePred,
- "unranked.tensor", "::mlir::UnrankedTensorType">;
+// Unranked tensor type whose element type is from the given `allowedTypes`
+// list, and which additionally satisfies an optional list of predicates.
+class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "unranked tensor">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>,
+ summary, "::mlir::UnrankedTensorType">;
+
+// Ranked tensor type whose element type is from the given `allowedTypes` list,
+// and which additionally satisfies an optional list of predicates.
+class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "ranked tensor">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>,
+ summary, "::mlir::RankedTensorType">;
// Any tensor type whose element type is from the given `allowedTypes`
// list, and which additionally satisfies an optional list of predicates.
@@ -754,12 +768,6 @@ def F16Tensor : TensorOf<[F16]>;
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;
-class RankedTensorOf<
- list<Type> allowedTypes,
- list<Pred> preds = [],
- string summary = "ranked tensor">
- : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;
-
class Non0RankedTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
"non-0-ranked.tensor">;
@@ -768,12 +776,13 @@ def AnyRankedTensor : RankedTensorOf<[AnyType]>;
def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>;
def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>;
-def AnyNon0RankedOrUnrankedTensor:
- AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>;
+def AnyNon0RankedOrUnrankedTensor
+ : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor],
+ "non-0-ranked or unranked tensor", "::mlir::TensorType">;
// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
- : TensorOf<allowedTypes,
+ : RankedTensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
@@ -784,7 +793,8 @@ class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
class StaticShapeTensorOf<list<Type> allowedTypes>
- : TensorOf<allowedTypes, [HasStaticShapePred], "statically shaped tensor">;
+ : RankedTensorOf<allowedTypes, [HasStaticShapePred],
+ "statically shaped tensor">;
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index 776d6c7a6930..ed13ab3fd8c0 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -44,7 +44,7 @@ public:
LogicalResult
matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TensorType tensorType = extractOp.getTensor().getType().cast<TensorType>();
+ auto tensorType = extractOp.getTensor().getType().cast<RankedTensorType>();
if (!tensorType.hasStaticShape())
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7ee9325e5f8e..ccea6dd854af 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -369,11 +369,16 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
auto extractOperand =
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
+ // Cannot fold cast to unranked tensor.
+ auto rankedResultType = tensorCast.getType().dyn_cast<RankedTensorType>();
+ if (!rankedResultType)
+ return failure();
+
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
- tensorCast.getType().getShape() == tensorCast.getSource()
- .getType()
- .cast<RankedTensorType>()
- .getShape())
+ rankedResultType.getShape() == tensorCast.getSource()
+ .getType()
+ .cast<RankedTensorType>()
+ .getShape())
return failure();
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
@@ -383,15 +388,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
for (size_t i = 0, e = sizes.size(); i < e; i++) {
if (dimMask && dimMask->count(i))
continue;
- int64_t dim = tensorCast.getType().getShape()[dimIndex++];
+ int64_t dim = rankedResultType.getShape()[dimIndex++];
if (ShapedType::isDynamic(dim))
continue;
sizes[i] = rewriter.getIndexAttr(dim);
}
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
- tensorCast, tensorCast.getType().cast<RankedTensorType>(),
- extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
+ tensorCast, rankedResultType, extractOperand.getSource(),
+ extractOperand.getMixedOffsets(), sizes,
extractOperand.getMixedStrides());
return success();
}
@@ -1500,7 +1505,7 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
return failure();
// Skip static dims. These are folded to constant ops.
- TensorType resultType = expandShapeOp.getResultType();
+ RankedTensorType resultType = expandShapeOp.getResultType();
if (!resultType.isDynamicDim(*dim))
return failure();
@@ -1544,7 +1549,7 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
return failure();
// Skip static dims. These are folded to constant ops.
- TensorType resultType = collapseShapeOp.getResultType();
+ RankedTensorType resultType = collapseShapeOp.getResultType();
if (!resultType.isDynamicDim(*dim))
return failure();
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f74bd9456b66..61f03f19de33 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -2,7 +2,7 @@
// Asking the dimension of a 0-D shape doesn't make sense.
func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
- tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}
+ tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be non-0-ranked or unranked tensor, but got 'tensor<f32>'}}
return
}
@@ -33,7 +33,7 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// -----
func.func @tensor.from_elements_wrong_result_type() {
- // expected-error@+2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}}
+ // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}}
%c0 = arith.constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<*xi32>
return