[flang-commits] [flang] 3b390a1 - [flang][OpenMP] Support for Collapse

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Thu May 19 07:40:05 PDT 2022


Author: Mats Petersson
Date: 2022-05-19T15:39:48+01:00
New Revision: 3b390a1682232a0d6921692f72fac65ec4374597

URL: https://github.com/llvm/llvm-project/commit/3b390a1682232a0d6921692f72fac65ec4374597
DIFF: https://github.com/llvm/llvm-project/commit/3b390a1682232a0d6921692f72fac65ec4374597.diff

LOG: [flang][OpenMP] Support for Collapse

Convert Fortran parse-tree into MLIR for collapse-clause.

Includes simple Fortran to LLVM-IR test, with auto-generated
check-lines (some of which have been edited by hand).

Reviewed By: kiranchandramohan, shraiysh, peixin

Differential Revision: https://reviews.llvm.org/D125302

Added: 
    flang/test/Lower/OpenMP/omp-wsloop-collapse.f90

Modified: 
    flang/include/flang/Lower/OpenMP.h
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index df44c6752acc6..c666376c2030f 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -13,10 +13,14 @@
 #ifndef FORTRAN_LOWER_OPENMP_H
 #define FORTRAN_LOWER_OPENMP_H
 
+#include <cinttypes>
+
 namespace Fortran {
 namespace parser {
 struct OpenMPConstruct;
 struct OpenMPDeclarativeConstruct;
+struct OmpEndLoopDirective;
+struct OmpClauseList;
 } // namespace parser
 
 namespace lower {
@@ -31,6 +35,7 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
                         const parser::OpenMPConstruct &);
 void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
                                    const parser::OpenMPDeclarativeConstruct &);
+int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
 
 } // namespace lower
 } // namespace Fortran

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a5d456d82518b..e41f525e8217a 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1401,15 +1401,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
-    Fortran::lower::genOpenMPConstruct(*this, getEval(), omp);
+    genOpenMPConstruct(*this, getEval(), omp);
+
+    const Fortran::parser::OpenMPLoopConstruct *ompLoop =
+        std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
 
     // If loop is part of an OpenMP Construct then the OpenMP dialect
     // workshare loop operation has already been created. Only the
     // body needs to be created here and the do_loop can be skipped.
-    Fortran::lower::pft::Evaluation *curEval =
-        std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)
-            ? &getEval().getFirstNestedEvaluation()
-            : &getEval();
+    // Skip the number of collapsed loops, which is 1 when there is a
+    // no collapse requested.
+
+    Fortran::lower::pft::Evaluation *curEval = &getEval();
+    if (ompLoop) {
+      const auto &wsLoopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
+          std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
+      int64_t collapseValue =
+          Fortran::lower::getCollapseValue(wsLoopOpClauseList);
+
+      curEval = &curEval->getFirstNestedEvaluation();
+      for (int64_t i = 1; i < collapseValue; i++) {
+        curEval = &*std::next(curEval->getNestedEvaluations().begin());
+      }
+    }
 
     for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
       genFIR(e);

diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 06f546d453270..571cf5c85a87c 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -25,6 +25,18 @@
 
 using namespace mlir;
 
+int64_t Fortran::lower::getCollapseValue(
+    const Fortran::parser::OmpClauseList &clauseList) {
+  for (const auto &clause : clauseList.v) {
+    if (const auto &collapseClause =
+            std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
+      const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
+      return Fortran::evaluate::ToInt64(*expr).value();
+    }
+  }
+  return 1;
+}
+
 static const Fortran::parser::Name *
 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
@@ -108,22 +120,42 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
   }
 }
 
+/// Create the body (block) for an OpenMP Operation.
+///
+/// \param [in]    op - the operation the body belongs to.
+/// \param [inout] converter - converter to use for the clauses.
+/// \param [in]    loc - location in source code.
+/// \oaran [in]    clauses - list of clauses to process.
+/// \param [in]    args - block arguments (induction variable[s]) for the
+////                      region.
+/// \param [in]    outerCombined - is this an outer operation - prevents
+///                                privatization.
 template <typename Op>
 static void
 createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
                mlir::Location &loc,
                const Fortran::parser::OmpClauseList *clauses = nullptr,
-               const Fortran::semantics::Symbol *arg = nullptr,
+               const SmallVector<const Fortran::semantics::Symbol *> &args = {},
                bool outerCombined = false) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  // If an argument for the region is provided then create the block with that
-  // argument. Also update the symbol's address with the mlir argument value.
-  // e.g. For loops the argument is the induction variable. And all further
+  // If arguments for the region are provided then create the block with those
+  // arguments. Also update the symbol's address with the mlir argument values.
+  // e.g. For loops the arguments are the induction variable. And all further
   // uses of the induction variable should use this mlir value.
-  if (arg) {
-    firOpBuilder.createBlock(&op.getRegion(), {}, {converter.genType(*arg)},
-                             {loc});
-    converter.bindSymbol(*arg, op.getRegion().front().getArgument(0));
+  if (args.size()) {
+    SmallVector<Type> tiv;
+    SmallVector<Location> locs;
+    int argIndex = 0;
+    for (auto &arg : args) {
+      tiv.push_back(converter.genType(*arg));
+      locs.push_back(loc);
+    }
+    firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
+    for (auto &arg : args) {
+      fir::ExtendedValue exval = op.getRegion().front().getArgument(argIndex);
+      converter.bindSymbol(*arg, exval);
+      argIndex++;
+    }
   } else {
     firOpBuilder.createBlock(&op.getRegion());
   }
@@ -394,38 +426,44 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     TODO(converter.getCurrentLocation(), "Combined worksharing loop construct");
   }
 
-  Fortran::lower::pft::Evaluation *doConstructEval =
-      &eval.getFirstNestedEvaluation();
-
-  Fortran::lower::pft::Evaluation *doLoop =
-      &doConstructEval->getFirstNestedEvaluation();
-  auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
-  assert(doStmt && "Expected do loop to be in the nested evaluation");
-  const auto &loopControl =
-      std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
-  const Fortran::parser::LoopControl::Bounds *bounds =
-      std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
-  assert(bounds && "Expected bounds for worksharing do loop");
-  Fortran::semantics::Symbol *iv = nullptr;
-  Fortran::lower::StatementContext stmtCtx;
-  lowerBound.push_back(fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
-  upperBound.push_back(fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
-  if (bounds->step) {
-    step.push_back(fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
-  } else { // If `step` is not present, assume it as `1`.
-    step.push_back(firOpBuilder.createIntegerConstant(
-        currentLocation, firOpBuilder.getIntegerType(32), 1));
-  }
-  iv = bounds->name.thing.symbol;
+  int64_t collapseValue = Fortran::lower::getCollapseValue(wsLoopOpClauseList);
+
+  // Collect the loops to collapse.
+  auto *doConstructEval = &eval.getFirstNestedEvaluation();
+
+  SmallVector<const Fortran::semantics::Symbol *> iv;
+  do {
+    auto *doLoop = &doConstructEval->getFirstNestedEvaluation();
+    auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
+    assert(doStmt && "Expected do loop to be in the nested evaluation");
+    const auto &loopControl =
+        std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
+    const Fortran::parser::LoopControl::Bounds *bounds =
+        std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+    assert(bounds && "Expected bounds for worksharing do loop");
+    Fortran::lower::StatementContext stmtCtx;
+    lowerBound.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
+    upperBound.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
+    if (bounds->step) {
+      step.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+    } else { // If `step` is not present, assume it as `1`.
+      step.push_back(firOpBuilder.createIntegerConstant(
+          currentLocation, firOpBuilder.getIntegerType(32), 1));
+    }
+    iv.push_back(bounds->name.thing.symbol);
+
+    collapseValue--;
+    doConstructEval =
+        &*std::next(doConstructEval->getNestedEvaluations().begin());
+  } while (collapseValue > 0);
 
   // FIXME: Add support for following clauses:
   // 1. linear
   // 2. order
-  // 3. collapse
-  // 4. schedule (with chunk)
+  // 3. schedule (with chunk)
   auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
       currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
       reductionVars, /*reductions=*/nullptr,
@@ -451,6 +489,13 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
       } else {
         wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0));
       }
+    } else if (const auto &collapseClause =
+                   std::get_if<Fortran::parser::OmpClause::Collapse>(
+                       &clause.u)) {
+      const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
+      const std::optional<std::int64_t> collapseValue =
+          Fortran::evaluate::ToInt64(*expr);
+      wsLoopOp.collapse_valAttr(firOpBuilder.getI64IntegerAttr(*collapseValue));
     } else if (const auto &scheduleClause =
                    std::get_if<Fortran::parser::OmpClause::Schedule>(
                        &clause.u)) {

diff  --git a/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90 b/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90
new file mode 100644
index 0000000000000..8197909a9dd43
--- /dev/null
+++ b/flang/test/Lower/OpenMP/omp-wsloop-collapse.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive(Worksharing) with collapse.
+
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+
+program wsloop_collapse
+  integer :: i, j, k
+  integer :: a, b, c
+  integer :: x
+! CHECK:         %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+! CHECK:         %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFEc"}
+! CHECK:         %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
+! CHECK:         %[[VAL_4:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFEj"}
+! CHECK:         %[[VAL_5:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFEk"}
+! CHECK:         %[[VAL_6:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+  a=3
+! CHECK:         %[[VAL_7:.*]] = arith.constant 3 : i32
+! CHECK:         fir.store %[[VAL_7]] to %[[VAL_0]] : !fir.ref<i32>
+  b=2
+! CHECK:         %[[VAL_8:.*]] = arith.constant 2 : i32
+! CHECK:         fir.store %[[VAL_8]] to %[[VAL_1]] : !fir.ref<i32>
+  c=5
+! CHECK:         %[[VAL_9:.*]] = arith.constant 5 : i32
+! CHECK:         fir.store %[[VAL_9]] to %[[VAL_2]] : !fir.ref<i32>
+  x=0
+! CHECK:         %[[VAL_10:.*]] = arith.constant 0 : i32
+! CHECK:         fir.store %[[VAL_10]] to %[[VAL_6]] : !fir.ref<i32>
+
+  !$omp do collapse(3)
+! CHECK:           %[[VAL_20:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_21:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+! CHECK:           %[[VAL_22:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_23:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_24:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+! CHECK:           %[[VAL_25:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_26:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_27:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
+! CHECK:           %[[VAL_28:.*]] = arith.constant 1 : i32
+  do i = 1, a
+     do j= 1, b
+        do k = 1, c
+! CHECK:           omp.wsloop collapse(3) for (%[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]]) : i32 = (%[[VAL_20]], %[[VAL_23]], %[[VAL_26]]) to (%[[VAL_21]], %[[VAL_24]], %[[VAL_27]]) inclusive step (%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]) {
+! CHECK:             %[[VAL_12:.*]] = fir.load %[[VAL_6]] : !fir.ref<i32>
+! CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : i32
+! CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
+! CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : i32
+! CHECK:             fir.store %[[VAL_15]] to %[[VAL_6]] : !fir.ref<i32>
+! CHECK:             omp.yield
+! CHECK:           }
+           x = x + i + j + k
+        end do
+     end do
+  end do
+  !$omp end do
+! CHECK:         return
+! CHECK:       }
+end program wsloop_collapse


        


More information about the flang-commits mailing list