diff options
Diffstat (limited to 'mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp')
-rw-r--r-- | mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp | 69 |
1 files changed, 32 insertions, 37 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp index d02328e4230d..0fd5b2d75d67 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -24,8 +24,7 @@ using namespace mlir; namespace { -class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { -public: +struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, @@ -36,11 +35,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { -public: +struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, @@ -50,11 +46,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { -public: +struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, @@ -64,10 +57,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeFromElementsOp +struct BufferizeFromElementsOp : public OpConversionPattern<tensor::FromElementsOp> { public: using OpConversionPattern::OpConversionPattern; @@ -88,11 +79,8 @@ public: return success(); } }; -} // namespace -namespace { -class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { -public: +struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -150,44 +138,51 @@ public: return success(); } }; -} // namespace -void mlir::populateTensorBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, - BufferizeFromElementsOp, BufferizeGenerateOp>( - typeConverter, patterns.getContext()); -} +struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(), + adaptor.tensor()); + return success(); + } +}; -namespace { struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { void runOnFunction() override { auto *context = &getContext(); bufferization::BufferizeTypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - bufferization::populateBufferizeMaterializationLegality(target); - populateTensorBufferizePatterns(typeConverter, patterns); - target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, - tensor::FromElementsOp, tensor::GenerateOp>(); - target.addLegalDialect<memref::MemRefDialect>(); + ConversionTarget target(*context); + target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>(); target.addDynamicallyLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>( [&](Operation *op) { return typeConverter.isLegal(op); }); - target.addLegalOp<CallOp>(); - target.addLegalOp<ReturnOp>(); - target.addLegalDialect<scf::SCFDialect>(); + target.addLegalOp<CallOp, ReturnOp>(); + target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, + tensor::FromElementsOp, tensor::GenerateOp>(); + bufferization::populateBufferizeMaterializationLegality(target); + RewritePatternSet patterns(context); + populateTensorBufferizePatterns(typeConverter, patterns); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; + } // namespace +void mlir::populateTensorBufferizePatterns( + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, + BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>( + typeConverter, patterns.getContext()); +} + std::unique_ptr<Pass> mlir::createTensorBufferizePass() { return std::make_unique<TensorBufferizePass>(); } |