diff options
Diffstat (limited to 'flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp')
-rw-r--r-- | flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index fbf7670413df..1cf3929c1c04 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -213,6 +213,38 @@ using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>; using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>; +struct AnyOpConversion : public HlfirIntrinsicConversion<hlfir::AnyOp> { + using HlfirIntrinsicConversion<hlfir::AnyOp>::HlfirIntrinsicConversion; + + mlir::LogicalResult + matchAndRewrite(hlfir::AnyOp any, + mlir::PatternRewriter &rewriter) const override { + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + const mlir::Location &loc = any->getLoc(); + + mlir::Type i32 = builder.getI32Type(); + mlir::Type logicalType = fir::LogicalType::get( + builder.getContext(), builder.getKindMap().defaultLogicalKind()); + llvm::SmallVector<IntrinsicArgument, 2> inArgs; + inArgs.push_back({any.getMask(), logicalType}); + inArgs.push_back({any.getDim(), i32}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("any"); + llvm::SmallVector<fir::ExtendedValue, 2> args = + this->lowerArguments(any, inArgs, rewriter, argLowering); + + mlir::Type resultType = hlfir::getFortranElementType(any.getType()); + + auto [resultExv, mustBeFreed] = + fir::genIntrinsicCall(builder, loc, "any", resultType, args); + + this->processReturnValue(any, resultExv, mustBeFreed, builder, rewriter); + + return mlir::success(); + } +}; + struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> { using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion; @@ -321,16 +353,15 @@ public: mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns - .insert<MatmulOpConversion, MatmulTransposeOpConversion, - SumOpConversion, ProductOpConversion, TransposeOpConversion>( - context); + patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion, + AnyOpConversion, SumOpConversion, ProductOpConversion, + TransposeOpConversion>(context); mlir::ConversionTarget target(*context); target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect, mlir::func::FuncDialect, fir::FIROpsDialect, hlfir::hlfirDialect>(); target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp, - hlfir::ProductOp, hlfir::TransposeOp>(); + hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp>(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed( |