diff options
author | Alexander Belyaev <pifon@google.com> | 2021-12-14 19:58:40 +0100 |
---|---|---|
committer | Alexander Belyaev <pifon@google.com> | 2021-12-14 20:04:57 +0100 |
commit | a82a19c137ad0b966847241c40546b3e145a17b5 (patch) | |
tree | 9bc7efced044b18b0bb50a7ed3a0957dfa4e6b33 | |
parent | 74d1fc742af0a5a766dbfa9f7a1a715301e05d3f (diff) | |
download | llvm-a82a19c137ad0b966847241c40546b3e145a17b5.tar.gz |
[mlir] Add a missing pattern to bufferize tensor.rank.
Differential Revision: https://reviews.llvm.org/D115745
-rw-r--r-- | mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp | 69 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/bufferize.mlir | 9 |
2 files changed, 41 insertions, 37 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp index d02328e4230d..0fd5b2d75d67 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -24,8 +24,7 @@ using namespace mlir; namespace { -class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { -public: +struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, @@ -36,11 +35,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { -public: +struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, @@ -50,11 +46,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { -public: +struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, @@ -64,10 +57,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeFromElementsOp +struct BufferizeFromElementsOp : public OpConversionPattern<tensor::FromElementsOp> { public: using OpConversionPattern::OpConversionPattern; @@ -88,11 +79,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { -public: +struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -150,44 +138,51 @@ public: return success(); } }; -} // namespace -void mlir::populateTensorBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, - BufferizeFromElementsOp, BufferizeGenerateOp>( - typeConverter, patterns.getContext()); -} +struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(), + adaptor.tensor()); + return success(); + } +}; -namespace { struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { void runOnFunction() override { auto *context = &getContext(); bufferization::BufferizeTypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - bufferization::populateBufferizeMaterializationLegality(target); - populateTensorBufferizePatterns(typeConverter, patterns); - target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, - tensor::FromElementsOp, tensor::GenerateOp>(); - target.addLegalDialect<memref::MemRefDialect>(); + ConversionTarget target(*context); + target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>(); target.addDynamicallyLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>( [&](Operation *op) { return typeConverter.isLegal(op); }); - target.addLegalOp<CallOp>(); - target.addLegalOp<ReturnOp>(); - target.addLegalDialect<scf::SCFDialect>(); + target.addLegalOp<CallOp, ReturnOp>(); + target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, + tensor::FromElementsOp, tensor::GenerateOp>(); + bufferization::populateBufferizeMaterializationLegality(target); + RewritePatternSet patterns(context); + populateTensorBufferizePatterns(typeConverter, patterns); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; + } // namespace +void mlir::populateTensorBufferizePatterns( + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, + BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>( + typeConverter, patterns.getContext()); +} + std::unique_ptr<Pass> mlir::createTensorBufferizePass() { return std::make_unique<TensorBufferizePass>(); } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 91642f06d0f2..5b3bb149d618 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -11,6 +11,15 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index { return %0 : index } +// CHECK-LABEL: func @rank( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index { +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] +// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32> +func @rank(%arg0: tensor<*xf32>) -> index { + %0 = tensor.rank %arg0 : tensor<*xf32> + return %0 : index +} + // CHECK-LABEL: func @tensor.cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] |