[flang-commits] [flang] ccd92ec - [flang][openmp] Changes for invoking scan Op (#123254)

via flang-commits flang-commits at lists.llvm.org
Wed Feb 5 06:55:35 PST 2025


Author: Anchu Rajendran S
Date: 2025-02-05T06:55:32-08:00
New Revision: ccd92ec4c6ceb09e75ed40c96c1da7d03b9c45d5

URL: https://github.com/llvm/llvm-project/commit/ccd92ec4c6ceb09e75ed40c96c1da7d03b9c45d5
DIFF: https://github.com/llvm/llvm-project/commit/ccd92ec4c6ceb09e75ed40c96c1da7d03b9c45d5.diff

LOG: [flang][openmp] Changes for invoking scan Op (#123254)

Added: 
    flang/test/Lower/OpenMP/scan.f90

Modified: 
    flang/lib/Lower/OpenMP/ClauseProcessor.cpp
    flang/lib/Lower/OpenMP/ClauseProcessor.h
    flang/lib/Lower/OpenMP/Clauses.cpp
    flang/lib/Lower/OpenMP/OpenMP.cpp
    flang/lib/Lower/OpenMP/ReductionProcessor.cpp
    flang/lib/Lower/OpenMP/ReductionProcessor.h
    flang/test/Lower/OpenMP/Todo/reduction-task.f90
    mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Removed: 
    flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
    flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..febc6adcf9d6ff 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -344,6 +344,20 @@ bool ClauseProcessor::processDistSchedule(
   return false;
 }
 
+bool ClauseProcessor::processExclusive(
+    mlir::Location currentLocation,
+    mlir::omp::ExclusiveClauseOps &result) const {
+  if (auto *clause = findUniqueClause<omp::clause::Exclusive>()) {
+    for (const Object &object : clause->v) {
+      const semantics::Symbol *symbol = object.sym();
+      mlir::Value symVal = converter.getSymbolAddress(*symbol);
+      result.exclusiveVars.push_back(symVal);
+    }
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
                                     mlir::omp::FilterClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +394,20 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
   return false;
 }
 
+bool ClauseProcessor::processInclusive(
+    mlir::Location currentLocation,
+    mlir::omp::InclusiveClauseOps &result) const {
+  if (auto *clause = findUniqueClause<omp::clause::Inclusive>()) {
+    for (const Object &object : clause->v) {
+      const semantics::Symbol *symbol = object.sym();
+      mlir::Value symVal = converter.getSymbolAddress(*symbol);
+      result.inclusiveVars.push_back(symVal);
+    }
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processMergeable(
     mlir::omp::MergeableClauseOps &result) const {
   return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1163,9 @@ bool ClauseProcessor::processReduction(
         llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
         llvm::SmallVector<const semantics::Symbol *> reductionSyms;
         ReductionProcessor rp;
-        rp.addDeclareReduction(currentLocation, converter, clause,
-                               reductionVars, reduceVarByRef,
-                               reductionDeclSymbols, reductionSyms);
-
+        rp.processReductionArguments(
+            currentLocation, converter, clause, reductionVars, reduceVarByRef,
+            reductionDeclSymbols, reductionSyms, result.reductionMod);
         // Copy local lists into the output.
         llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
         llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));

diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7b047d4a7567ad..e05f66c7666844 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -64,6 +64,8 @@ class ClauseProcessor {
   bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
   bool processDistSchedule(lower::StatementContext &stmtCtx,
                            mlir::omp::DistScheduleClauseOps &result) const;
+  bool processExclusive(mlir::Location currentLocation,
+                        mlir::omp::ExclusiveClauseOps &result) const;
   bool processFilter(lower::StatementContext &stmtCtx,
                      mlir::omp::FilterClauseOps &result) const;
   bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
       mlir::omp::HasDeviceAddrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
   bool processHint(mlir::omp::HintClauseOps &result) const;
+  bool processInclusive(mlir::Location currentLocation,
+                        mlir::omp::InclusiveClauseOps &result) const;
   bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
   bool processNowait(mlir::omp::NowaitClauseOps &result) const;
   bool processNumTeams(lower::StatementContext &stmtCtx,

diff  --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index db6486abc7ea1e..5664d8ab2a5d88 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -736,8 +736,8 @@ Enter make(const parser::OmpClause::Enter &inp,
 
 Exclusive make(const parser::OmpClause::Exclusive &inp,
                semantics::SemanticsContext &semaCtx) {
-  // inp -> empty
-  llvm_unreachable("Empty: exclusive");
+  // inp.v -> parser::OmpObjectList
+  return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
 }
 
 Fail make(const parser::OmpClause::Fail &inp,
@@ -846,8 +846,8 @@ If make(const parser::OmpClause::If &inp,
 
 Inclusive make(const parser::OmpClause::Inclusive &inp,
                semantics::SemanticsContext &semaCtx) {
-  // inp -> empty
-  llvm_unreachable("Empty: inclusive");
+  // inp.v -> parser::OmpObjectList
+  return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
 }
 
 Indirect make(const parser::OmpClause::Indirect &inp,

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5e1f3b0208869b..0ae04ed9baf450 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1584,6 +1584,15 @@ static void genParallelClauses(
   cp.processReduction(loc, clauseOps, reductionSyms);
 }
 
+static void genScanClauses(lower::AbstractConverter &converter,
+                           semantics::SemanticsContext &semaCtx,
+                           const List<Clause> &clauses, mlir::Location loc,
+                           mlir::omp::ScanOperands &clauseOps) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processInclusive(loc, clauseOps);
+  cp.processExclusive(loc, clauseOps);
+}
+
 static void genSectionsClauses(
     lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
     const List<Clause> &clauses, mlir::Location loc,
@@ -1981,6 +1990,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return parallelOp;
 }
 
+static mlir::omp::ScanOp
+genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+          semantics::SemanticsContext &semaCtx, mlir::Location loc,
+          const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+  mlir::omp::ScanOperands clauseOps;
+  genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+  return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
+      converter.getCurrentLocation(), clauseOps);
+}
+
 /// This breaks the normal prototype of the gen*Op functions: adding the
 /// sectionBlocks argument so that the enclosed section constructs can be
 /// lowered here with correct reduction symbol remapping.
@@ -2990,7 +3009,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scan:
-    TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+    genScanOp(converter, symTable, semaCtx, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_section:
     llvm_unreachable("genOMPDispatch: OMPD_section");

diff  --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 4a811f1bdfdf53..f83079eb68688d 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -31,6 +31,9 @@ static llvm::cl::opt<bool> forceByrefReduction(
     llvm::cl::desc("Pass all reduction arguments by reference"),
     llvm::cl::Hidden);
 
+using ReductionModifier =
+    Fortran::lower::omp::clause::Reduction::ReductionModifier;
+
 namespace Fortran {
 namespace lower {
 namespace omp {
@@ -518,18 +521,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
   return false;
 }
 
-void ReductionProcessor::addDeclareReduction(
+mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
+  switch (mod) {
+  case ReductionModifier::Default:
+    return mlir::omp::ReductionModifier::defaultmod;
+  case ReductionModifier::Inscan:
+    return mlir::omp::ReductionModifier::inscan;
+  case ReductionModifier::Task:
+    return mlir::omp::ReductionModifier::task;
+  }
+  return mlir::omp::ReductionModifier::defaultmod;
+}
+
+void ReductionProcessor::processReductionArguments(
     mlir::Location currentLocation, lower::AbstractConverter &converter,
     const omp::clause::Reduction &reduction,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
     llvm::SmallVectorImpl<bool> &reduceVarByRef,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+    mlir::omp::ReductionModifierAttr &reductionMod) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
-  if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
-          reduction.t))
-    TODO(currentLocation, "Reduction modifiers are not supported");
+  auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
+  if (mod.has_value()) {
+    if (mod.value() == ReductionModifier::Task)
+      TODO(currentLocation, "Reduction modifier `task` is not supported");
+    else
+      reductionMod = mlir::omp::ReductionModifierAttr::get(
+          firOpBuilder.getContext(), translateReductionModifier(mod.value()));
+  }
 
   mlir::omp::DeclareReductionOp decl;
   const auto &redOperatorList{

diff  --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index d7d9b067e0bac6..11baa839c74b49 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -19,6 +19,7 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/type.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Types.h"
 
@@ -120,13 +121,14 @@ class ReductionProcessor {
 
   /// Creates a reduction declaration and associates it with an OpenMP block
   /// directive.
-  static void addDeclareReduction(
+  static void processReductionArguments(
       mlir::Location currentLocation, lower::AbstractConverter &converter,
       const omp::clause::Reduction &reduction,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
       llvm::SmallVectorImpl<bool> &reduceVarByRef,
       llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
-      llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+      llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+      mlir::omp::ReductionModifierAttr &reductionMod);
 };
 
 template <typename FloatOp, typename IntegerOp>

diff  --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
deleted file mode 100644
index 152d91a16f80fe..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-subroutine reduction_inscan()
-  integer :: i,j
-  i = 0
-
-  !$omp do reduction(inscan, +:i)
-  do j=1,10
-     !$omp scan inclusive(i)
-     i = i + 1
-  end do
-  !$omp end do
-end subroutine reduction_inscan

diff  --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
deleted file mode 100644
index 82625ed8c5f31c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
+++ /dev/null
@@ -1,14 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-
-subroutine foo()
-  integer :: i, j
-  j = 0
-  !$omp do reduction (inscan, *: j)
-  do i = 1, 10
-    !$omp scan inclusive(j)
-    j = j + 1
-  end do
-end subroutine

diff  --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
index 6707f65e1a4cc3..b8bfc37d1758f9 100644
--- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90
+++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
@@ -1,7 +1,7 @@
 ! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
 ! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
 
-! CHECK: not yet implemented: Reduction modifiers are not supported
+! CHECK: not yet implemented: Reduction modifier `task` is not supported
 subroutine reduction_task()
   integer :: i
   i = 0

diff  --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90
new file mode 100644
index 00000000000000..97b672ec41f204
--- /dev/null
+++ b/flang/test/Lower/OpenMP/scan.f90
@@ -0,0 +1,36 @@
+! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_1:.*]] : {{.*}}) {
+! CHECK: %[[RED_DECL_1:.*]]:2 = hlfir.declare %[[RED_ARG_1]]
+! CHECK: omp.scan inclusive(%[[RED_DECL_1]]#1 : {{.*}})
+
+subroutine inclusive_scan(a, b, n)
+ implicit none
+ integer a(:), b(:)
+ integer x, k, n
+
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+   x = x + a(k)
+   !$omp scan inclusive(x)
+   b(k) = x
+ end do
+end subroutine inclusive_scan
+
+
+! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_2:.*]] : {{.*}}) {
+! CHECK: %[[RED_DECL_2:.*]]:2 = hlfir.declare %[[RED_ARG_2]]
+! CHECK: omp.scan exclusive(%[[RED_DECL_2]]#1 : {{.*}})
+subroutine exclusive_scan(a, b, n)
+ implicit none
+ integer a(:), b(:)
+ integer x, k, n
+
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+   x = x + a(k)
+   !$omp scan exclusive(x)
+   b(k) = x
+ end do
+end subroutine exclusive_scan

diff  --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 0caf3ad1ccf017..12e3c07669839d 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
   target.addDynamicallyLegalOp<
       omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
       omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
-      omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
+      omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
       omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
       omp::YieldOp>([&](Operation *op) {
     return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -274,6 +274,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionLessOpConversion<omp::CancelOp>,
       RegionLessOpConversion<omp::CriticalDeclareOp>,
       RegionLessOpConversion<omp::OrderedOp>,
+      RegionLessOpConversion<omp::ScanOp>,
       RegionLessOpConversion<omp::TargetEnterDataOp>,
       RegionLessOpConversion<omp::TargetExitDataOp>,
       RegionLessOpConversion<omp::TargetUpdateOp>,


        


More information about the flang-commits mailing list