summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Belyaev <pifon@google.com>2021-12-14 09:35:14 +0100
committerAlexander Belyaev <pifon@google.com>2021-12-14 10:15:55 +0100
commit15f8f3e20aa92349b0cb559d657f7648987edb06 (patch)
tree89d9cfafc4202967b6ac1f713bf85d5008b4ac63
parentef5be2bb16e51c2f6fff622a43cc71268acc6ddc (diff)
downloadllvm-15f8f3e20aa92349b0cb559d657f7648987edb06.tar.gz
[mlir] Split std.rank into tensor.rank and memref.rank.
Move `std.rank` similarly to how `std.dim` was moved to TensorOps and MemRefOps. Differential Revision: https://reviews.llvm.org/D115665
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td25
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/IR/Ops.td26
-rw-r--r--mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td31
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp23
-rw-r--r--mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp2
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp23
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp13
-rw-r--r--mlir/lib/Dialect/StandardOps/IR/Ops.cpp14
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp13
-rw-r--r--mlir/lib/Transforms/BufferOptimizations.cpp16
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir22
-rw-r--r--mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir2
-rw-r--r--mlir/test/Conversion/StandardToLLVM/rank.mlir23
-rw-r--r--mlir/test/Dialect/MemRef/canonicalize.mlir16
-rw-r--r--mlir/test/Dialect/MemRef/invalid.mlir8
-rw-r--r--mlir/test/Dialect/MemRef/ops.mlir11
-rw-r--r--mlir/test/Dialect/Tensor/canonicalize.mlir21
-rw-r--r--mlir/test/Dialect/Tensor/invalid.mlir8
-rw-r--r--mlir/test/Dialect/Tensor/ops.mlir11
-rw-r--r--mlir/test/IR/core-ops.mlir6
-rw-r--r--mlir/test/IR/invalid-ops.mlir8
-rw-r--r--mlir/test/Transforms/constant-fold.mlir26
-rw-r--r--mlir/test/Transforms/promote-buffers-to-stack.mlir12
23 files changed, 210 insertions, 150 deletions
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c6b7e28fe0aa..e529a50dae93 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -999,6 +999,31 @@ def MemRef_ReinterpretCastOp:
}
//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
+ let summary = "rank operation";
+ let description = [{
+ The `memref.rank` operation takes a memref operand and returns its rank.
+
+ Example:
+
+ ```mlir
+ %0 = memref.rank %arg0 : memref<*xf32>
+ %1 = memref.rank %arg1 : memref<?x?xf32>
+ ```
+ }];
+
+ let arguments = (ins AnyRankedOrUnrankedMemRef:$memref);
+ let results = (outs Index);
+
+ let verifier = ?;
+ let hasFolder = 1;
+ let assemblyFormat = "$memref attr-dict `:` type($memref)";
+}
+
+//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 23b9df282af0..2e50971db9e7 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -659,32 +659,6 @@ def ConstantOp : Std_Op<"constant",
}
//===----------------------------------------------------------------------===//
-// RankOp
-//===----------------------------------------------------------------------===//
-
-def RankOp : Std_Op<"rank", [NoSideEffect]> {
- let summary = "rank operation";
- let description = [{
- The `rank` operation takes a memref/tensor operand and returns its rank.
-
- Example:
-
- ```mlir
- %1 = rank %arg0 : tensor<*xf32>
- %2 = rank %arg1 : memref<*xf32>
- ```
- }];
-
- let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
- "any memref or tensor type">:$memrefOrTensor);
- let results = (outs Index);
- let verifier = ?;
-
- let hasFolder = 1;
- let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
-}
-
-//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3b1bfeeca6c1..21331fc649cd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -68,9 +68,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
let summary = "dimension index operation";
let description = [{
- The `dim` operation takes a tensor and a dimension operand of type `index`.
- It returns the size of the requested dimension of the given tensor.
- If the dimension index is out of bounds, the behavior is undefined.
+ The `tensor.dim` operation takes a tensor and a dimension operand of type
+ `index`. It returns the size of the requested dimension of the given
+ tensor. If the dimension index is out of bounds, the behavior is undefined.
The specified tensor type is that of the first operand.
@@ -559,6 +559,31 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
}
//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
+ let summary = "rank operation";
+ let description = [{
+ The `tensor.rank` operation takes a tensor operand and returns its rank.
+
+ Example:
+
+ ```mlir
+ %0 = tensor.rank %arg0 : tensor<*xf32>
+ %1 = tensor.rank %arg1 : tensor<?x?xf32>
+ ```
+ }];
+
+ let arguments = (ins AnyTensor:$tensor);
+ let results = (outs Index);
+
+ let verifier = ?;
+ let hasFolder = 1;
+ let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
+}
+
+//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 521b3fcab0c6..28981dd87ecc 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -596,6 +596,28 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
}
};
+struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
+ using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type operandType = op.memref().getType();
+ if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
+ UnrankedMemRefDescriptor desc(adaptor.memref());
+ rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
+ return success();
+ }
+ if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
+ rewriter.replaceOp(
+ op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
+ return success();
+ }
+ return failure();
+ }
+};
+
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
@@ -1549,6 +1571,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
PrefetchOpLowering,
+ RankOpLowering,
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
StoreOpLowering,
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e1e24faa4d2a..5a1af7b33132 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -577,7 +577,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Lower to `tensor.generate` otherwise.
auto *ctx = rewriter.getContext();
- Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
+ Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 200834a2d1bc..f588521ac6ef 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -566,28 +566,6 @@ struct UnrealizedConversionCastOpLowering
}
};
-struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
- using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(RankOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Type operandType = op.getMemrefOrTensor().getType();
- if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
- UnrankedMemRefDescriptor desc(adaptor.getMemrefOrTensor());
- rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
- return success();
- }
- if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
- rewriter.replaceOp(
- op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
- return success();
- }
- return failure();
- }
-};
-
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
@@ -987,7 +965,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
CondBranchOpLowering,
ConstantOpLowering,
GenericAtomicRMWOpLowering,
- RankOpLowering,
ReturnOpLowering,
SelectOpLowering,
SplatOpLowering,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 4badc0b31ddb..1916ffe36dd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1073,6 +1073,19 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
}
//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+ // Constant fold rank when the rank of the operand is known.
+ auto type = getOperand().getType();
+ auto shapedType = type.dyn_cast<ShapedType>();
+ if (shapedType && shapedType.hasRank())
+ return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
+ return IntegerAttr();
+}
+
+//===----------------------------------------------------------------------===//
// ReinterpretCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1a43c0937d03..1d045b291215 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -900,20 +900,6 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
}
//===----------------------------------------------------------------------===//
-// RankOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
- // Constant fold rank when the rank of the operand is known.
- auto type = getOperand().getType();
- if (auto shapedType = type.dyn_cast<ShapedType>())
- if (shapedType.hasRank())
- return IntegerAttr::get(IndexType::get(getContext()),
- shapedType.getRank());
- return IntegerAttr();
-}
-
-//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index edddfb86e553..ecdd966a3c35 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -610,6 +610,19 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+ // Constant fold rank when the rank of the operand is known.
+ auto type = getOperand().getType();
+ auto shapedType = type.dyn_cast<ShapedType>();
+ if (shapedType && shapedType.hasRank())
+ return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
+ return IntegerAttr();
+}
+
+//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp
index 64a005dfb55b..27e00a14c0d4 100644
--- a/mlir/lib/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Transforms/BufferOptimizations.cpp
@@ -37,14 +37,16 @@ static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
if (!type || !alloc.getDefiningOp<memref::AllocOp>())
return false;
if (!type.hasStaticShape()) {
- // Check if the dynamic shape dimension of the alloc is produced by RankOp.
- // If this is the case, it is likely to be small. Furthermore, the dimension
- // is limited to the maximum rank of the allocated memref to avoid large
- // values by multiplying several small values.
+ // Check if the dynamic shape dimension of the alloc is produced by
+ // `memref.rank`. If this is the case, it is likely to be small.
+ // Furthermore, the dimension is limited to the maximum rank of the
+ // allocated memref to avoid large values by multiplying several small
+ // values.
if (type.getRank() <= maxRankOfAllocatedMemRef) {
- return llvm::all_of(
- alloc.getDefiningOp()->getOperands(),
- [&](Value operand) { return operand.getDefiningOp<RankOp>(); });
+ return llvm::all_of(alloc.getDefiningOp()->getOperands(),
+ [&](Value operand) {
+ return operand.getDefiningOp<memref::RankOp>();
+ });
}
return false;
}
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index a26638a34151..009106f95e8a 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1,7 +1,6 @@
// RUN: mlir-opt -convert-memref-to-llvm %s -split-input-file | FileCheck %s
// RUN: mlir-opt -convert-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
-
// CHECK-LABEL: func @view(
// CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
@@ -835,3 +834,24 @@ func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
+
+// -----
+
+// CHECK-LABEL: func @rank_of_unranked
+// CHECK32-LABEL: func @rank_of_unranked
+func @rank_of_unranked(%unranked: memref<*xi32>) {
+ %rank = memref.rank %unranked : memref<*xi32>
+ return
+}
+// CHECK: %[[UNRANKED_DESC:.*]] = builtin.unrealized_conversion_cast
+// CHECK-NEXT: llvm.extractvalue %[[UNRANKED_DESC]][0] : !llvm.struct<(i64, ptr<i8>)>
+// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr<i8>)>
+
+// CHECK-LABEL: func @rank_of_ranked
+// CHECK32-LABEL: func @rank_of_ranked
+func @rank_of_ranked(%ranked: memref<?xi32>) {
+ %rank = memref.rank %ranked : memref<?xi32>
+ return
+}
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK32: llvm.mlir.constant(1 : index) : i32
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 015cb2fcaaf4..ea0ef33862ce 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -203,7 +203,7 @@ func @shape_of(%arg : tensor<*xf32>) {
// CHECK-LABEL: @shape_of_unranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of_unranked(%arg : tensor<*xf32>) {
- // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+ // CHECK: %[[RANK:.*]] = tensor.rank %[[ARG]] : tensor<*xf32>
// CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] {
// CHECK: ^bb0(%[[I:.*]]: index):
// CHECK: %[[EXTENT:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32>
diff --git a/mlir/test/Conversion/StandardToLLVM/rank.mlir b/mlir/test/Conversion/StandardToLLVM/rank.mlir
deleted file mode 100644
index 7c0a03aa8df3..000000000000
--- a/mlir/test/Conversion/StandardToLLVM/rank.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s
-// RUN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
-
-// CHECK-LABEL: func @rank_of_unranked
-// CHECK32-LABEL: func @rank_of_unranked
-func @rank_of_unranked(%unranked: memref<*xi32>) {
- %rank = rank %unranked : memref<*xi32>
- return
-}
-// CHECK-NEXT: llvm.mlir.undef
-// CHECK-NEXT: llvm.insertvalue
-// CHECK-NEXT: llvm.insertvalue
-// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, ptr<i8>)>
-// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr<i8>)>
-
-// CHECK-LABEL: func @rank_of_ranked
-// CHECK32-LABEL: func @rank_of_ranked
-func @rank_of_ranked(%ranked: memref<?xi32>) {
- %rank = rank %ranked : memref<?xi32>
- return
-}
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK32: llvm.mlir.constant(1 : index) : i32
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 251658fac765..80282c21afab 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -185,10 +185,10 @@ func @dim_of_alloca(%size: index) -> index {
// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>
-// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
+// CHECK-NEXT: %[[RANK:.*]] = memref.rank %[[MEM]] : memref<*xf32>
// CHECK-NEXT: return %[[RANK]] : index
func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
- %0 = rank %arg0 : memref<*xf32>
+ %0 = memref.rank %arg0 : memref<*xf32>
%1 = memref.alloca(%0) : memref<?xindex>
%c0 = arith.constant 0 : index
%2 = memref.dim %1, %c0 : memref<?xindex>
@@ -438,3 +438,15 @@ func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
// CHECK: %[[RESULT:.+]] = memref.subview
// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
// CHECK: return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @fold_rank_memref
+func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
+ // Fold a rank into a constant
+ // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index
+ %rank_0 = memref.rank %arg0 : memref<?x?xf32>
+
+ // CHECK-NEXT: return [[C2]]
+ return %rank_0 : index
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 6014687e6e9d..55c5a821fb3d 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -844,3 +844,11 @@ func @test_alloc_memref_map_rank_mismatch() {
%0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1>
return
}
+
+// -----
+
+func @rank(%0: f32) {
+ // expected-error@+1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}}
+ "memref.rank"(%0): (f32)->index
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index f716c5de2174..4ff2f8b5517b 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -207,3 +207,14 @@ func @collapse_shape_to_dynamic
// CHECK: func @collapse_shape_to_dynamic
// CHECK: memref.collapse_shape
// CHECK-SAME: [0], [1], [2, 3, 4]
+
+// -----
+
+func @rank(%t : memref<4x4x?xf32>) {
+ // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
+ %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index
+
+ // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
+ %1 = memref.rank %t : memref<4x4x?xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fc9abe439b8a..ec9601e26993 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -183,7 +183,7 @@ func @extract_oob_from_tensor.from_elements(%element : index) -> index {
// CHECK-LABEL: func @extract_from_tensor.generate
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
- %size = rank %tensor : tensor<*xf32>
+ %size = tensor.rank %tensor : tensor<*xf32>
// CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]]
%0 = tensor.generate %size {
^bb0(%arg0: index):
@@ -200,7 +200,7 @@ func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index
// CHECK-LABEL: func @extract_from_tensor.generate_2d
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
- %size = rank %tensor : tensor<*xf32>
+ %size = tensor.rank %tensor : tensor<*xf32>
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]]
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]]
// CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]]
@@ -221,7 +221,7 @@ func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tenso
// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
// CHECK-SAME: %[[IDX:.*]]: index
func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
- %size = rank %tensor : tensor<*xf32>
+ %size = tensor.rank %tensor : tensor<*xf32>
// CHECK: %[[DTENSOR:.*]] = tensor.generate
%0 = tensor.generate %size {
^bb0(%arg0: index):
@@ -900,3 +900,18 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64>
// CHECK-NOT: tensor.expand_shape
// CHECK: return %[[CST]]
+
+// -----
+
+// CHECK-LABEL: func @fold_rank
+func @fold_rank() -> (index) {
+ %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
+ : tensor<2x1x4xi32>
+
+ // Fold a ank into a constant
+ // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
+ %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32>
+
+ // CHECK-NEXT: return [[C3]]
+ return %rank_0 : index
+}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 8b40ec80e02d..564526f16370 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -292,3 +292,11 @@ func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
: tensor<?x4x5xf32> into tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func @rank(%0: f32) {
+ // expected-error@+1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
+ "tensor.rank"(%0): (f32)->index
+ return
+}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 63afc1f382b3..8d50d1518421 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -160,3 +160,14 @@ func @legal_collapsing_reshape_dynamic_tensor
// CHECK: func @legal_collapsing_reshape_dynamic_tensor
// CHECK: tensor.collapse_shape
// CHECK-SAME: [0], [1], [2, 3, 4]
+
+// -----
+
+func @rank(%t : tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32>
+ %0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index
+
+ // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32>
+ %1 = tensor.rank %t : tensor<4x4x?xf32>
+ return
+}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index fe2d7207d3d0..b83f530eeacc 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -99,12 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32>
%70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32>
- // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32>
- %71 = "std.rank"(%t) : (tensor<4x4x?xf32>) -> index
-
- // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32>
- %72 = rank %t : tensor<4x4x?xf32>
-
// CHECK: = constant unit
%73 = constant unit
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 13cfd16daf9a..49f29f09bf49 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1,13 +1,5 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
-func @rank(f32) {
-^bb(%0: f32):
- "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any memref or tensor type}}
-
- return
-}
-
-// -----
func @affine_apply_no_map() {
^bb0:
%i = arith.constant 0 : index
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 5406a8588ce4..2e720eae3439 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -754,32 +754,6 @@ func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1,
// -----
-// CHECK-LABEL: func @fold_rank
-func @fold_rank() -> (index) {
- %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
-
- // Fold a rank into a constant
- // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
- %rank_0 = rank %const_0 : tensor<2x1x4xi32>
-
- // CHECK-NEXT: return [[C3]]
- return %rank_0 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @fold_rank_memref
-func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
- // Fold a rank into a constant
- // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index
- %rank_0 = rank %arg0 : memref<?x?xf32>
-
- // CHECK-NEXT: return [[C2]]
- return %rank_0 : index
-}
-
-// -----
-
// CHECK-LABEL: func @nested_isolated_region
func @nested_isolated_region() {
// CHECK-NEXT: func @isolated_op
diff --git a/mlir/test/Transforms/promote-buffers-to-stack.mlir b/mlir/test/Transforms/promote-buffers-to-stack.mlir
index c78f8a71dbb7..2b6cd3185fa1 100644
--- a/mlir/test/Transforms/promote-buffers-to-stack.mlir
+++ b/mlir/test/Transforms/promote-buffers-to-stack.mlir
@@ -77,25 +77,25 @@ func @condBranchDynamicType(
// -----
// CHECK-LABEL: func @dynamicRanked
-func @dynamicRanked(%tensor: tensor<*xf32>) {
- %0 = rank %tensor : tensor<*xf32>
+func @dynamicRanked(%memref: memref<*xf32>) {
+ %0 = memref.rank %memref : memref<*xf32>
%1 = memref.alloc(%0) : memref<?xindex>
return
}
-// CHECK-NEXT: %[[RANK:.*]] = rank
+// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32>
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
// -----
// CHECK-LABEL: func @dynamicRanked2D
-func @dynamicRanked2D(%tensor: tensor<*xf32>) {
- %0 = rank %tensor : tensor<*xf32>
+func @dynamicRanked2D(%memref: memref<*xf32>) {
+ %0 = memref.rank %memref : memref<*xf32>
%1 = memref.alloc(%0, %0) : memref<?x?xindex>
return
}
-// CHECK-NEXT: %[[RANK:.*]] = rank
+// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32>
// RANK-NEXT: %[[ALLOC:.*]] = memref.alloca(%[[RANK]], %[[RANK]])
// DEFINDEX-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[RANK]], %[[RANK]])