[flang-commits] [flang] 9bf5093 - [flang][openmp] Parallel reduction FIR lowering

Ethan Luis McDonough via flang-commits flang-commits at lists.llvm.org
Fri Jun 30 13:36:34 PDT 2023


Author: Ethan Luis McDonough
Date: 2023-06-30T15:36:27-05:00
New Revision: 9bf50936237ded3fc324f4686dc4c1c5a9fb16eb

URL: https://github.com/llvm/llvm-project/commit/9bf50936237ded3fc324f4686dc4c1c5a9fb16eb
DIFF: https://github.com/llvm/llvm-project/commit/9bf50936237ded3fc324f4686dc4c1c5a9fb16eb.diff

LOG: [flang][openmp] Parallel reduction FIR lowering

This patch extends the logic for lowering loop construct reductions to parallel block reductions.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D154182

Added: 
    flang/test/Lower/OpenMP/parallel-reduction-add.f90

Modified: 
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/OpenMP.cpp

Removed: 
    flang/test/Lower/OpenMP/Todo/parallel-reduction.f90


################################################################################
diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 087d32032278c3..3fee12b8b6e0ee 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2172,6 +2172,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     const Fortran::parser::OpenMPLoopConstruct *ompLoop =
         std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
+    const Fortran::parser::OpenMPBlockConstruct *ompBlock =
+        std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
 
     // If loop is part of an OpenMP Construct then the OpenMP dialect
     // workshare loop operation has already been created. Only the
@@ -2196,8 +2198,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
       genFIR(e);
 
-    if (ompLoop)
+    if (ompLoop) {
       genOpenMPReduction(*this, *loopOpClauseList);
+    } else if (ompBlock) {
+      const auto &blockStart =
+          std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
+      const auto &blockClauses =
+          std::get<Fortran::parser::OmpClauseList>(blockStart.t);
+      genOpenMPReduction(*this, blockClauses);
+    }
 
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);

diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 67b88d2033aa25..61eb1cc7d2a957 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1154,209 +1154,6 @@ createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
                                   /*isCombined=*/true);
 }
 
-static void
-genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
-       const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
-  const auto &beginBlockDirective =
-      std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
-  const auto &blockDirective =
-      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
-  const auto &endBlockDirective =
-      std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.genLocation(blockDirective.source);
-
-  Fortran::lower::StatementContext stmtCtx;
-  llvm::ArrayRef<mlir::Type> argTy;
-  mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
-      priorityClauseOperand;
-  mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
-  SmallVector<Value> allocateOperands, allocatorOperands, dependOperands;
-  SmallVector<Attribute> dependTypeOperands;
-  mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
-
-  const auto &opClauseList =
-      std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
-  for (const auto &clause : opClauseList.v) {
-    mlir::Location clauseLocation = converter.genLocation(clause.source);
-    if (const auto &ifClause =
-            std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
-      ifClauseOperand =
-          getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
-    } else if (const auto &numThreadsClause =
-                   std::get_if<Fortran::parser::OmpClause::NumThreads>(
-                       &clause.u)) {
-      // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
-      numThreadsClauseOperand = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
-    } else if (const auto &procBindClause =
-                   std::get_if<Fortran::parser::OmpClause::ProcBind>(
-                       &clause.u)) {
-      procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
-    } else if (const auto &allocateClause =
-                   std::get_if<Fortran::parser::OmpClause::Allocate>(
-                       &clause.u)) {
-      genAllocateClause(converter, allocateClause->v, allocatorOperands,
-                        allocateOperands);
-    } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
-               std::get_if<Fortran::parser::OmpClause::Firstprivate>(
-                   &clause.u) ||
-               std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
-      // Privatisation and copyin clauses are handled elsewhere.
-      continue;
-    } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
-      // Shared is the default behavior in the IR, so no handling is required.
-      continue;
-    } else if (const auto &defaultClause =
-                   std::get_if<Fortran::parser::OmpClause::Default>(
-                       &clause.u)) {
-      if ((defaultClause->v.v ==
-           Fortran::parser::OmpDefaultClause::Type::Shared) ||
-          (defaultClause->v.v ==
-           Fortran::parser::OmpDefaultClause::Type::None)) {
-        // Default clause with shared or none do not require any handling since
-        // Shared is the default behavior in the IR and None is only required
-        // for semantic checks.
-        continue;
-      }
-    } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
-      // Nothing needs to be done for threads clause.
-      continue;
-    } else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
-      // Map clause is exclusive to Target Data directives. It is handled
-      // as part of the TargetOp creation.
-      continue;
-    } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
-                   &clause.u)) {
-      // UseDevicePtr clause is exclusive to Target Data directives. It is
-      // handled as part of the TargetOp creation.
-      continue;
-    } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
-                   &clause.u)) {
-      // UseDeviceAddr clause is exclusive to Target Data directives. It is
-      // handled as part of the TargetOp creation.
-      continue;
-    } else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
-                   &clause.u)) {
-      // Handled as part of TargetOp creation.
-      continue;
-    } else if (const auto &finalClause =
-                   std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
-      mlir::Value finalVal = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
-      finalClauseOperand = firOpBuilder.createConvert(
-          currentLocation, firOpBuilder.getI1Type(), finalVal);
-    } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
-      untiedAttr = firOpBuilder.getUnitAttr();
-    } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
-      mergeableAttr = firOpBuilder.getUnitAttr();
-    } else if (const auto &priorityClause =
-                   std::get_if<Fortran::parser::OmpClause::Priority>(
-                       &clause.u)) {
-      priorityClauseOperand = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
-    } else if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
-      TODO(currentLocation,
-           "Reduction in OpenMP " +
-               llvm::omp::getOpenMPDirectiveName(blockDirective.v) +
-               " construct");
-    } else if (const auto &dependClause =
-                   std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
-      const std::list<Fortran::parser::Designator> &depVal =
-          std::get<std::list<Fortran::parser::Designator>>(
-              std::get<Fortran::parser::OmpDependClause::InOut>(
-                  dependClause->v.u)
-                  .t);
-      omp::ClauseTaskDependAttr dependTypeOperand =
-          genDependKindAttr(firOpBuilder, dependClause);
-      dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
-                                dependTypeOperand);
-      for (const Fortran::parser::Designator &ompObject : depVal) {
-        Fortran::semantics::Symbol *sym = nullptr;
-        std::visit(
-            Fortran::common::visitors{
-                [&](const Fortran::parser::DataRef &designator) {
-                  if (const Fortran::parser::Name *name =
-                          std::get_if<Fortran::parser::Name>(&designator.u)) {
-                    sym = name->symbol;
-                  } else if (std::get_if<Fortran::common::Indirection<
-                                 Fortran::parser::ArrayElement>>(
-                                 &designator.u)) {
-                    TODO(converter.getCurrentLocation(),
-                         "array sections not supported for task depend");
-                  }
-                },
-                [&](const Fortran::parser::Substring &designator) {
-                  TODO(converter.getCurrentLocation(),
-                       "substring not supported for task depend");
-                }},
-            (ompObject).u);
-        const mlir::Value variable = converter.getSymbolAddress(*sym);
-        dependOperands.push_back(((variable)));
-      }
-    } else {
-      TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
-    }
-  }
-
-  for (const auto &clause :
-       std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
-    if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
-      nowaitAttr = firOpBuilder.getUnitAttr();
-  }
-
-  if (blockDirective.v == llvm::omp::OMPD_parallel) {
-    // Create and insert the operation.
-    auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
-        currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
-        allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
-        /*reductions=*/nullptr, procBindKindAttr);
-    createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
-                                    eval, &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_master) {
-    auto masterOp =
-        firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
-    createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
-  } else if (blockDirective.v == llvm::omp::OMPD_single) {
-    auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
-        currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
-    createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval,
-                                  &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
-    auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
-        currentLocation, /*simd=*/false);
-    createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
-                                         eval);
-  } else if (blockDirective.v == llvm::omp::OMPD_task) {
-    auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
-        currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
-        mergeableAttr, /*in_reduction_vars=*/ValueRange(),
-        /*in_reductions=*/nullptr, priorityClauseOperand,
-        dependTypeOperands.empty()
-            ? nullptr
-            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
-                                   dependTypeOperands),
-        dependOperands, allocateOperands, allocatorOperands);
-    createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
-    // TODO: Add task_reduction support
-    auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
-        currentLocation, /*task_reduction_vars=*/ValueRange(),
-        /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
-    createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
-                   &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_target) {
-    createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
-                   &eval);
-  } else if (blockDirective.v == llvm::omp::OMPD_target_data) {
-    createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
-                   &eval);
-  } else {
-    TODO(currentLocation, "Unhandled block directive");
-  }
-}
-
 /// This function returns the identity value of the operator \p reductionOpName.
 /// For example:
 ///    0 + x = x,
@@ -1691,6 +1488,97 @@ static std::string getReductionName(
   return getReductionName(reductionName, ty);
 }
 
+/// Creates a reduction declaration and associates it with an
+/// OpenMP block directive
+static void
+addReductionDecl(mlir::Location currentLocation,
+                 Fortran::lower::AbstractConverter &converter,
+                 const Fortran::parser::OmpReductionClause &reduction,
+                 SmallVector<Value> &reductionVars,
+                 SmallVector<Attribute> &reductionDeclSymbols) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  omp::ReductionDeclareOp decl;
+  const auto &redOperator{
+      std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+  const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+  if (const auto &redDefinedOp =
+          std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+    const auto &intrinsicOp{
+        std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+            redDefinedOp->u)};
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      break;
+
+    default:
+      TODO(currentLocation,
+           "Reduction of some intrinsic operators is not supported");
+      break;
+    }
+    for (const auto &ompObject : objectList.v) {
+      if (const auto *name{
+              Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+        if (const auto *symbol{name->symbol}) {
+          mlir::Value symVal = converter.getSymbolAddress(*symbol);
+          mlir::Type redType =
+              symVal.getType().cast<fir::ReferenceType>().getEleTy();
+          reductionVars.push_back(symVal);
+          if (redType.isa<fir::LogicalType>())
+            decl = createReductionDecl(
+                firOpBuilder,
+                getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
+                intrinsicOp, redType, currentLocation);
+          else if (redType.isIntOrIndexOrFloat()) {
+            decl = createReductionDecl(firOpBuilder,
+                                       getReductionName(intrinsicOp, redType),
+                                       intrinsicOp, redType, currentLocation);
+          } else {
+            TODO(currentLocation, "Reduction of some types is not supported");
+          }
+          reductionDeclSymbols.push_back(
+              SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+        }
+      }
+    }
+  } else if (auto reductionIntrinsic =
+                 std::get_if<Fortran::parser::ProcedureDesignator>(
+                     &redOperator.u)) {
+    if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+            reductionIntrinsic)}) {
+      if ((name->source != "max") && (name->source != "min") &&
+          (name->source != "ior") && (name->source != "ieor") &&
+          (name->source != "iand")) {
+        TODO(currentLocation,
+             "Reduction of intrinsic procedures is not supported");
+      }
+      std::string intrinsicOp = name->ToString();
+      for (const auto &ompObject : objectList.v) {
+        if (const auto *name{
+                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+          if (const auto *symbol{name->symbol}) {
+            mlir::Value symVal = converter.getSymbolAddress(*symbol);
+            mlir::Type redType =
+                symVal.getType().cast<fir::ReferenceType>().getEleTy();
+            reductionVars.push_back(symVal);
+            assert(redType.isIntOrIndexOrFloat() &&
+                   "Unsupported reduction type");
+            decl = createReductionDecl(
+                firOpBuilder, getReductionName(intrinsicOp, redType),
+                *reductionIntrinsic, redType, currentLocation);
+            reductionDeclSymbols.push_back(SymbolRefAttr::get(
+                firOpBuilder.getContext(), decl.getSymName()));
+          }
+        }
+      }
+    }
+  }
+}
+
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -1786,88 +1674,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     } else if (const auto &reductionClause =
                    std::get_if<Fortran::parser::OmpClause::Reduction>(
                        &clause.u)) {
-      omp::ReductionDeclareOp decl;
-      const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
-          reductionClause->v.t)};
-      const auto &objectList{
-          std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
-      if (const auto &redDefinedOp =
-              std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
-        const auto &intrinsicOp{
-            std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
-                redDefinedOp->u)};
-        switch (intrinsicOp) {
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-          break;
-
-        default:
-          TODO(currentLocation,
-               "Reduction of some intrinsic operators is not supported");
-          break;
-        }
-        for (const auto &ompObject : objectList.v) {
-          if (const auto *name{
-                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const auto *symbol{name->symbol}) {
-              mlir::Value symVal = converter.getSymbolAddress(*symbol);
-              mlir::Type redType =
-                  symVal.getType().cast<fir::ReferenceType>().getEleTy();
-              reductionVars.push_back(symVal);
-              if (redType.isa<fir::LogicalType>())
-                decl = createReductionDecl(
-                    firOpBuilder,
-                    getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                    intrinsicOp, redType, currentLocation);
-              else if (redType.isIntOrIndexOrFloat()) {
-                decl = createReductionDecl(
-                    firOpBuilder, getReductionName(intrinsicOp, redType),
-                    intrinsicOp, redType, currentLocation);
-              } else {
-                TODO(currentLocation,
-                     "Reduction of some types is not supported");
-              }
-              reductionDeclSymbols.push_back(SymbolRefAttr::get(
-                  firOpBuilder.getContext(), decl.getSymName()));
-            }
-          }
-        }
-      } else if (auto reductionIntrinsic =
-                     std::get_if<Fortran::parser::ProcedureDesignator>(
-                         &redOperator.u)) {
-        if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
-                reductionIntrinsic)}) {
-          if ((name->source != "max") && (name->source != "min") &&
-              (name->source != "ior") && (name->source != "ieor") &&
-              (name->source != "iand")) {
-            TODO(currentLocation,
-                 "Reduction of intrinsic procedures is not supported");
-          }
-          std::string intrinsicOp = name->ToString();
-          for (const auto &ompObject : objectList.v) {
-            if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
-                    ompObject)}) {
-              if (const auto *symbol{name->symbol}) {
-                mlir::Value symVal = converter.getSymbolAddress(*symbol);
-                mlir::Type redType =
-                    symVal.getType().cast<fir::ReferenceType>().getEleTy();
-                reductionVars.push_back(symVal);
-                assert(redType.isIntOrIndexOrFloat() &&
-                       "Unsupported reduction type");
-                decl = createReductionDecl(
-                    firOpBuilder, getReductionName(intrinsicOp, redType),
-                    *reductionIntrinsic, redType, currentLocation);
-                reductionDeclSymbols.push_back(SymbolRefAttr::get(
-                    firOpBuilder.getContext(), decl.getSymName()));
-              }
-            }
-          }
-        }
-      }
+      addReductionDecl(currentLocation, converter, reductionClause->v,
+                       reductionVars, reductionDeclSymbols);
     } else if (const auto &simdlenClause =
                    std::get_if<Fortran::parser::OmpClause::Simdlen>(
                        &clause.u)) {
@@ -2003,6 +1811,214 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
                                 &loopOpClauseList, iv, /*outer=*/false, &dsp);
 }
 
+static void
+genOMP(Fortran::lower::AbstractConverter &converter,
+       Fortran::lower::pft::Evaluation &eval,
+       const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
+  const auto &beginBlockDirective =
+      std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
+  const auto &blockDirective =
+      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
+  const auto &endBlockDirective =
+      std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.genLocation(blockDirective.source);
+
+  Fortran::lower::StatementContext stmtCtx;
+  llvm::ArrayRef<mlir::Type> argTy;
+  mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
+      priorityClauseOperand;
+  mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
+  SmallVector<Value> allocateOperands, allocatorOperands, dependOperands,
+      reductionVars;
+  SmallVector<Attribute> dependTypeOperands, reductionDeclSymbols;
+  mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
+
+  const auto &opClauseList =
+      std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
+  for (const auto &clause : opClauseList.v) {
+    mlir::Location clauseLocation = converter.genLocation(clause.source);
+    if (const auto &ifClause =
+            std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+      ifClauseOperand =
+          getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
+    } else if (const auto &numThreadsClause =
+                   std::get_if<Fortran::parser::OmpClause::NumThreads>(
+                       &clause.u)) {
+      // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
+      numThreadsClauseOperand = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
+    } else if (const auto &procBindClause =
+                   std::get_if<Fortran::parser::OmpClause::ProcBind>(
+                       &clause.u)) {
+      procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
+    } else if (const auto &allocateClause =
+                   std::get_if<Fortran::parser::OmpClause::Allocate>(
+                       &clause.u)) {
+      genAllocateClause(converter, allocateClause->v, allocatorOperands,
+                        allocateOperands);
+    } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
+               std::get_if<Fortran::parser::OmpClause::Firstprivate>(
+                   &clause.u) ||
+               std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
+      // Privatisation and copyin clauses are handled elsewhere.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
+      // Shared is the default behavior in the IR, so no handling is required.
+      continue;
+    } else if (const auto &defaultClause =
+                   std::get_if<Fortran::parser::OmpClause::Default>(
+                       &clause.u)) {
+      if ((defaultClause->v.v ==
+           Fortran::parser::OmpDefaultClause::Type::Shared) ||
+          (defaultClause->v.v ==
+           Fortran::parser::OmpDefaultClause::Type::None)) {
+        // Default clause with shared or none do not require any handling since
+        // Shared is the default behavior in the IR and None is only required
+        // for semantic checks.
+        continue;
+      }
+    } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
+      // Nothing needs to be done for threads clause.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
+      // Map clause is exclusive to Target Data directives. It is handled
+      // as part of the TargetOp creation.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+                   &clause.u)) {
+      // UseDevicePtr clause is exclusive to Target Data directives. It is
+      // handled as part of the TargetOp creation.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+                   &clause.u)) {
+      // UseDeviceAddr clause is exclusive to Target Data directives. It is
+      // handled as part of the TargetOp creation.
+      continue;
+    } else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
+                   &clause.u)) {
+      // Handled as part of TargetOp creation.
+      continue;
+    } else if (const auto &finalClause =
+                   std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
+      mlir::Value finalVal = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
+      finalClauseOperand = firOpBuilder.createConvert(
+          currentLocation, firOpBuilder.getI1Type(), finalVal);
+    } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
+      untiedAttr = firOpBuilder.getUnitAttr();
+    } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
+      mergeableAttr = firOpBuilder.getUnitAttr();
+    } else if (const auto &priorityClause =
+                   std::get_if<Fortran::parser::OmpClause::Priority>(
+                       &clause.u)) {
+      priorityClauseOperand = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
+    } else if (const auto &reductionClause =
+                   std::get_if<Fortran::parser::OmpClause::Reduction>(
+                       &clause.u)) {
+      addReductionDecl(currentLocation, converter, reductionClause->v,
+                       reductionVars, reductionDeclSymbols);
+    } else if (const auto &dependClause =
+                   std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
+      const std::list<Fortran::parser::Designator> &depVal =
+          std::get<std::list<Fortran::parser::Designator>>(
+              std::get<Fortran::parser::OmpDependClause::InOut>(
+                  dependClause->v.u)
+                  .t);
+      omp::ClauseTaskDependAttr dependTypeOperand =
+          genDependKindAttr(firOpBuilder, dependClause);
+      dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
+                                dependTypeOperand);
+      for (const Fortran::parser::Designator &ompObject : depVal) {
+        Fortran::semantics::Symbol *sym = nullptr;
+        std::visit(
+            Fortran::common::visitors{
+                [&](const Fortran::parser::DataRef &designator) {
+                  if (const Fortran::parser::Name *name =
+                          std::get_if<Fortran::parser::Name>(&designator.u)) {
+                    sym = name->symbol;
+                  } else if (std::get_if<Fortran::common::Indirection<
+                                 Fortran::parser::ArrayElement>>(
+                                 &designator.u)) {
+                    TODO(converter.getCurrentLocation(),
+                         "array sections not supported for task depend");
+                  }
+                },
+                [&](const Fortran::parser::Substring &designator) {
+                  TODO(converter.getCurrentLocation(),
+                       "substring not supported for task depend");
+                }},
+            (ompObject).u);
+        const mlir::Value variable = converter.getSymbolAddress(*sym);
+        dependOperands.push_back(((variable)));
+      }
+    } else {
+      TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
+    }
+  }
+
+  for (const auto &clause :
+       std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
+    if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
+      nowaitAttr = firOpBuilder.getUnitAttr();
+  }
+
+  if (blockDirective.v == llvm::omp::OMPD_parallel) {
+    // Create and insert the operation.
+    auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
+        currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
+        allocateOperands, allocatorOperands, reductionVars,
+        reductionDeclSymbols.empty()
+            ? nullptr
+            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
+                                   reductionDeclSymbols),
+        procBindKindAttr);
+    createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
+                                    eval, &opClauseList);
+  } else if (blockDirective.v == llvm::omp::OMPD_master) {
+    auto masterOp =
+        firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
+    createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
+  } else if (blockDirective.v == llvm::omp::OMPD_single) {
+    auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
+        currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
+    createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval,
+                                  &opClauseList);
+  } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
+    auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
+        currentLocation, /*simd=*/false);
+    createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
+                                         eval);
+  } else if (blockDirective.v == llvm::omp::OMPD_task) {
+    auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
+        currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
+        mergeableAttr, /*in_reduction_vars=*/ValueRange(),
+        /*in_reductions=*/nullptr, priorityClauseOperand,
+        dependTypeOperands.empty()
+            ? nullptr
+            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
+                                   dependTypeOperands),
+        dependOperands, allocateOperands, allocatorOperands);
+    createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
+  } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
+    // TODO: Add task_reduction support
+    auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
+        currentLocation, /*task_reduction_vars=*/ValueRange(),
+        /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
+    createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
+                   &opClauseList);
+  } else if (blockDirective.v == llvm::omp::OMPD_target) {
+    createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
+                   &eval);
+  } else if (blockDirective.v == llvm::omp::OMPD_target_data) {
+    createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
+                   &eval);
+  } else {
+    TODO(currentLocation, "Unhandled block directive");
+  }
+}
+
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
        Fortran::lower::pft::Evaluation &eval,

diff  --git a/flang/test/Lower/OpenMP/Todo/parallel-reduction.f90 b/flang/test/Lower/OpenMP/Todo/parallel-reduction.f90
deleted file mode 100644
index df13f1faefe806..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/parallel-reduction.f90
+++ /dev/null
@@ -1,11 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction in OpenMP parallel construct
-subroutine reduction_parallel
-  integer :: x
-  !$omp parallel reduction(+:x)
-  x = x + i
-  !$omp end parallel
-  print *, x
-end subroutine

diff  --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
new file mode 100644
index 00000000000000..cab076f817cadd
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
@@ -0,0 +1,97 @@
+! RUN: bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : f32 init {
+!CHECK: ^bb0(%{{.*}}: f32):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  omp.yield(%[[C0_1]] : f32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32):
+!CHECK:  %[[RES:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}}: f32
+!CHECK:  omp.yield(%[[RES]] : f32)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK:  omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK:  %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK:  omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel   reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
+!CHECK:    %[[I_INCR:.*]] = arith.constant 1 : i32
+!CHECK:    omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_int_add
+    integer :: i
+    i = 0
+
+    !$omp parallel reduction(+:i)
+    i = i + 1
+    !$omp end parallel
+
+    print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  omp.parallel   reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK:    omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_real_add
+    real :: r
+    r = 0.0
+
+    !$omp parallel reduction(+:r)
+    r = r + 1.5
+    !$omp end parallel
+
+    print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel   reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK:    omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK:    %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK:    omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine int_real_add
+    real :: r
+    integer :: i
+
+    r = 0.0
+    i = 0
+
+    !$omp parallel reduction(+:i,r)
+    r = 1.5 + r
+    i = i + 3
+    !$omp end parallel
+
+    print *, r
+    print *, i
+end subroutine


        


More information about the flang-commits mailing list