diff options
-rw-r--r-- | mlir/examples/toy/Ch2/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch3/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch4/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp | 9 | ||||
-rw-r--r-- | mlir/examples/toy/Ch6/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp | 9 | ||||
-rw-r--r-- | mlir/examples/toy/Ch7/mlir/Dialect.cpp | 2 | ||||
-rw-r--r-- | mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp | 9 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 12 | ||||
-rw-r--r-- | mlir/include/mlir/IR/OpBase.td | 40 | ||||
-rw-r--r-- | mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 23 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/invalid.mlir | 4 |
14 files changed, 66 insertions, 54 deletions
diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp index a6ccbbfab48a..ef07af26ec43 100644 --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() { // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast<mlir::TensorType>(); + auto attrType = getValue().getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 913979a94950..43f8d5b1481d 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() { // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast<mlir::TensorType>(); + auto attrType = getValue().getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index f5258eb5cff1..75a517159a6d 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() { // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast<mlir::TensorType>(); + auto attrType = getValue().getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index a959969c0449..98c8eb5dd798 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() { // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast<mlir::TensorType>(); + auto attrType = getValue().getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index c52f5bd2c59b..a40353e3fd8b 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ using namespace mlir; // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value( static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast<TensorType>(); + auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast<TensorType>(); + auto tensorType = op.getType().cast<RankedTensorType>(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index a959969c0449..98c8eb5dd798 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() { // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast<mlir::TensorType>(); + auto attrType = getValue().getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index c52f5bd2c59b..a40353e3fd8b 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ using namespace mlir; // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value( static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast<TensorType>(); + auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast<TensorType>(); + auto tensorType = op.getType().cast<RankedTensorType>(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index d332411b63bb..5fcb0be36c8a 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -195,7 +195,7 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type, // Check that the rank of the attribute type matches the rank of the // constant result type. - auto attrType = attrValue.getType().cast<mlir::TensorType>(); + auto attrType = attrValue.getType().cast<mlir::RankedTensorType>(); if (attrType.getRank() != resultType.getRank()) { return op->emitOpError("return type must match the one of the attached " "value attribute: ") diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index c52f5bd2c59b..a40353e3fd8b 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ using namespace mlir; // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value( static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast<TensorType>(); + auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast<TensorType>(); + auto tensorType = op.getType().cast<RankedTensorType>(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 65b2c12b4507..d106f1285dfd 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -974,8 +974,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> : Tensor_Op<mnemonic, !listconcat(traits, [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, Pure])>, - Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, - Results<(outs AnyTensor:$result)> { + Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>, + Results<(outs AnyRankedTensor:$result)> { code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrStrName() { return "reassociation"; } @@ -1210,7 +1210,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [ }]; let arguments = (ins - AnyTensor:$source, + AnyRankedTensor:$source, Variadic<Index>:$low, Variadic<Index>:$high, DenseI64ArrayAttr:$static_low, @@ -1219,7 +1219,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [ let regions = (region SizedRegion<1>:$region); - let results = (outs AnyTensor:$result); + let results = (outs AnyRankedTensor:$result); // TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands. let assemblyFormat = [{ @@ -1678,8 +1678,8 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> : "$_self">])> { code commonExtraClassDeclaration = [{ - size_t getSourceRank() { return getSource().getType().getRank(); }; - size_t getDestRank() { return getDest().getType().getRank(); }; + size_t getSourceRank() { return getSourceType().getRank(); }; + size_t getDestRank() { return getDestType().getRank(); }; RankedTensorType getSourceType() { return getSource().getType().cast<RankedTensorType>(); }; RankedTensorType getDestType() { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 554f02675363..f7f009cce317 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -240,6 +240,10 @@ def IsUnrankedMemRefTypePred def IsUnrankedTensorTypePred : CPred<"$_self.isa<::mlir::UnrankedTensorType>()">; +// Whether a type is a RankedTensorType +def IsRankedTensorTypePred + : CPred<"$_self.isa<::mlir::RankedTensorType>()">; + // Whether a type is a BaseMemRefType def IsBaseMemRefTypePred : CPred<"$_self.isa<::mlir::BaseMemRefType>()">; @@ -721,11 +725,21 @@ def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", //===----------------------------------------------------------------------===// // Tensor types. -// Unranked tensor type whose element type is from the given -// `allowedTypes` list. -class UnrankedTensorOf<list<Type> allowedTypes> - : ShapedContainerType<allowedTypes, IsUnrankedTensorTypePred, - "unranked.tensor", "::mlir::UnrankedTensorType">; +// Unranked tensor type whose element type is from the given `allowedTypes` +// list, and which additionally satisfies an optional list of predicates. +class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], + string summary = "unranked tensor"> + : ShapedContainerType< + allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>, + summary, "::mlir::UnrankedTensorType">; + +// Ranked tensor type whose element type is from the given `allowedTypes` list, +// and which additionally satisfies an optional list of predicates. +class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], + string summary = "ranked tensor"> + : ShapedContainerType< + allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>, + summary, "::mlir::RankedTensorType">; // Any tensor type whose element type is from the given `allowedTypes` // list, and which additionally satisfies an optional list of predicates. @@ -754,12 +768,6 @@ def F16Tensor : TensorOf<[F16]>; def F32Tensor : TensorOf<[F32]>; def F64Tensor : TensorOf<[F64]>; -class RankedTensorOf< - list<Type> allowedTypes, - list<Pred> preds = [], - string summary = "ranked tensor"> - : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>; - class Non0RankedTensorOf<list<Type> allowedTypes> : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>], "non-0-ranked.tensor">; @@ -768,12 +776,13 @@ def AnyRankedTensor : RankedTensorOf<[AnyType]>; def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>; def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>; -def AnyNon0RankedOrUnrankedTensor: - AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>; +def AnyNon0RankedOrUnrankedTensor + : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor], + "non-0-ranked or unranked tensor", "::mlir::TensorType">; // Ranked tensor type with one of the specified types and ranks. class TensorRankOf<list<Type> allowedTypes, list<int> ranks> - : TensorOf<allowedTypes, + : RankedTensorOf<allowedTypes, [HasAnyRankOfPred<ranks>], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; @@ -784,7 +793,8 @@ class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>; class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>; class StaticShapeTensorOf<list<Type> allowedTypes> - : TensorOf<allowedTypes, [HasStaticShapePred], "statically shaped tensor">; + : RankedTensorOf<allowedTypes, [HasStaticShapePred], + "statically shaped tensor">; def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp index 776d6c7a6930..ed13ab3fd8c0 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -44,7 +44,7 @@ public: LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TensorType tensorType = extractOp.getTensor().getType().cast<TensorType>(); + auto tensorType = extractOp.getTensor().getType().cast<RankedTensorType>(); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 7ee9325e5f8e..ccea6dd854af 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -369,11 +369,16 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> { auto extractOperand = tensorCast.getOperand().getDefiningOp<ExtractSliceOp>(); + // Cannot fold cast to unranked tensor. + auto rankedResultType = tensorCast.getType().dyn_cast<RankedTensorType>(); + if (!rankedResultType) + return failure(); + if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || - tensorCast.getType().getShape() == tensorCast.getSource() - .getType() - .cast<RankedTensorType>() - .getShape()) + rankedResultType.getShape() == tensorCast.getSource() + .getType() + .cast<RankedTensorType>() + .getShape()) return failure(); SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes(); @@ -383,15 +388,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> { for (size_t i = 0, e = sizes.size(); i < e; i++) { if (dimMask && dimMask->count(i)) continue; - int64_t dim = tensorCast.getType().getShape()[dimIndex++]; + int64_t dim = rankedResultType.getShape()[dimIndex++]; if (ShapedType::isDynamic(dim)) continue; sizes[i] = rewriter.getIndexAttr(dim); } rewriter.replaceOpWithNewOp<ExtractSliceOp>( - tensorCast, tensorCast.getType().cast<RankedTensorType>(), - extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes, + tensorCast, rankedResultType, extractOperand.getSource(), + extractOperand.getMixedOffsets(), sizes, extractOperand.getMixedStrides()); return success(); } @@ -1500,7 +1505,7 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> { return failure(); // Skip static dims. These are folded to constant ops. - TensorType resultType = expandShapeOp.getResultType(); + RankedTensorType resultType = expandShapeOp.getResultType(); if (!resultType.isDynamicDim(*dim)) return failure(); @@ -1544,7 +1549,7 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> { return failure(); // Skip static dims. These are folded to constant ops. - TensorType resultType = collapseShapeOp.getResultType(); + RankedTensorType resultType = collapseShapeOp.getResultType(); if (!resultType.isDynamicDim(*dim)) return failure(); diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index f74bd9456b66..61f03f19de33 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -2,7 +2,7 @@ // Asking the dimension of a 0-D shape doesn't make sense. func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) { - tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}} + tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be non-0-ranked or unranked tensor, but got 'tensor<f32>'}} return } @@ -33,7 +33,7 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) { // ----- func.func @tensor.from_elements_wrong_result_type() { - // expected-error@+2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}} + // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}} %c0 = arith.constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<*xi32> return |