summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp')
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp69
1 files changed, 32 insertions, 37 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index d02328e4230d..0fd5b2d75d67 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -24,8 +24,7 @@
using namespace mlir;
namespace {
-class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
-public:
+struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@@ -36,11 +35,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
-public:
+struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
@@ -50,11 +46,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
-public:
+struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
@@ -64,10 +57,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeFromElementsOp
+struct BufferizeFromElementsOp
: public OpConversionPattern<tensor::FromElementsOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -88,11 +79,8 @@ public:
return success();
}
};
-} // namespace
-namespace {
-class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
-public:
+struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -150,44 +138,51 @@ public:
return success();
}
};
-} // namespace
-void mlir::populateTensorBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
- BufferizeFromElementsOp, BufferizeGenerateOp>(
- typeConverter, patterns.getContext());
-}
+struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
+ adaptor.tensor());
+ return success();
+ }
+};
-namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
bufferization::BufferizeTypeConverter typeConverter;
- RewritePatternSet patterns(context);
- ConversionTarget target(*context);
-
- bufferization::populateBufferizeMaterializationLegality(target);
- populateTensorBufferizePatterns(typeConverter, patterns);
- target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
- tensor::FromElementsOp, tensor::GenerateOp>();
- target.addLegalDialect<memref::MemRefDialect>();
+ ConversionTarget target(*context);
+ target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
- target.addLegalOp<CallOp>();
- target.addLegalOp<ReturnOp>();
- target.addLegalDialect<scf::SCFDialect>();
+ target.addLegalOp<CallOp, ReturnOp>();
+ target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
+ tensor::FromElementsOp, tensor::GenerateOp>();
+ bufferization::populateBufferizeMaterializationLegality(target);
+ RewritePatternSet patterns(context);
+ populateTensorBufferizePatterns(typeConverter, patterns);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
+
} // namespace
+void mlir::populateTensorBufferizePatterns(
+ bufferization::BufferizeTypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
+ BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
+ typeConverter, patterns.getContext());
+}
+
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}