diff options
author | Jean Perier <jperier@nvidia.com> | 2023-05-09 09:21:09 +0200 |
---|---|---|
committer | Jean Perier <jperier@nvidia.com> | 2023-05-09 09:21:27 +0200 |
commit | 54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25 (patch) | |
tree | 8ca864ee51019096d67928c1990c7fd87ce1b107 /flang | |
parent | b87e65531c58df55cfae4c06c7a68f84539aa779 (diff) | |
download | llvm-54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25.tar.gz |
[flang][hlfir] Lower WHERE to HLFIR
Lower WHERE to the newly added hlfir.where and hlfir.elsewhere
operations.
Differential Revision: https://reviews.llvm.org/D149950
Diffstat (limited to 'flang')
-rw-r--r-- | flang/lib/Lower/Bridge.cpp | 108 | ||||
-rw-r--r-- | flang/test/Lower/HLFIR/where.f90 | 170 |
2 files changed, 264 insertions, 14 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index fe86fe8cb2dd..acf3768dfdd8 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3154,7 +3154,7 @@ private: // Gather some information about the assignment that will impact how it is // lowered. const bool isWholeAllocatableAssignment = - !userDefinedAssignment && + !userDefinedAssignment && !isInsideHlfirWhere() && Fortran::lower::isWholeAllocatable(assign.lhs); std::optional<Fortran::evaluate::DynamicType> lhsType = assign.lhs.GetType(); @@ -3243,8 +3243,6 @@ private: void genAssignment(const Fortran::evaluate::Assignment &assign) { mlir::Location loc = toLocation(); if (lowerToHighLevelFIR()) { - if (!implicitIterSpace.empty()) - TODO(loc, "HLFIR assignment inside WHERE"); std::visit( Fortran::common::visitors{ [&](const Fortran::evaluate::Assignment::Intrinsic &) { @@ -3452,23 +3450,47 @@ private: Fortran::lower::createArrayMergeStores(*this, explicitIterSpace); } - bool isInsideHlfirForallOrWhere() const { + // Is the insertion point of the builder directly or indirectly set + // inside any operation of type "Op"? + template <typename... Op> + bool isInsideOp() const { mlir::Block *block = builder->getInsertionBlock(); mlir::Operation *op = block ? block->getParentOp() : nullptr; while (op) { - if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op)) + if (mlir::isa<Op...>(op)) return true; op = op->getParentOp(); } return false; } + bool isInsideHlfirForallOrWhere() const { + return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>(); + } + bool isInsideHlfirWhere() const { return isInsideOp<hlfir::WhereOp>(); } void genFIR(const Fortran::parser::WhereConstruct &c) { - implicitIterSpace.growStack(); + mlir::Location loc = getCurrentLocation(); + hlfir::WhereOp whereOp; + + if (!lowerToHighLevelFIR()) { + implicitIterSpace.growStack(); + } else { + whereOp = builder->create<hlfir::WhereOp>(loc); + builder->createBlock(&whereOp.getMaskRegion()); + } + + // Lower the where mask. For HLFIR, this is done in the hlfir.where mask + // region. genNestedStatement( std::get< Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>( c.t)); + + // Lower WHERE body. For HLFIR, this is done in the hlfir.where body + // region. + if (whereOp) + builder->createBlock(&whereOp.getBody()); + for (const auto &body : std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t)) genFIR(body); @@ -3484,6 +3506,13 @@ private: genNestedStatement( std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>( c.t)); + + if (whereOp) { + // For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or + // in the hlfir.where if it had no elsewhere. + builder->create<fir::FirEndOp>(loc); + builder->setInsertionPointAfter(whereOp); + } } void genFIR(const Fortran::parser::WhereBodyConstruct &body) { std::visit( @@ -3499,24 +3528,61 @@ private: }, body.u); } + + /// Lower a Where or Elsewhere mask into an hlfir mask region. + void lowerWhereMaskToHlfir(mlir::Location loc, + const Fortran::semantics::SomeExpr *maskExpr) { + assert(maskExpr && "mask semantic analysis failed"); + Fortran::lower::StatementContext maskContext; + hlfir::Entity mask = Fortran::lower::convertExprToHLFIR( + loc, *this, *maskExpr, localSymbols, maskContext); + mask = hlfir::loadTrivialScalar(loc, *builder, mask); + auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask); + genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext); + } void genFIR(const Fortran::parser::WhereConstructStmt &stmt) { - implicitIterSpace.append(Fortran::semantics::GetExpr( - std::get<Fortran::parser::LogicalExpr>(stmt.t))); + const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr( + std::get<Fortran::parser::LogicalExpr>(stmt.t)); + if (lowerToHighLevelFIR()) + lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr); + else + implicitIterSpace.append(maskExpr); } void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) { + mlir::Location loc = getCurrentLocation(); + hlfir::ElseWhereOp elsewhereOp; + if (lowerToHighLevelFIR()) { + elsewhereOp = builder->create<hlfir::ElseWhereOp>(loc); + // Lower mask in the mask region. + builder->createBlock(&elsewhereOp.getMaskRegion()); + } genNestedStatement( std::get< Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>( ew.t)); + + // For HLFIR, lower the body in the hlfir.elsewhere body region. + if (elsewhereOp) + builder->createBlock(&elsewhereOp.getBody()); + for (const auto &body : std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t)) genFIR(body); } void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) { - implicitIterSpace.append(Fortran::semantics::GetExpr( - std::get<Fortran::parser::LogicalExpr>(stmt.t))); + const auto *maskExpr = Fortran::semantics::GetExpr( + std::get<Fortran::parser::LogicalExpr>(stmt.t)); + if (lowerToHighLevelFIR()) + lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr); + else + implicitIterSpace.append(maskExpr); } void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) { + if (lowerToHighLevelFIR()) { + auto elsewhereOp = + builder->create<hlfir::ElseWhereOp>(getCurrentLocation()); + builder->createBlock(&elsewhereOp.getBody()); + } genNestedStatement( std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>( ew.t)); @@ -3525,18 +3591,32 @@ private: genFIR(body); } void genFIR(const Fortran::parser::ElsewhereStmt &stmt) { - implicitIterSpace.append(nullptr); + if (!lowerToHighLevelFIR()) + implicitIterSpace.append(nullptr); } void genFIR(const Fortran::parser::EndWhereStmt &) { - implicitIterSpace.shrinkStack(); + if (!lowerToHighLevelFIR()) + implicitIterSpace.shrinkStack(); } void genFIR(const Fortran::parser::WhereStmt &stmt) { Fortran::lower::StatementContext stmtCtx; const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t); + const auto *mask = Fortran::semantics::GetExpr( + std::get<Fortran::parser::LogicalExpr>(stmt.t)); + if (lowerToHighLevelFIR()) { + mlir::Location loc = getCurrentLocation(); + auto whereOp = builder->create<hlfir::WhereOp>(loc); + builder->createBlock(&whereOp.getMaskRegion()); + lowerWhereMaskToHlfir(loc, mask); + builder->createBlock(&whereOp.getBody()); + genAssignment(*assign.typedAssignment->v); + builder->create<fir::FirEndOp>(loc); + builder->setInsertionPointAfter(whereOp); + return; + } implicitIterSpace.growStack(); - implicitIterSpace.append(Fortran::semantics::GetExpr( - std::get<Fortran::parser::LogicalExpr>(stmt.t))); + implicitIterSpace.append(mask); genAssignment(*assign.typedAssignment->v); implicitIterSpace.shrinkStack(); } diff --git a/flang/test/Lower/HLFIR/where.f90 b/flang/test/Lower/HLFIR/where.f90 new file mode 100644 index 000000000000..88e49c9d740a --- /dev/null +++ b/flang/test/Lower/HLFIR/where.f90 @@ -0,0 +1,170 @@ +! Test lowering of WHERE construct and statements to HLFIR. +! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s + +module where_defs + logical :: mask(10) + real :: x(10), y(10) + real, allocatable :: a(:), b(:) + interface + function return_temporary_mask() + logical, allocatable :: return_temporary_mask(:) + end function + function return_temporary_array() + real, allocatable :: return_temporary_array(:) + end function + end interface +end module + +subroutine simple_where() + use where_defs, only: mask, x, y + where (mask) x = y +end subroutine +! CHECK-LABEL: func.func @_QPsimple_where() { +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_3]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: } +! CHECK: return +! CHECK:} + +subroutine where_construct() + use where_defs + where (mask) + x = y + a = b + end where +end subroutine +! CHECK-LABEL: func.func @_QPwhere_construct() { +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEa"} +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEb"} +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: hlfir.region_assign { +! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> +! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> +! CHECK: } to { +! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> +! CHECK: hlfir.yield %[[VAL_17]] : !fir.box<!fir.heap<!fir.array<?xf32>>> +! CHECK: } +! CHECK: } +! CHECK: return +! CHECK:} + +subroutine where_cleanup() + use where_defs, only: x, return_temporary_mask, return_temporary_array + where (return_temporary_mask()) x = return_temporary_array() +end subroutine +! CHECK-LABEL: func.func @_QPwhere_cleanup() { +! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = ".result"} +! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {bindc_name = ".result"} +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: hlfir.where { +! CHECK: %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> +! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>> +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) +! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>> +! CHECK: hlfir.yield %[[VAL_8]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> cleanup { +! CHECK: fir.freemem +! CHECK: } +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?xf32>>> +! CHECK: fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) +! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> +! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> cleanup { +! CHECK: fir.freemem +! CHECK: } +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: } + +subroutine simple_elsewhere() + use where_defs + where (mask) + x = y + elsewhere + y = x + end where +end subroutine +! CHECK-LABEL: func.func @_QPsimple_elsewhere() { +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: hlfir.elsewhere do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: } +! CHECK: } + +subroutine elsewhere_2(mask2) + use where_defs, only : mask, x, y + logical :: mask2(:) + where (mask) + x = y + elsewhere(mask2) + y = x + elsewhere + x = foo() + end where +end subroutine +! CHECK-LABEL: func.func @_QPelsewhere_2( +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2 +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: hlfir.elsewhere mask { +! CHECK: hlfir.yield %[[VAL_6]]#0 : !fir.box<!fir.array<?x!fir.logical<4>>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: hlfir.elsewhere do { +! CHECK: hlfir.region_assign { +! CHECK: %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> f32 +! CHECK: hlfir.yield %[[VAL_16]] : f32 +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>> +! CHECK: } +! CHECK: } +! CHECK: } +! CHECK: } |