[flang-commits] [flang] [flang][OpenMP] Add `reduction` clause support to `loop` directive (PR #128849)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Thu Feb 27 05:12:44 PST 2025


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/128849

>From 5539c20db413553ad2fcab5625c272aa3d624f6f Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 26 Feb 2025 02:11:02 -0600
Subject: [PATCH] [flang][OpenMP] Add `reduction` clause support to `loop`
 directive

Extends `loop` directive transformation by adding support for the
`reduction` clause.
---
 .../OpenMP/GenericLoopConversion.cpp          | 65 +++++++++++++++----
 flang/test/Lower/OpenMP/loop-directive.f90    | 21 +++++-
 .../generic-loop-rewriting-todo.mlir          | 16 +----
 3 files changed, 74 insertions(+), 28 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
index bf94166edc079..b0014a3aced6b 100644
--- a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -15,6 +15,8 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 #include <memory>
+#include <optional>
+#include <type_traits>
 
 namespace flangomp {
 #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
@@ -58,7 +60,7 @@ class GenericLoopConversionPattern
       if (teamsLoopCanBeParallelFor(loopOp))
         rewriteToDistributeParallelDo(loopOp, rewriter);
       else
-        rewriteToDistrbute(loopOp, rewriter);
+        rewriteToDistribute(loopOp, rewriter);
       break;
     }
 
@@ -77,9 +79,6 @@ class GenericLoopConversionPattern
     if (loopOp.getOrder())
       return todo("order");
 
-    if (!loopOp.getReductionVars().empty())
-      return todo("reduction");
-
     return mlir::success();
   }
 
@@ -168,7 +167,7 @@ class GenericLoopConversionPattern
     case ClauseBindKind::Parallel:
       return rewriteToWsloop(loopOp, rewriter);
     case ClauseBindKind::Teams:
-      return rewriteToDistrbute(loopOp, rewriter);
+      return rewriteToDistribute(loopOp, rewriter);
     case ClauseBindKind::Thread:
       return rewriteToSimdLoop(loopOp, rewriter);
     }
@@ -211,8 +210,9 @@ class GenericLoopConversionPattern
         loopOp, rewriter);
   }
 
-  void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
-                          mlir::ConversionPatternRewriter &rewriter) const {
+  void rewriteToDistribute(mlir::omp::LoopOp loopOp,
+                           mlir::ConversionPatternRewriter &rewriter) const {
+    assert(loopOp.getReductionVars().empty());
     rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
                              mlir::omp::DistributeOperands>(loopOp, rewriter);
   }
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern
     Fortran::common::openmp::EntryBlockArgs args;
     args.priv.vars = clauseOps.privateVars;
 
+    if constexpr (!std::is_same_v<OpOperandsTy,
+                                  mlir::omp::DistributeOperands>) {
+      populateReductionClauseOps(loopOp, clauseOps);
+      args.reduction.vars = clauseOps.reductionVars;
+    }
+
     auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
     mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
 
@@ -275,8 +281,7 @@ class GenericLoopConversionPattern
 
     auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
                                                              parallelClauseOps);
-    mlir::Block *parallelBlock =
-        genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
+    genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
     parallelOp.setComposite(true);
     rewriter.setInsertionPoint(
         rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
@@ -288,20 +293,54 @@ class GenericLoopConversionPattern
     rewriter.createBlock(&distributeOp.getRegion());
 
     mlir::omp::WsloopOperands wsloopClauseOps;
+    populateReductionClauseOps(loopOp, wsloopClauseOps);
+    Fortran::common::openmp::EntryBlockArgs wsloopArgs;
+    wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
+
     auto wsloopOp =
         rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
     wsloopOp.setComposite(true);
-    rewriter.createBlock(&wsloopOp.getRegion());
+    genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());
 
     mlir::IRMapping mapper;
-    mlir::Block &loopBlock = *loopOp.getRegion().begin();
 
-    for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
-             loopBlock.getArguments(), parallelBlock->getArguments()))
+    auto loopBlockInterface =
+        llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
+    auto parallelBlockInterface =
+        llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
+    auto wsloopBlockInterface =
+        llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
+
+    for (auto [loopOpArg, parallelOpArg] :
+         llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
+                         parallelBlockInterface.getPrivateBlockArgs()))
       mapper.map(loopOpArg, parallelOpArg);
 
+    for (auto [loopOpArg, wsloopOpArg] :
+         llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
+                         wsloopBlockInterface.getReductionBlockArgs()))
+      mapper.map(loopOpArg, wsloopOpArg);
+
     rewriter.clone(*loopOp.begin(), mapper);
   }
+
+  void
+  populateReductionClauseOps(mlir::omp::LoopOp loopOp,
+                             mlir::omp::ReductionClauseOps &clauseOps) const {
+    clauseOps.reductionMod = loopOp.getReductionModAttr();
+    clauseOps.reductionVars = loopOp.getReductionVars();
+
+    std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
+    if (reductionSyms)
+      clauseOps.reductionSyms.assign(reductionSyms->begin(),
+                                     reductionSyms->end());
+
+    std::optional<llvm::ArrayRef<bool>> reductionByref =
+        loopOp.getReductionByref();
+    if (reductionByref)
+      clauseOps.reductionByref.assign(reductionByref->begin(),
+                                      reductionByref->end());
+  }
 };
 
 class GenericLoopConversionPass
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
index ffa4a6ff24f24..3fccf3502fce9 100644
--- a/flang/test/Lower/OpenMP/loop-directive.f90
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -75,7 +75,7 @@ subroutine test_order()
 subroutine test_reduction()
   integer :: i, dummy = 1
 
-  ! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
+  ! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
   ! CHECK-SAME:  (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
   ! CHECK-NEXT:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
   ! CHECK:          %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
@@ -294,3 +294,22 @@ subroutine teams_loop_cannot_be_parallel_for_4
     !$omp end parallel
   END DO
 end subroutine
+
+! CHECK-LABEL: func.func @_QPloop_parallel_bind_reduction
+subroutine loop_parallel_bind_reduction
+  implicit none
+  integer :: x, i
+
+  ! CHECK: omp.wsloop
+  ! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>)
+  ! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
+  ! CHECK-NEXT: omp.loop_nest {{.*}} {
+  ! CHECK-NEXT:   hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
+  ! CHECK-NEXT:   hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
+  ! CHECK:      }
+  ! CHECK: }
+  !$omp loop bind(parallel) reduction(+: x)
+  do i = 0, 10
+    x = x + i
+  end do
+end subroutine
diff --git a/flang/test/Transforms/generic-loop-rewriting-todo.mlir b/flang/test/Transforms/generic-loop-rewriting-todo.mlir
index e992296c9a837..64094d61eb9a3 100644
--- a/flang/test/Transforms/generic-loop-rewriting-todo.mlir
+++ b/flang/test/Transforms/generic-loop-rewriting-todo.mlir
@@ -1,24 +1,12 @@
 // RUN: fir-opt --omp-generic-loop-conversion -verify-diagnostics %s
-
-omp.declare_reduction @add_reduction_i32 : i32 init {
-  ^bb0(%arg0: i32):
-    %c0_i32 = arith.constant 0 : i32
-    omp.yield(%c0_i32 : i32)
-  } combiner {
-  ^bb0(%arg0: i32, %arg1: i32):
-    %0 = arith.addi %arg0, %arg1 : i32
-    omp.yield(%0 : i32)
-  }
-
 func.func @_QPloop_order() {
   omp.teams {
     %c0 = arith.constant 0 : i32
     %c10 = arith.constant 10 : i32
     %c1 = arith.constant 1 : i32
-    %sum = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_orderEi"}
 
-    // expected-error at below {{not yet implemented: Unhandled clause reduction in omp.loop operation}}
-    omp.loop reduction(@add_reduction_i32 %sum -> %arg2 : !fir.ref<i32>) {
+    // expected-error at below {{not yet implemented: Unhandled clause order in omp.loop operation}}
+    omp.loop order(reproducible:concurrent) {
       omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
         omp.yield
       }



More information about the flang-commits mailing list