[flang-commits] [flang] 3b390a1 - [flang][OpenMP] Support for Collapse
Mats Petersson via flang-commits
flang-commits at lists.llvm.org
Thu May 19 07:40:05 PDT 2022
Author: Mats Petersson
Date: 2022-05-19T15:39:48+01:00
New Revision: 3b390a1682232a0d6921692f72fac65ec4374597
URL: https://github.com/llvm/llvm-project/commit/3b390a1682232a0d6921692f72fac65ec4374597
DIFF: https://github.com/llvm/llvm-project/commit/3b390a1682232a0d6921692f72fac65ec4374597.diff
LOG: [flang][OpenMP] Support for Collapse
Convert Fortran parse-tree into MLIR for collapse-clause.
Includes simple Fortran to LLVM-IR test, with auto-generated
check-lines (some of which have been edited by hand).
Reviewed By: kiranchandramohan, shraiysh, peixin
Differential Revision: https://reviews.llvm.org/D125302
Added:
flang/test/Lower/OpenMP/omp-wsloop-collapse.f90
Modified:
flang/include/flang/Lower/OpenMP.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/OpenMP.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index df44c6752acc6..c666376c2030f 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -13,10 +13,14 @@
#ifndef FORTRAN_LOWER_OPENMP_H
#define FORTRAN_LOWER_OPENMP_H
+#include <cinttypes>
+
namespace Fortran {
namespace parser {
struct OpenMPConstruct;
struct OpenMPDeclarativeConstruct;
+struct OmpEndLoopDirective;
+struct OmpClauseList;
} // namespace parser
namespace lower {
@@ -31,6 +35,7 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
const parser::OpenMPConstruct &);
void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
const parser::OpenMPDeclarativeConstruct &);
+int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a5d456d82518b..e41f525e8217a 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1401,15 +1401,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
- Fortran::lower::genOpenMPConstruct(*this, getEval(), omp);
+ genOpenMPConstruct(*this, getEval(), omp);
+
+ const Fortran::parser::OpenMPLoopConstruct *ompLoop =
+ std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
// If loop is part of an OpenMP Construct then the OpenMP dialect
// workshare loop operation has already been created. Only the
// body needs to be created here and the do_loop can be skipped.
- Fortran::lower::pft::Evaluation *curEval =
- std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)
- ? &getEval().getFirstNestedEvaluation()
- : &getEval();
+ // Skip the number of collapsed loops, which is 1 when there is a
+ // no collapse requested.
+
+ Fortran::lower::pft::Evaluation *curEval = &getEval();
+ if (ompLoop) {
+ const auto &wsLoopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
+ std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
+ int64_t collapseValue =
+ Fortran::lower::getCollapseValue(wsLoopOpClauseList);
+
+ curEval = &curEval->getFirstNestedEvaluation();
+ for (int64_t i = 1; i < collapseValue; i++) {
+ curEval = &*std::next(curEval->getNestedEvaluations().begin());
+ }
+ }
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 06f546d453270..571cf5c85a87c 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -25,6 +25,18 @@
using namespace mlir;
+int64_t Fortran::lower::getCollapseValue(
+ const Fortran::parser::OmpClauseList &clauseList) {
+ for (const auto &clause : clauseList.v) {
+ if (const auto &collapseClause =
+ std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
+ const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
+ return Fortran::evaluate::ToInt64(*expr).value();
+ }
+ }
+ return 1;
+}
+
static const Fortran::parser::Name *
getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
@@ -108,22 +120,42 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
}
}
+/// Create the body (block) for an OpenMP Operation.
+///
+/// \param [in] op - the operation the body belongs to.
+/// \param [inout] converter - converter to use for the clauses.
+/// \param [in] loc - location in source code.
+/// \oaran [in] clauses - list of clauses to process.
+/// \param [in] args - block arguments (induction variable[s]) for the
+//// region.
+/// \param [in] outerCombined - is this an outer operation - prevents
+/// privatization.
template <typename Op>
static void
createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
mlir::Location &loc,
const Fortran::parser::OmpClauseList *clauses = nullptr,
- const Fortran::semantics::Symbol *arg = nullptr,
+ const SmallVector<const Fortran::semantics::Symbol *> &args = {},
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- // If an argument for the region is provided then create the block with that
- // argument. Also update the symbol's address with the mlir argument value.
- // e.g. For loops the argument is the induction variable. And all further
+ // If arguments for the region are provided then create the block with those
+ // arguments. Also update the symbol's address with the mlir argument values.
+ // e.g. For loops the arguments are the induction variable. And all further
// uses of the induction variable should use this mlir value.
- if (arg) {
- firOpBuilder.createBlock(&op.getRegion(), {}, {converter.genType(*arg)},
- {loc});
- converter.bindSymbol(*arg, op.getRegion().front().getArgument(0));
+ if (args.size()) {
+ SmallVector<Type> tiv;
+ SmallVector<Location> locs;
+ int argIndex = 0;
+ for (auto &arg : args) {
+ tiv.push_back(converter.genType(*arg));
+ locs.push_back(loc);
+ }
+ firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
+ for (auto &arg : args) {
+ fir::ExtendedValue exval = op.getRegion().front().getArgument(argIndex);
+ converter.bindSymbol(*arg, exval);
+ argIndex++;
+ }
} else {
firOpBuilder.createBlock(&op.getRegion());
}
@@ -394,38 +426,44 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
TODO(converter.getCurrentLocation(), "Combined worksharing loop construct");
}
- Fortran::lower::pft::Evaluation *doConstructEval =
- &eval.getFirstNestedEvaluation();
-
- Fortran::lower::pft::Evaluation *doLoop =
- &doConstructEval->getFirstNestedEvaluation();
- auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
- assert(doStmt && "Expected do loop to be in the nested evaluation");
- const auto &loopControl =
- std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
- const Fortran::parser::LoopControl::Bounds *bounds =
- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
- assert(bounds && "Expected bounds for worksharing do loop");
- Fortran::semantics::Symbol *iv = nullptr;
- Fortran::lower::StatementContext stmtCtx;
- lowerBound.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
- upperBound.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
- if (bounds->step) {
- step.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
- } else { // If `step` is not present, assume it as `1`.
- step.push_back(firOpBuilder.createIntegerConstant(
- currentLocation, firOpBuilder.getIntegerType(32), 1));
- }
- iv = bounds->name.thing.symbol;
+ int64_t collapseValue = Fortran::lower::getCollapseValue(wsLoopOpClauseList);
+
+ // Collect the loops to collapse.
+ auto *doConstructEval = &eval.getFirstNestedEvaluation();
+
+ SmallVector<const Fortran::semantics::Symbol *> iv;
+ do {
+ auto *doLoop = &doConstructEval->getFirstNestedEvaluation();
+ auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ const auto &loopControl =
+ std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
+ const Fortran::parser::LoopControl::Bounds *bounds =
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for worksharing do loop");
+ Fortran::lower::StatementContext stmtCtx;
+ lowerBound.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
+ upperBound.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
+ if (bounds->step) {
+ step.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+ } else { // If `step` is not present, assume it as `1`.
+ step.push_back(firOpBuilder.createIntegerConstant(
+ currentLocation, firOpBuilder.getIntegerType(32), 1));
+ }
+ iv.push_back(bounds->name.thing.symbol);
+
+ collapseValue--;
+ doConstructEval =
+ &*std::next(doConstructEval->getNestedEvaluations().begin());
+ } while (collapseValue > 0);
// FIXME: Add support for following clauses:
// 1. linear
// 2. order
- // 3. collapse
- // 4. schedule (with chunk)
+ // 3. schedule (with chunk)
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
reductionVars, /*reductions=*/nullptr,
@@ -451,6 +489,13 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
} else {
wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0));
}
+ } else if (const auto &collapseClause =
+ std::get_if<Fortran::parser::OmpClause::Collapse>(
+ &clause.u)) {
+ const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
+ const std::optional<std::int64_t> collapseValue =
+ Fortran::evaluate::ToInt64(*expr);
+ wsLoopOp.collapse_valAttr(firOpBuilder.getI64IntegerAttr(*collapseValue));
} else if (const auto &scheduleClause =
std::get_if<Fortran::parser::OmpClause::Schedule>(
&clause.u)) {
diff --git a/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90 b/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90
new file mode 100644
index 0000000000000..8197909a9dd43
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive(Worksharing) with collapse.
+
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+
+program wsloop_collapse
+ integer :: i, j, k
+ integer :: a, b, c
+ integer :: x
+! CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+! CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFEc"}
+! CHECK: %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
+! CHECK: %[[VAL_4:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFEj"}
+! CHECK: %[[VAL_5:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFEk"}
+! CHECK: %[[VAL_6:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+ a=3
+! CHECK: %[[VAL_7:.*]] = arith.constant 3 : i32
+! CHECK: fir.store %[[VAL_7]] to %[[VAL_0]] : !fir.ref<i32>
+ b=2
+! CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32
+! CHECK: fir.store %[[VAL_8]] to %[[VAL_1]] : !fir.ref<i32>
+ c=5
+! CHECK: %[[VAL_9:.*]] = arith.constant 5 : i32
+! CHECK: fir.store %[[VAL_9]] to %[[VAL_2]] : !fir.ref<i32>
+ x=0
+! CHECK: %[[VAL_10:.*]] = arith.constant 0 : i32
+! CHECK: fir.store %[[VAL_10]] to %[[VAL_6]] : !fir.ref<i32>
+
+ !$omp do collapse(3)
+! CHECK: %[[VAL_20:.*]] = arith.constant 1 : i32
+! CHECK: %[[VAL_21:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+! CHECK: %[[VAL_22:.*]] = arith.constant 1 : i32
+! CHECK: %[[VAL_23:.*]] = arith.constant 1 : i32
+! CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+! CHECK: %[[VAL_25:.*]] = arith.constant 1 : i32
+! CHECK: %[[VAL_26:.*]] = arith.constant 1 : i32
+! CHECK: %[[VAL_27:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
+! CHECK: %[[VAL_28:.*]] = arith.constant 1 : i32
+ do i = 1, a
+ do j= 1, b
+ do k = 1, c
+! CHECK: omp.wsloop collapse(3) for (%[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]]) : i32 = (%[[VAL_20]], %[[VAL_23]], %[[VAL_26]]) to (%[[VAL_21]], %[[VAL_24]], %[[VAL_27]]) inclusive step (%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]) {
+! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_6]] : !fir.ref<i32>
+! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : i32
+! CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
+! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : i32
+! CHECK: fir.store %[[VAL_15]] to %[[VAL_6]] : !fir.ref<i32>
+! CHECK: omp.yield
+! CHECK: }
+ x = x + i + j + k
+ end do
+ end do
+ end do
+ !$omp end do
+! CHECK: return
+! CHECK: }
+end program wsloop_collapse
More information about the flang-commits
mailing list