[flang-commits] [flang] [mlir] Changes for invoking scan Op (PR #123254)

Anchu Rajendran S via flang-commits flang-commits at lists.llvm.org
Thu Jan 16 15:19:23 PST 2025


https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/123254

>From 0da4166bad559a4a214d5810da3696eddd115426 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 16 Jan 2025 17:13:56 -0600
Subject: [PATCH] Changes for invoking scan Op

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    | 39 +++++++++++++++++--
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |  4 ++
 flang/lib/Lower/OpenMP/Clauses.cpp            |  8 ++--
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 23 ++++++++++-
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 36 +++++++++++++++--
 flang/lib/Lower/OpenMP/ReductionProcessor.h   |  4 +-
 .../Lower/OpenMP/Todo/reduction-inscan.f90    | 15 -------
 .../Lower/OpenMP/Todo/reduction-modifiers.f90 | 14 -------
 .../test/Lower/OpenMP/Todo/reduction-task.f90 |  2 +-
 flang/test/Lower/OpenMP/scan.f90              | 34 ++++++++++++++++
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  |  3 +-
 11 files changed, 137 insertions(+), 45 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
 delete mode 100644 flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
 create mode 100644 flang/test/Lower/OpenMP/scan.f90

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index fb8e007c7af574..619f4a57205a05 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -344,6 +344,22 @@ bool ClauseProcessor::processDistSchedule(
   return false;
 }
 
+bool ClauseProcessor::processExclusive(
+    mlir::Location currentLocation,
+    mlir::omp::ExclusiveClauseOps &result) const {
+  return findRepeatableClause<omp::clause::Exclusive>(
+      [&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
+        for (const Object &object : clause.v) {
+          semantics::Symbol *sym = object.sym();
+          mlir::Value symVal = converter.getSymbolAddress(*sym);
+          if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
+            symVal = declOp.getBase();
+          }
+          result.exclusiveVars.push_back(symVal);
+        }
+      });
+}
+
 bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
                                     mlir::omp::FilterClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +396,22 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
   return false;
 }
 
+bool ClauseProcessor::processInclusive(
+    mlir::Location currentLocation,
+    mlir::omp::InclusiveClauseOps &result) const {
+  return findRepeatableClause<omp::clause::Inclusive>(
+      [&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
+        for (const Object &object : clause.v) {
+          const semantics::Symbol *symbol = object.sym();
+          mlir::Value symVal = converter.getSymbolAddress(*symbol);
+          if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
+            symVal = declOp.getBase();
+          }
+          result.inclusiveVars.push_back(symVal);
+        }
+      });
+}
+
 bool ClauseProcessor::processMergeable(
     mlir::omp::MergeableClauseOps &result) const {
   return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1167,9 @@ bool ClauseProcessor::processReduction(
         llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
         llvm::SmallVector<const semantics::Symbol *> reductionSyms;
         ReductionProcessor rp;
-        rp.addDeclareReduction(currentLocation, converter, clause,
-                               reductionVars, reduceVarByRef,
-                               reductionDeclSymbols, reductionSyms);
-
+        rp.addDeclareReduction(
+            currentLocation, converter, clause, reductionVars, reduceVarByRef,
+            reductionDeclSymbols, reductionSyms, &result.reductionMod);
         // Copy local lists into the output.
         llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
         llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7b047d4a7567ad..e05f66c7666844 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -64,6 +64,8 @@ class ClauseProcessor {
   bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
   bool processDistSchedule(lower::StatementContext &stmtCtx,
                            mlir::omp::DistScheduleClauseOps &result) const;
+  bool processExclusive(mlir::Location currentLocation,
+                        mlir::omp::ExclusiveClauseOps &result) const;
   bool processFilter(lower::StatementContext &stmtCtx,
                      mlir::omp::FilterClauseOps &result) const;
   bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
       mlir::omp::HasDeviceAddrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
   bool processHint(mlir::omp::HintClauseOps &result) const;
+  bool processInclusive(mlir::Location currentLocation,
+                        mlir::omp::InclusiveClauseOps &result) const;
   bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
   bool processNowait(mlir::omp::NowaitClauseOps &result) const;
   bool processNumTeams(lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b424e209d56da9..a26bdcdf343e13 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,
 
 Exclusive make(const parser::OmpClause::Exclusive &inp,
                semantics::SemanticsContext &semaCtx) {
-  // inp -> empty
-  llvm_unreachable("Empty: exclusive");
+  // inp.v -> parser::OmpObjectList
+  return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
 }
 
 Fail make(const parser::OmpClause::Fail &inp,
@@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,
 
 Inclusive make(const parser::OmpClause::Inclusive &inp,
                semantics::SemanticsContext &semaCtx) {
-  // inp -> empty
-  llvm_unreachable("Empty: inclusive");
+  // inp.v -> parser::OmpObjectList
+  return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
 }
 
 Indirect make(const parser::OmpClause::Indirect &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 158f76250572ea..1e9926024b0504 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1577,6 +1577,15 @@ static void genParallelClauses(
   cp.processReduction(loc, clauseOps, reductionSyms);
 }
 
+static void genScanClauses(lower::AbstractConverter &converter,
+                           semantics::SemanticsContext &semaCtx,
+                           const List<Clause> &clauses, mlir::Location loc,
+                           mlir::omp::ScanOperands &clauseOps) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processInclusive(loc, clauseOps);
+  cp.processExclusive(loc, clauseOps);
+}
+
 static void genSectionsClauses(
     lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
     const List<Clause> &clauses, mlir::Location loc,
@@ -1974,6 +1983,18 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return parallelOp;
 }
 
+static mlir::omp::ScanOp
+genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+          semantics::SemanticsContext &semaCtx, mlir::Location loc,
+          const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+
+  mlir::omp::ScanOperands clauseOps;
+  genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+
+  return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
+      converter.getCurrentLocation(), clauseOps);
+}
+
 /// This breaks the normal prototype of the gen*Op functions: adding the
 /// sectionBlocks argument so that the enclosed section constructs can be
 /// lowered here with correct reduction symbol remapping.
@@ -2976,7 +2997,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scan:
-    TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+    genScanOp(converter, symTable, semaCtx, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_section:
     llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 2cd21107a916e4..b11c8af168da67 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -25,6 +25,7 @@
 #include "flang/Parser/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/Support/CommandLine.h"
+#include <string>
 
 static llvm::cl::opt<bool> forceByrefReduction(
     "force-byref-reduction",
@@ -514,18 +515,38 @@ static bool doReductionByRef(mlir::Value reductionVar) {
   return false;
 }
 
+mlir::omp::ReductionModifier
+translateReductionModifier(const omp::clause::Reduction::ReductionModifier &m) {
+  switch (m) {
+  case omp::clause::Reduction::ReductionModifier::Default:
+    return mlir::omp::ReductionModifier::defaultmod;
+  case omp::clause::Reduction::ReductionModifier::Inscan:
+    return mlir::omp::ReductionModifier::inscan;
+  case omp::clause::Reduction::ReductionModifier::Task:
+    return mlir::omp::ReductionModifier::task;
+  }
+  return mlir::omp::ReductionModifier::defaultmod;
+}
+
 void ReductionProcessor::addDeclareReduction(
     mlir::Location currentLocation, lower::AbstractConverter &converter,
     const omp::clause::Reduction &reduction,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
     llvm::SmallVectorImpl<bool> &reduceVarByRef,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+    mlir::omp::ReductionModifierAttr *reductionMod) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
-  if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
-          reduction.t))
-    TODO(currentLocation, "Reduction modifiers are not supported");
+  auto mod = std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
+      reduction.t);
+  if (mod.has_value() &&
+      (mod.value() != omp::clause::Reduction::ReductionModifier::Inscan)) {
+    std::string modStr = "default";
+    if (mod.value() == omp::clause::Reduction::ReductionModifier::Task)
+      modStr = "task";
+    TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
+  }
 
   mlir::omp::DeclareReductionOp decl;
   const auto &redOperatorList{
@@ -649,6 +670,13 @@ void ReductionProcessor::addDeclareReduction(
                                   currentLocation, isByRef);
     reductionDeclSymbols.push_back(
         mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+    auto redMod =
+        std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
+            reduction.t);
+    if (redMod.has_value())
+      *reductionMod = mlir::omp::ReductionModifierAttr::get(
+          firOpBuilder.getContext(),
+          translateReductionModifier(redMod.value()));
   }
 }
 
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 5f4d742b62cb10..49e5584088bc61 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -19,6 +19,7 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/type.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Types.h"
 
@@ -126,7 +127,8 @@ class ReductionProcessor {
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
       llvm::SmallVectorImpl<bool> &reduceVarByRef,
       llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
-      llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+      llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+      mlir::omp::ReductionModifierAttr *reductionMod);
 };
 
 template <typename FloatOp, typename IntegerOp>
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
deleted file mode 100644
index 152d91a16f80fe..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-subroutine reduction_inscan()
-  integer :: i,j
-  i = 0
-
-  !$omp do reduction(inscan, +:i)
-  do j=1,10
-     !$omp scan inclusive(i)
-     i = i + 1
-  end do
-  !$omp end do
-end subroutine reduction_inscan
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
deleted file mode 100644
index 82625ed8c5f31c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
+++ /dev/null
@@ -1,14 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-
-subroutine foo()
-  integer :: i, j
-  j = 0
-  !$omp do reduction (inscan, *: j)
-  do i = 1, 10
-    !$omp scan inclusive(j)
-    j = j + 1
-  end do
-end subroutine
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
index 6707f65e1a4cc3..b746872e9e7edf 100644
--- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90
+++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
@@ -1,7 +1,7 @@
 ! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
 ! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
 
-! CHECK: not yet implemented: Reduction modifiers are not supported
+! CHECK: not yet implemented: Reduction modifier task is not supported
 subroutine reduction_task()
   integer :: i
   i = 0
diff --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90
new file mode 100644
index 00000000000000..9cf2174a7f3314
--- /dev/null
+++ b/flang/test/Lower/OpenMP/scan.f90
@@ -0,0 +1,34 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine inclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+   x = x + a(k)
+   !CHECK: omp.scan inclusive({{.*}})
+   !$omp scan inclusive(x)
+   b(k) = x
+ end do
+end subroutine inclusive_scan
+
+
+subroutine exclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+   x = x + a(k)
+   !CHECK: omp.scan exclusive({{.*}})
+   !$omp scan exclusive(x)
+   b(k) = x
+ end do
+end subroutine exclusive_scan
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 5d0003911bca87..056b7f989128a4 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
   target.addDynamicallyLegalOp<
       omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
       omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
-      omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
+      omp::MapInfoOp, omp::ScanOp, omp::OrderedOp, omp::TargetEnterDataOp,
       omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
       omp::YieldOp>([&](Operation *op) {
     return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionLessOpConversion<omp::CancelOp>,
       RegionLessOpConversion<omp::CriticalDeclareOp>,
       RegionLessOpConversion<omp::OrderedOp>,
+      RegionLessOpConversion<omp::ScanOp>,
       RegionLessOpConversion<omp::TargetEnterDataOp>,
       RegionLessOpConversion<omp::TargetExitDataOp>,
       RegionLessOpConversion<omp::TargetUpdateOp>,



More information about the flang-commits mailing list