summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Belyaev <pifon@google.com>2021-12-14 19:58:40 +0100
committerAlexander Belyaev <pifon@google.com>2021-12-14 20:04:57 +0100
commita82a19c137ad0b966847241c40546b3e145a17b5 (patch)
tree9bc7efced044b18b0bb50a7ed3a0957dfa4e6b33
parent74d1fc742af0a5a766dbfa9f7a1a715301e05d3f (diff)
downloadllvm-a82a19c137ad0b966847241c40546b3e145a17b5.tar.gz
[mlir] Add a missing pattern to bufferize tensor.rank.
Differential Revision: https://reviews.llvm.org/D115745
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp69
-rw-r--r--mlir/test/Dialect/Tensor/bufferize.mlir9
2 files changed, 41 insertions, 37 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index d02328e4230d..0fd5b2d75d67 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -24,8 +24,7 @@
using namespace mlir;
namespace {
-class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
-public:
+struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@@ -36,11 +35,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
-public:
+struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
@@ -50,11 +46,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
-public:
+struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
@@ -64,10 +57,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeFromElementsOp
+struct BufferizeFromElementsOp
: public OpConversionPattern<tensor::FromElementsOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -88,11 +79,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
-public:
+struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -150,44 +138,51 @@ public:
return success();
}
};
-} // namespace
-void mlir::populateTensorBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
- BufferizeFromElementsOp, BufferizeGenerateOp>(
- typeConverter, patterns.getContext());
-}
+struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
+ adaptor.tensor());
+ return success();
+ }
+};
-namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
bufferization::BufferizeTypeConverter typeConverter;
- RewritePatternSet patterns(context);
- ConversionTarget target(*context);
-
- bufferization::populateBufferizeMaterializationLegality(target);
- populateTensorBufferizePatterns(typeConverter, patterns);
- target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
- tensor::FromElementsOp, tensor::GenerateOp>();
- target.addLegalDialect<memref::MemRefDialect>();
+ ConversionTarget target(*context);
+ target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
- target.addLegalOp<CallOp>();
- target.addLegalOp<ReturnOp>();
- target.addLegalDialect<scf::SCFDialect>();
+ target.addLegalOp<CallOp, ReturnOp>();
+ target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
+ tensor::FromElementsOp, tensor::GenerateOp>();
+ bufferization::populateBufferizeMaterializationLegality(target);
+ RewritePatternSet patterns(context);
+ populateTensorBufferizePatterns(typeConverter, patterns);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
+
} // namespace
+void mlir::populateTensorBufferizePatterns(
+ bufferization::BufferizeTypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
+ BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
+ typeConverter, patterns.getContext());
+}
+
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 91642f06d0f2..5b3bb149d618 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -11,6 +11,15 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
return %0 : index
}
+// CHECK-LABEL: func @rank(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index {
+// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
+// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
+func @rank(%arg0: tensor<*xf32>) -> index {
+ %0 = tensor.rank %arg0 : tensor<*xf32>
+ return %0 : index
+}
+
// CHECK-LABEL: func @tensor.cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]