summaryrefslogtreecommitdiff
path: root/flang
diff options
context:
space:
mode:
authorJean Perier <jperier@nvidia.com>2023-05-09 09:21:09 +0200
committerJean Perier <jperier@nvidia.com>2023-05-09 09:21:27 +0200
commit54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25 (patch)
tree8ca864ee51019096d67928c1990c7fd87ce1b107 /flang
parentb87e65531c58df55cfae4c06c7a68f84539aa779 (diff)
downloadllvm-54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25.tar.gz
[flang][hlfir] Lower WHERE to HLFIR
Lower WHERE to the newly added hlfir.where and hlfir.elsewhere operations. Differential Revision: https://reviews.llvm.org/D149950
Diffstat (limited to 'flang')
-rw-r--r--flang/lib/Lower/Bridge.cpp108
-rw-r--r--flang/test/Lower/HLFIR/where.f90170
2 files changed, 264 insertions, 14 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index fe86fe8cb2dd..acf3768dfdd8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3154,7 +3154,7 @@ private:
// Gather some information about the assignment that will impact how it is
// lowered.
const bool isWholeAllocatableAssignment =
- !userDefinedAssignment &&
+ !userDefinedAssignment && !isInsideHlfirWhere() &&
Fortran::lower::isWholeAllocatable(assign.lhs);
std::optional<Fortran::evaluate::DynamicType> lhsType =
assign.lhs.GetType();
@@ -3243,8 +3243,6 @@ private:
void genAssignment(const Fortran::evaluate::Assignment &assign) {
mlir::Location loc = toLocation();
if (lowerToHighLevelFIR()) {
- if (!implicitIterSpace.empty())
- TODO(loc, "HLFIR assignment inside WHERE");
std::visit(
Fortran::common::visitors{
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
@@ -3452,23 +3450,47 @@ private:
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
}
- bool isInsideHlfirForallOrWhere() const {
+ // Is the insertion point of the builder directly or indirectly set
+ // inside any operation of type "Op"?
+ template <typename... Op>
+ bool isInsideOp() const {
mlir::Block *block = builder->getInsertionBlock();
mlir::Operation *op = block ? block->getParentOp() : nullptr;
while (op) {
- if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
+ if (mlir::isa<Op...>(op))
return true;
op = op->getParentOp();
}
return false;
}
+ bool isInsideHlfirForallOrWhere() const {
+ return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>();
+ }
+ bool isInsideHlfirWhere() const { return isInsideOp<hlfir::WhereOp>(); }
void genFIR(const Fortran::parser::WhereConstruct &c) {
- implicitIterSpace.growStack();
+ mlir::Location loc = getCurrentLocation();
+ hlfir::WhereOp whereOp;
+
+ if (!lowerToHighLevelFIR()) {
+ implicitIterSpace.growStack();
+ } else {
+ whereOp = builder->create<hlfir::WhereOp>(loc);
+ builder->createBlock(&whereOp.getMaskRegion());
+ }
+
+ // Lower the where mask. For HLFIR, this is done in the hlfir.where mask
+ // region.
genNestedStatement(
std::get<
Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>(
c.t));
+
+ // Lower WHERE body. For HLFIR, this is done in the hlfir.where body
+ // region.
+ if (whereOp)
+ builder->createBlock(&whereOp.getBody());
+
for (const auto &body :
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t))
genFIR(body);
@@ -3484,6 +3506,13 @@ private:
genNestedStatement(
std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>(
c.t));
+
+ if (whereOp) {
+ // For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or
+ // in the hlfir.where if it had no elsewhere.
+ builder->create<fir::FirEndOp>(loc);
+ builder->setInsertionPointAfter(whereOp);
+ }
}
void genFIR(const Fortran::parser::WhereBodyConstruct &body) {
std::visit(
@@ -3499,24 +3528,61 @@ private:
},
body.u);
}
+
+ /// Lower a Where or Elsewhere mask into an hlfir mask region.
+ void lowerWhereMaskToHlfir(mlir::Location loc,
+ const Fortran::semantics::SomeExpr *maskExpr) {
+ assert(maskExpr && "mask semantic analysis failed");
+ Fortran::lower::StatementContext maskContext;
+ hlfir::Entity mask = Fortran::lower::convertExprToHLFIR(
+ loc, *this, *maskExpr, localSymbols, maskContext);
+ mask = hlfir::loadTrivialScalar(loc, *builder, mask);
+ auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
+ genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
+ }
void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
- implicitIterSpace.append(Fortran::semantics::GetExpr(
- std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+ const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::LogicalExpr>(stmt.t));
+ if (lowerToHighLevelFIR())
+ lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
+ else
+ implicitIterSpace.append(maskExpr);
}
void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
+ mlir::Location loc = getCurrentLocation();
+ hlfir::ElseWhereOp elsewhereOp;
+ if (lowerToHighLevelFIR()) {
+ elsewhereOp = builder->create<hlfir::ElseWhereOp>(loc);
+ // Lower mask in the mask region.
+ builder->createBlock(&elsewhereOp.getMaskRegion());
+ }
genNestedStatement(
std::get<
Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>(
ew.t));
+
+ // For HLFIR, lower the body in the hlfir.elsewhere body region.
+ if (elsewhereOp)
+ builder->createBlock(&elsewhereOp.getBody());
+
for (const auto &body :
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t))
genFIR(body);
}
void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) {
- implicitIterSpace.append(Fortran::semantics::GetExpr(
- std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+ const auto *maskExpr = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::LogicalExpr>(stmt.t));
+ if (lowerToHighLevelFIR())
+ lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
+ else
+ implicitIterSpace.append(maskExpr);
}
void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) {
+ if (lowerToHighLevelFIR()) {
+ auto elsewhereOp =
+ builder->create<hlfir::ElseWhereOp>(getCurrentLocation());
+ builder->createBlock(&elsewhereOp.getBody());
+ }
genNestedStatement(
std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>(
ew.t));
@@ -3525,18 +3591,32 @@ private:
genFIR(body);
}
void genFIR(const Fortran::parser::ElsewhereStmt &stmt) {
- implicitIterSpace.append(nullptr);
+ if (!lowerToHighLevelFIR())
+ implicitIterSpace.append(nullptr);
}
void genFIR(const Fortran::parser::EndWhereStmt &) {
- implicitIterSpace.shrinkStack();
+ if (!lowerToHighLevelFIR())
+ implicitIterSpace.shrinkStack();
}
void genFIR(const Fortran::parser::WhereStmt &stmt) {
Fortran::lower::StatementContext stmtCtx;
const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t);
+ const auto *mask = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::LogicalExpr>(stmt.t));
+ if (lowerToHighLevelFIR()) {
+ mlir::Location loc = getCurrentLocation();
+ auto whereOp = builder->create<hlfir::WhereOp>(loc);
+ builder->createBlock(&whereOp.getMaskRegion());
+ lowerWhereMaskToHlfir(loc, mask);
+ builder->createBlock(&whereOp.getBody());
+ genAssignment(*assign.typedAssignment->v);
+ builder->create<fir::FirEndOp>(loc);
+ builder->setInsertionPointAfter(whereOp);
+ return;
+ }
implicitIterSpace.growStack();
- implicitIterSpace.append(Fortran::semantics::GetExpr(
- std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+ implicitIterSpace.append(mask);
genAssignment(*assign.typedAssignment->v);
implicitIterSpace.shrinkStack();
}
diff --git a/flang/test/Lower/HLFIR/where.f90 b/flang/test/Lower/HLFIR/where.f90
new file mode 100644
index 000000000000..88e49c9d740a
--- /dev/null
+++ b/flang/test/Lower/HLFIR/where.f90
@@ -0,0 +1,170 @@
+! Test lowering of WHERE construct and statements to HLFIR.
+! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s
+
+module where_defs
+ logical :: mask(10)
+ real :: x(10), y(10)
+ real, allocatable :: a(:), b(:)
+ interface
+ function return_temporary_mask()
+ logical, allocatable :: return_temporary_mask(:)
+ end function
+ function return_temporary_array()
+ real, allocatable :: return_temporary_array(:)
+ end function
+ end interface
+end module
+
+subroutine simple_where()
+ use where_defs, only: mask, x, y
+ where (mask) x = y
+end subroutine
+! CHECK-LABEL: func.func @_QPsimple_where() {
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_3]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+! CHECK: return
+! CHECK:}
+
+subroutine where_construct()
+ use where_defs
+ where (mask)
+ x = y
+ a = b
+ end where
+end subroutine
+! CHECK-LABEL: func.func @_QPwhere_construct() {
+! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEa"}
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEb"}
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.region_assign {
+! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: } to {
+! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: hlfir.yield %[[VAL_17]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: }
+! CHECK: }
+! CHECK: return
+! CHECK:}
+
+subroutine where_cleanup()
+ use where_defs, only: x, return_temporary_mask, return_temporary_array
+ where (return_temporary_mask()) x = return_temporary_array()
+end subroutine
+! CHECK-LABEL: func.func @_QPwhere_cleanup() {
+! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = ".result"}
+! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {bindc_name = ".result"}
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: hlfir.where {
+! CHECK: %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>)
+! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+! CHECK: hlfir.yield %[[VAL_8]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> cleanup {
+! CHECK: fir.freemem
+! CHECK: }
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> cleanup {
+! CHECK: fir.freemem
+! CHECK: }
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+
+subroutine simple_elsewhere()
+ use where_defs
+ where (mask)
+ x = y
+ elsewhere
+ y = x
+ end where
+end subroutine
+! CHECK-LABEL: func.func @_QPsimple_elsewhere() {
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.elsewhere do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+! CHECK: }
+
+subroutine elsewhere_2(mask2)
+ use where_defs, only : mask, x, y
+ logical :: mask2(:)
+ where (mask)
+ x = y
+ elsewhere(mask2)
+ y = x
+ elsewhere
+ x = foo()
+ end where
+end subroutine
+! CHECK-LABEL: func.func @_QPelsewhere_2(
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.elsewhere mask {
+! CHECK: hlfir.yield %[[VAL_6]]#0 : !fir.box<!fir.array<?x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.elsewhere do {
+! CHECK: hlfir.region_assign {
+! CHECK: %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> f32
+! CHECK: hlfir.yield %[[VAL_16]] : f32
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+! CHECK: }
+! CHECK: }