diff options
author | Alexander Belyaev <pifon@google.com> | 2021-12-14 09:35:14 +0100 |
---|---|---|
committer | Alexander Belyaev <pifon@google.com> | 2021-12-14 10:15:55 +0100 |
commit | 15f8f3e20aa92349b0cb559d657f7648987edb06 (patch) | |
tree | 89d9cfafc4202967b6ac1f713bf85d5008b4ac63 | |
parent | ef5be2bb16e51c2f6fff622a43cc71268acc6ddc (diff) | |
download | llvm-15f8f3e20aa92349b0cb559d657f7648987edb06.tar.gz |
[mlir] Split std.rank into tensor.rank and memref.rank.
Move `std.rank` similarly to how `std.dim` was moved to TensorOps and MemRefOps.
Differential Revision: https://reviews.llvm.org/D115665
23 files changed, 210 insertions, 150 deletions
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c6b7e28fe0aa..e529a50dae93 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -999,6 +999,31 @@ def MemRef_ReinterpretCastOp: } //===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The `memref.rank` operation takes a memref operand and returns its rank. + + Example: + + ```mlir + %0 = memref.rank %arg0 : memref<*xf32> + %1 = memref.rank %arg1 : memref<?x?xf32> + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); + let results = (outs Index); + + let verifier = ?; + let hasFolder = 1; + let assemblyFormat = "$memref attr-dict `:` type($memref)"; +} + +//===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 23b9df282af0..2e50971db9e7 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -659,32 +659,6 @@ def ConstantOp : Std_Op<"constant", } //===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -def RankOp : Std_Op<"rank", [NoSideEffect]> { - let summary = "rank operation"; - let description = [{ - The `rank` operation takes a memref/tensor operand and returns its rank. - - Example: - - ```mlir - %1 = rank %arg0 : tensor<*xf32> - %2 = rank %arg1 : memref<*xf32> - ``` - }]; - - let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor], - "any memref or tensor type">:$memrefOrTensor); - let results = (outs Index); - let verifier = ?; - - let hasFolder = 1; - let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; -} - -//===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 3b1bfeeca6c1..21331fc649cd 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -68,9 +68,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> { let summary = "dimension index operation"; let description = [{ - The `dim` operation takes a tensor and a dimension operand of type `index`. - It returns the size of the requested dimension of the given tensor. - If the dimension index is out of bounds, the behavior is undefined. + The `tensor.dim` operation takes a tensor and a dimension operand of type + `index`. It returns the size of the requested dimension of the given + tensor. If the dimension index is out of bounds, the behavior is undefined. The specified tensor type is that of the first operand. @@ -559,6 +559,31 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides< } //===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> { + let summary = "rank operation"; + let description = [{ + The `tensor.rank` operation takes a tensor operand and returns its rank. + + Example: + + ```mlir + %0 = tensor.rank %arg0 : tensor<*xf32> + %1 = tensor.rank %arg1 : tensor<?x?xf32> + ``` + }]; + + let arguments = (ins AnyTensor:$tensor); + let results = (outs Index); + + let verifier = ?; + let hasFolder = 1; + let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; +} + +//===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 521b3fcab0c6..28981dd87ecc 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -596,6 +596,28 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { } }; +struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { + using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type operandType = op.memref().getType(); + if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { + UnrankedMemRefDescriptor desc(adaptor.memref()); + rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); + return success(); + } + if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { + rewriter.replaceOp( + op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); + return success(); + } + return failure(); + } +}; + struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; @@ -1549,6 +1571,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, + RankOpLowering, ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, StoreOpLowering, diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index e1e24faa4d2a..5a1af7b33132 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -577,7 +577,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); - Value rank = rewriter.create<mlir::RankOp>(loc, tensor); + Value rank = rewriter.create<tensor::RankOp>(loc, tensor); rewriter.replaceOpWithNewOp<tensor::GenerateOp>( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 200834a2d1bc..f588521ac6ef 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -566,28 +566,6 @@ struct UnrealizedConversionCastOpLowering } }; -struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> { - using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(RankOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Type operandType = op.getMemrefOrTensor().getType(); - if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) { - UnrankedMemRefDescriptor desc(adaptor.getMemrefOrTensor()); - rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); - return success(); - } - if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) { - rewriter.replaceOp( - op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); - return success(); - } - return failure(); - } -}; - // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. @@ -987,7 +965,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, CondBranchOpLowering, ConstantOpLowering, GenericAtomicRMWOpLowering, - RankOpLowering, ReturnOpLowering, SelectOpLowering, SplatOpLowering, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 4badc0b31ddb..1916ffe36dd6 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1073,6 +1073,19 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, } //===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { + // Constant fold rank when the rank of the operand is known. + auto type = getOperand().getType(); + auto shapedType = type.dyn_cast<ShapedType>(); + if (shapedType && shapedType.hasRank()) + return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); + return IntegerAttr(); +} + +//===----------------------------------------------------------------------===// // ReinterpretCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 1a43c0937d03..1d045b291215 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -900,20 +900,6 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) { } //===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { - // Constant fold rank when the rank of the operand is known. - auto type = getOperand().getType(); - if (auto shapedType = type.dyn_cast<ShapedType>()) - if (shapedType.hasRank()) - return IntegerAttr::get(IndexType::get(getContext()), - shapedType.getRank()); - return IntegerAttr(); -} - -//===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index edddfb86e553..ecdd966a3c35 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -610,6 +610,19 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { + // Constant fold rank when the rank of the operand is known. + auto type = getOperand().getType(); + auto shapedType = type.dyn_cast<ShapedType>(); + if (shapedType && shapedType.hasRank()) + return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); + return IntegerAttr(); +} + +//===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp index 64a005dfb55b..27e00a14c0d4 100644 --- a/mlir/lib/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Transforms/BufferOptimizations.cpp @@ -37,14 +37,16 @@ static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, if (!type || !alloc.getDefiningOp<memref::AllocOp>()) return false; if (!type.hasStaticShape()) { - // Check if the dynamic shape dimension of the alloc is produced by RankOp. - // If this is the case, it is likely to be small. Furthermore, the dimension - // is limited to the maximum rank of the allocated memref to avoid large - // values by multiplying several small values. + // Check if the dynamic shape dimension of the alloc is produced by + // `memref.rank`. If this is the case, it is likely to be small. + // Furthermore, the dimension is limited to the maximum rank of the + // allocated memref to avoid large values by multiplying several small + // values. if (type.getRank() <= maxRankOfAllocatedMemRef) { - return llvm::all_of( - alloc.getDefiningOp()->getOperands(), - [&](Value operand) { return operand.getDefiningOp<RankOp>(); }); + return llvm::all_of(alloc.getDefiningOp()->getOperands(), + [&](Value operand) { + return operand.getDefiningOp<memref::RankOp>(); + }); } return false; } diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index a26638a34151..009106f95e8a 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt -convert-memref-to-llvm %s -split-input-file | FileCheck %s // RUN: mlir-opt -convert-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s - // CHECK-LABEL: func @view( // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index func @view(%arg0 : index, %arg1 : index, %arg2 : index) { @@ -835,3 +834,24 @@ func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { // CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.mul %{{.*}}, %{{.*}} : i64 + +// ----- + +// CHECK-LABEL: func @rank_of_unranked +// CHECK32-LABEL: func @rank_of_unranked +func @rank_of_unranked(%unranked: memref<*xi32>) { + %rank = memref.rank %unranked : memref<*xi32> + return +} +// CHECK: %[[UNRANKED_DESC:.*]] = builtin.unrealized_conversion_cast +// CHECK-NEXT: llvm.extractvalue %[[UNRANKED_DESC]][0] : !llvm.struct<(i64, ptr<i8>)> +// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr<i8>)> + +// CHECK-LABEL: func @rank_of_ranked +// CHECK32-LABEL: func @rank_of_ranked +func @rank_of_ranked(%ranked: memref<?xi32>) { + %rank = memref.rank %ranked : memref<?xi32> + return +} +// CHECK: llvm.mlir.constant(1 : index) : i64 +// CHECK32: llvm.mlir.constant(1 : index) : i32 diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 015cb2fcaaf4..ea0ef33862ce 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -203,7 +203,7 @@ func @shape_of(%arg : tensor<*xf32>) { // CHECK-LABEL: @shape_of_unranked // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { - // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> + // CHECK: %[[RANK:.*]] = tensor.rank %[[ARG]] : tensor<*xf32> // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] { // CHECK: ^bb0(%[[I:.*]]: index): // CHECK: %[[EXTENT:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32> diff --git a/mlir/test/Conversion/StandardToLLVM/rank.mlir b/mlir/test/Conversion/StandardToLLVM/rank.mlir deleted file mode 100644 index 7c0a03aa8df3..000000000000 --- a/mlir/test/Conversion/StandardToLLVM/rank.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s - -// CHECK-LABEL: func @rank_of_unranked -// CHECK32-LABEL: func @rank_of_unranked -func @rank_of_unranked(%unranked: memref<*xi32>) { - %rank = rank %unranked : memref<*xi32> - return -} -// CHECK-NEXT: llvm.mlir.undef -// CHECK-NEXT: llvm.insertvalue -// CHECK-NEXT: llvm.insertvalue -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, ptr<i8>)> -// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr<i8>)> - -// CHECK-LABEL: func @rank_of_ranked -// CHECK32-LABEL: func @rank_of_ranked -func @rank_of_ranked(%ranked: memref<?xi32>) { - %rank = rank %ranked : memref<?xi32> - return -} -// CHECK: llvm.mlir.constant(1 : index) : i64 -// CHECK32: llvm.mlir.constant(1 : index) : i32 diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 251658fac765..80282c21afab 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -185,10 +185,10 @@ func @dim_of_alloca(%size: index) -> index { // Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) // CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> -// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = memref.rank %[[MEM]] : memref<*xf32> // CHECK-NEXT: return %[[RANK]] : index func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { - %0 = rank %arg0 : memref<*xf32> + %0 = memref.rank %arg0 : memref<*xf32> %1 = memref.alloca(%0) : memref<?xindex> %c0 = arith.constant 0 : index %2 = memref.dim %1, %c0 : memref<?xindex> @@ -438,3 +438,15 @@ func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index) // CHECK: %[[RESULT:.+]] = memref.subview // CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}> // CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: func @fold_rank_memref +func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) { + // Fold a rank into a constant + // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index + %rank_0 = memref.rank %arg0 : memref<?x?xf32> + + // CHECK-NEXT: return [[C2]] + return %rank_0 : index +} diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 6014687e6e9d..55c5a821fb3d 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -844,3 +844,11 @@ func @test_alloc_memref_map_rank_mismatch() { %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> return } + +// ----- + +func @rank(%0: f32) { + // expected-error@+1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}} + "memref.rank"(%0): (f32)->index + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index f716c5de2174..4ff2f8b5517b 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -207,3 +207,14 @@ func @collapse_shape_to_dynamic // CHECK: func @collapse_shape_to_dynamic // CHECK: memref.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] + +// ----- + +func @rank(%t : memref<4x4x?xf32>) { + // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32> + %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index + + // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32> + %1 = memref.rank %t : memref<4x4x?xf32> + return +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index fc9abe439b8a..ec9601e26993 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -183,7 +183,7 @@ func @extract_oob_from_tensor.from_elements(%element : index) -> index { // CHECK-LABEL: func @extract_from_tensor.generate // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> + %size = tensor.rank %tensor : tensor<*xf32> // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]] %0 = tensor.generate %size { ^bb0(%arg0: index): @@ -200,7 +200,7 @@ func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index // CHECK-LABEL: func @extract_from_tensor.generate_2d // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> + %size = tensor.rank %tensor : tensor<*xf32> // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]] // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]] // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]] @@ -221,7 +221,7 @@ func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tenso // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects // CHECK-SAME: %[[IDX:.*]]: index func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index { - %size = rank %tensor : tensor<*xf32> + %size = tensor.rank %tensor : tensor<*xf32> // CHECK: %[[DTENSOR:.*]] = tensor.generate %0 = tensor.generate %size { ^bb0(%arg0: index): @@ -900,3 +900,18 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> { // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64> // CHECK-NOT: tensor.expand_shape // CHECK: return %[[CST]] + +// ----- + +// CHECK-LABEL: func @fold_rank +func @fold_rank() -> (index) { + %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> + : tensor<2x1x4xi32> + + // Fold a ank into a constant + // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index + %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32> + + // CHECK-NEXT: return [[C3]] + return %rank_0 : index +} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 8b40ec80e02d..564526f16370 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -292,3 +292,11 @@ func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) : tensor<?x4x5xf32> into tensor<?x?xf32> return %0 : tensor<?x?xf32> } + +// ----- + +func @rank(%0: f32) { + // expected-error@+1 {{'tensor.rank' op operand #0 must be tensor of any type values}} + "tensor.rank"(%0): (f32)->index + return +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 63afc1f382b3..8d50d1518421 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -160,3 +160,14 @@ func @legal_collapsing_reshape_dynamic_tensor // CHECK: func @legal_collapsing_reshape_dynamic_tensor // CHECK: tensor.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] + +// ----- + +func @rank(%t : tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32> + %0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index + + // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32> + %1 = tensor.rank %t : tensor<4x4x?xf32> + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index fe2d7207d3d0..b83f530eeacc 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -99,12 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32> %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32> - // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32> - %71 = "std.rank"(%t) : (tensor<4x4x?xf32>) -> index - - // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32> - %72 = rank %t : tensor<4x4x?xf32> - // CHECK: = constant unit %73 = constant unit diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 13cfd16daf9a..49f29f09bf49 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1,13 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics -func @rank(f32) { -^bb(%0: f32): - "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any memref or tensor type}} - - return -} - -// ----- func @affine_apply_no_map() { ^bb0: %i = arith.constant 0 : index diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 5406a8588ce4..2e720eae3439 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -754,32 +754,6 @@ func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, // ----- -// CHECK-LABEL: func @fold_rank -func @fold_rank() -> (index) { - %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> - - // Fold a rank into a constant - // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index - %rank_0 = rank %const_0 : tensor<2x1x4xi32> - - // CHECK-NEXT: return [[C3]] - return %rank_0 : index -} - -// ----- - -// CHECK-LABEL: func @fold_rank_memref -func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) { - // Fold a rank into a constant - // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index - %rank_0 = rank %arg0 : memref<?x?xf32> - - // CHECK-NEXT: return [[C2]] - return %rank_0 : index -} - -// ----- - // CHECK-LABEL: func @nested_isolated_region func @nested_isolated_region() { // CHECK-NEXT: func @isolated_op diff --git a/mlir/test/Transforms/promote-buffers-to-stack.mlir b/mlir/test/Transforms/promote-buffers-to-stack.mlir index c78f8a71dbb7..2b6cd3185fa1 100644 --- a/mlir/test/Transforms/promote-buffers-to-stack.mlir +++ b/mlir/test/Transforms/promote-buffers-to-stack.mlir @@ -77,25 +77,25 @@ func @condBranchDynamicType( // ----- // CHECK-LABEL: func @dynamicRanked -func @dynamicRanked(%tensor: tensor<*xf32>) { - %0 = rank %tensor : tensor<*xf32> +func @dynamicRanked(%memref: memref<*xf32>) { + %0 = memref.rank %memref : memref<*xf32> %1 = memref.alloc(%0) : memref<?xindex> return } -// CHECK-NEXT: %[[RANK:.*]] = rank +// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32> // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]]) // ----- // CHECK-LABEL: func @dynamicRanked2D -func @dynamicRanked2D(%tensor: tensor<*xf32>) { - %0 = rank %tensor : tensor<*xf32> +func @dynamicRanked2D(%memref: memref<*xf32>) { + %0 = memref.rank %memref : memref<*xf32> %1 = memref.alloc(%0, %0) : memref<?x?xindex> return } -// CHECK-NEXT: %[[RANK:.*]] = rank +// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32> // RANK-NEXT: %[[ALLOC:.*]] = memref.alloca(%[[RANK]], %[[RANK]]) // DEFINDEX-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[RANK]], %[[RANK]]) |