[flang-commits] [flang] 54c88fc - [flang][hlfir] Lower WHERE to HLFIR
Jean Perier via flang-commits
flang-commits at lists.llvm.org
Tue May 9 00:21:59 PDT 2023
Author: Jean Perier
Date: 2023-05-09T09:21:27+02:00
New Revision: 54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25
URL: https://github.com/llvm/llvm-project/commit/54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25
DIFF: https://github.com/llvm/llvm-project/commit/54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25.diff
LOG: [flang][hlfir] Lower WHERE to HLFIR
Lower WHERE to the newly added hlfir.where and hlfir.elsewhere
operations.
Differential Revision: https://reviews.llvm.org/D149950
Added:
flang/test/Lower/HLFIR/where.f90
Modified:
flang/lib/Lower/Bridge.cpp
Removed:
################################################################################
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index fe86fe8cb2dd..acf3768dfdd8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3154,7 +3154,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Gather some information about the assignment that will impact how it is
// lowered.
const bool isWholeAllocatableAssignment =
- !userDefinedAssignment &&
+ !userDefinedAssignment && !isInsideHlfirWhere() &&
Fortran::lower::isWholeAllocatable(assign.lhs);
std::optional<Fortran::evaluate::DynamicType> lhsType =
assign.lhs.GetType();
@@ -3243,8 +3243,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genAssignment(const Fortran::evaluate::Assignment &assign) {
mlir::Location loc = toLocation();
if (lowerToHighLevelFIR()) {
- if (!implicitIterSpace.empty())
- TODO(loc, "HLFIR assignment inside WHERE");
std::visit(
Fortran::common::visitors{
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
@@ -3452,23 +3450,47 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
}
- bool isInsideHlfirForallOrWhere() const {
+ // Is the insertion point of the builder directly or indirectly set
+ // inside any operation of type "Op"?
+ template <typename... Op>
+ bool isInsideOp() const {
mlir::Block *block = builder->getInsertionBlock();
mlir::Operation *op = block ? block->getParentOp() : nullptr;
while (op) {
- if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
+ if (mlir::isa<Op...>(op))
return true;
op = op->getParentOp();
}
return false;
}
+ bool isInsideHlfirForallOrWhere() const {
+ return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>();
+ }
+ bool isInsideHlfirWhere() const { return isInsideOp<hlfir::WhereOp>(); }
void genFIR(const Fortran::parser::WhereConstruct &c) {
- implicitIterSpace.growStack();
+ mlir::Location loc = getCurrentLocation();
+ hlfir::WhereOp whereOp;
+
+ if (!lowerToHighLevelFIR()) {
+ implicitIterSpace.growStack();
+ } else {
+ whereOp = builder->create<hlfir::WhereOp>(loc);
+ builder->createBlock(&whereOp.getMaskRegion());
+ }
+
+ // Lower the where mask. For HLFIR, this is done in the hlfir.where mask
+ // region.
genNestedStatement(
std::get<
Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>(
c.t));
+
+ // Lower WHERE body. For HLFIR, this is done in the hlfir.where body
+ // region.
+ if (whereOp)
+ builder->createBlock(&whereOp.getBody());
+
for (const auto &body :
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t))
genFIR(body);
@@ -3484,6 +3506,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genNestedStatement(
std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>(
c.t));
+
+ if (whereOp) {
+ // For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or
+ // in the hlfir.where if it had no elsewhere.
+ builder->create<fir::FirEndOp>(loc);
+ builder->setInsertionPointAfter(whereOp);
+ }
}
void genFIR(const Fortran::parser::WhereBodyConstruct &body) {
std::visit(
@@ -3499,24 +3528,61 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
body.u);
}
+
+ /// Lower a Where or Elsewhere mask into an hlfir mask region.
+ void lowerWhereMaskToHlfir(mlir::Location loc,
+ const Fortran::semantics::SomeExpr *maskExpr) {
+ assert(maskExpr && "mask semantic analysis failed");
+ Fortran::lower::StatementContext maskContext;
+ hlfir::Entity mask = Fortran::lower::convertExprToHLFIR(
+ loc, *this, *maskExpr, localSymbols, maskContext);
+ mask = hlfir::loadTrivialScalar(loc, *builder, mask);
+ auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
+ genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
+ }
void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
- implicitIterSpace.append(Fortran::semantics::GetExpr(
- std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+ const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::LogicalExpr>(stmt.t));
+ if (lowerToHighLevelFIR())
+ lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
+ else
+ implicitIterSpace.append(maskExpr);
}
void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
+ mlir::Location loc = getCurrentLocation();
+ hlfir::ElseWhereOp elsewhereOp;
+ if (lowerToHighLevelFIR()) {
+ elsewhereOp = builder->create<hlfir::ElseWhereOp>(loc);
+ // Lower mask in the mask region.
+ builder->createBlock(&elsewhereOp.getMaskRegion());
+ }
genNestedStatement(
std::get<
Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>(
ew.t));
+
+ // For HLFIR, lower the body in the hlfir.elsewhere body region.
+ if (elsewhereOp)
+ builder->createBlock(&elsewhereOp.getBody());
+
for (const auto &body :
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t))
genFIR(body);
}
void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) {
- implicitIterSpace.append(Fortran::semantics::GetExpr(
- std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+ const auto *maskExpr = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::LogicalExpr>(stmt.t));
+ if (lowerToHighLevelFIR())
+ lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
+ else
+ implicitIterSpace.append(maskExpr);
}
void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) {
+ if (lowerToHighLevelFIR()) {
+ auto elsewhereOp =
+ builder->create<hlfir::ElseWhereOp>(getCurrentLocation());
+ builder->createBlock(&elsewhereOp.getBody());
+ }
genNestedStatement(
std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>(
ew.t));
@@ -3525,18 +3591,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genFIR(body);
}
void genFIR(const Fortran::parser::ElsewhereStmt &stmt) {
- implicitIterSpace.append(nullptr);
+ if (!lowerToHighLevelFIR())
+ implicitIterSpace.append(nullptr);
}
void genFIR(const Fortran::parser::EndWhereStmt &) {
- implicitIterSpace.shrinkStack();
+ if (!lowerToHighLevelFIR())
+ implicitIterSpace.shrinkStack();
}
void genFIR(const Fortran::parser::WhereStmt &stmt) {
Fortran::lower::StatementContext stmtCtx;
const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t);
+ const auto *mask = Fortran::semantics::GetExpr(
+ std::get<Fortran::parser::LogicalExpr>(stmt.t));
+ if (lowerToHighLevelFIR()) {
+ mlir::Location loc = getCurrentLocation();
+ auto whereOp = builder->create<hlfir::WhereOp>(loc);
+ builder->createBlock(&whereOp.getMaskRegion());
+ lowerWhereMaskToHlfir(loc, mask);
+ builder->createBlock(&whereOp.getBody());
+ genAssignment(*assign.typedAssignment->v);
+ builder->create<fir::FirEndOp>(loc);
+ builder->setInsertionPointAfter(whereOp);
+ return;
+ }
implicitIterSpace.growStack();
- implicitIterSpace.append(Fortran::semantics::GetExpr(
- std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+ implicitIterSpace.append(mask);
genAssignment(*assign.typedAssignment->v);
implicitIterSpace.shrinkStack();
}
diff --git a/flang/test/Lower/HLFIR/where.f90 b/flang/test/Lower/HLFIR/where.f90
new file mode 100644
index 000000000000..88e49c9d740a
--- /dev/null
+++ b/flang/test/Lower/HLFIR/where.f90
@@ -0,0 +1,170 @@
+! Test lowering of WHERE construct and statements to HLFIR.
+! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s
+
+module where_defs
+ logical :: mask(10)
+ real :: x(10), y(10)
+ real, allocatable :: a(:), b(:)
+ interface
+ function return_temporary_mask()
+ logical, allocatable :: return_temporary_mask(:)
+ end function
+ function return_temporary_array()
+ real, allocatable :: return_temporary_array(:)
+ end function
+ end interface
+end module
+
+subroutine simple_where()
+ use where_defs, only: mask, x, y
+ where (mask) x = y
+end subroutine
+! CHECK-LABEL: func.func @_QPsimple_where() {
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_3]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+! CHECK: return
+! CHECK:}
+
+subroutine where_construct()
+ use where_defs
+ where (mask)
+ x = y
+ a = b
+ end where
+end subroutine
+! CHECK-LABEL: func.func @_QPwhere_construct() {
+! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEa"}
+! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEb"}
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.region_assign {
+! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: } to {
+! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: hlfir.yield %[[VAL_17]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: }
+! CHECK: }
+! CHECK: return
+! CHECK:}
+
+subroutine where_cleanup()
+ use where_defs, only: x, return_temporary_mask, return_temporary_array
+ where (return_temporary_mask()) x = return_temporary_array()
+end subroutine
+! CHECK-LABEL: func.func @_QPwhere_cleanup() {
+! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = ".result"}
+! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {bindc_name = ".result"}
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: hlfir.where {
+! CHECK: %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>)
+! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+! CHECK: hlfir.yield %[[VAL_8]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> cleanup {
+! CHECK: fir.freemem
+! CHECK: }
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> cleanup {
+! CHECK: fir.freemem
+! CHECK: }
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+
+subroutine simple_elsewhere()
+ use where_defs
+ where (mask)
+ x = y
+ elsewhere
+ y = x
+ end where
+end subroutine
+! CHECK-LABEL: func.func @_QPsimple_elsewhere() {
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.elsewhere do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+! CHECK: }
+
+subroutine elsewhere_2(mask2)
+ use where_defs, only : mask, x, y
+ logical :: mask2(:)
+ where (mask)
+ x = y
+ elsewhere(mask2)
+ y = x
+ elsewhere
+ x = foo()
+ end where
+end subroutine
+! CHECK-LABEL: func.func @_QPelsewhere_2(
+! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2
+! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK: hlfir.where {
+! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.elsewhere mask {
+! CHECK: hlfir.yield %[[VAL_6]]#0 : !fir.box<!fir.array<?x!fir.logical<4>>>
+! CHECK: } do {
+! CHECK: hlfir.region_assign {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: hlfir.elsewhere do {
+! CHECK: hlfir.region_assign {
+! CHECK: %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> f32
+! CHECK: hlfir.yield %[[VAL_16]] : f32
+! CHECK: } to {
+! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK: }
+! CHECK: }
+! CHECK: }
+! CHECK: }
More information about the flang-commits
mailing list