[flang-commits] [flang] [flang][OpenMP] Add `reduction` clause support to `loop` directive (PR #128849)
Kareem Ergawy via flang-commits
flang-commits at lists.llvm.org
Thu Feb 27 05:12:44 PST 2025
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/128849
>From 5539c20db413553ad2fcab5625c272aa3d624f6f Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 26 Feb 2025 02:11:02 -0600
Subject: [PATCH] [flang][OpenMP] Add `reduction` clause support to `loop`
directive
Extends `loop` directive transformation by adding support for the
`reduction` clause.
---
.../OpenMP/GenericLoopConversion.cpp | 65 +++++++++++++++----
flang/test/Lower/OpenMP/loop-directive.f90 | 21 +++++-
.../generic-loop-rewriting-todo.mlir | 16 +----
3 files changed, 74 insertions(+), 28 deletions(-)
diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
index bf94166edc079..b0014a3aced6b 100644
--- a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -15,6 +15,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
+#include <optional>
+#include <type_traits>
namespace flangomp {
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
@@ -58,7 +60,7 @@ class GenericLoopConversionPattern
if (teamsLoopCanBeParallelFor(loopOp))
rewriteToDistributeParallelDo(loopOp, rewriter);
else
- rewriteToDistrbute(loopOp, rewriter);
+ rewriteToDistribute(loopOp, rewriter);
break;
}
@@ -77,9 +79,6 @@ class GenericLoopConversionPattern
if (loopOp.getOrder())
return todo("order");
- if (!loopOp.getReductionVars().empty())
- return todo("reduction");
-
return mlir::success();
}
@@ -168,7 +167,7 @@ class GenericLoopConversionPattern
case ClauseBindKind::Parallel:
return rewriteToWsloop(loopOp, rewriter);
case ClauseBindKind::Teams:
- return rewriteToDistrbute(loopOp, rewriter);
+ return rewriteToDistribute(loopOp, rewriter);
case ClauseBindKind::Thread:
return rewriteToSimdLoop(loopOp, rewriter);
}
@@ -211,8 +210,9 @@ class GenericLoopConversionPattern
loopOp, rewriter);
}
- void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
- mlir::ConversionPatternRewriter &rewriter) const {
+ void rewriteToDistribute(mlir::omp::LoopOp loopOp,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ assert(loopOp.getReductionVars().empty());
rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
mlir::omp::DistributeOperands>(loopOp, rewriter);
}
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern
Fortran::common::openmp::EntryBlockArgs args;
args.priv.vars = clauseOps.privateVars;
+ if constexpr (!std::is_same_v<OpOperandsTy,
+ mlir::omp::DistributeOperands>) {
+ populateReductionClauseOps(loopOp, clauseOps);
+ args.reduction.vars = clauseOps.reductionVars;
+ }
+
auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
@@ -275,8 +281,7 @@ class GenericLoopConversionPattern
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
parallelClauseOps);
- mlir::Block *parallelBlock =
- genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
+ genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
parallelOp.setComposite(true);
rewriter.setInsertionPoint(
rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
@@ -288,20 +293,54 @@ class GenericLoopConversionPattern
rewriter.createBlock(&distributeOp.getRegion());
mlir::omp::WsloopOperands wsloopClauseOps;
+ populateReductionClauseOps(loopOp, wsloopClauseOps);
+ Fortran::common::openmp::EntryBlockArgs wsloopArgs;
+ wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
+
auto wsloopOp =
rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
wsloopOp.setComposite(true);
- rewriter.createBlock(&wsloopOp.getRegion());
+ genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());
mlir::IRMapping mapper;
- mlir::Block &loopBlock = *loopOp.getRegion().begin();
- for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
- loopBlock.getArguments(), parallelBlock->getArguments()))
+ auto loopBlockInterface =
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
+ auto parallelBlockInterface =
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
+ auto wsloopBlockInterface =
+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
+
+ for (auto [loopOpArg, parallelOpArg] :
+ llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
+ parallelBlockInterface.getPrivateBlockArgs()))
mapper.map(loopOpArg, parallelOpArg);
+ for (auto [loopOpArg, wsloopOpArg] :
+ llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
+ wsloopBlockInterface.getReductionBlockArgs()))
+ mapper.map(loopOpArg, wsloopOpArg);
+
rewriter.clone(*loopOp.begin(), mapper);
}
+
+ void
+ populateReductionClauseOps(mlir::omp::LoopOp loopOp,
+ mlir::omp::ReductionClauseOps &clauseOps) const {
+ clauseOps.reductionMod = loopOp.getReductionModAttr();
+ clauseOps.reductionVars = loopOp.getReductionVars();
+
+ std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
+ if (reductionSyms)
+ clauseOps.reductionSyms.assign(reductionSyms->begin(),
+ reductionSyms->end());
+
+ std::optional<llvm::ArrayRef<bool>> reductionByref =
+ loopOp.getReductionByref();
+ if (reductionByref)
+ clauseOps.reductionByref.assign(reductionByref->begin(),
+ reductionByref->end());
+ }
};
class GenericLoopConversionPass
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
index ffa4a6ff24f24..3fccf3502fce9 100644
--- a/flang/test/Lower/OpenMP/loop-directive.f90
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -75,7 +75,7 @@ subroutine test_order()
subroutine test_reduction()
integer :: i, dummy = 1
- ! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
+ ! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
! CHECK-SAME: (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
! CHECK: %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
@@ -294,3 +294,22 @@ subroutine teams_loop_cannot_be_parallel_for_4
!$omp end parallel
END DO
end subroutine
+
+! CHECK-LABEL: func.func @_QPloop_parallel_bind_reduction
+subroutine loop_parallel_bind_reduction
+ implicit none
+ integer :: x, i
+
+ ! CHECK: omp.wsloop
+ ! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>)
+ ! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
+ ! CHECK-NEXT: omp.loop_nest {{.*}} {
+ ! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
+ ! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
+ ! CHECK: }
+ ! CHECK: }
+ !$omp loop bind(parallel) reduction(+: x)
+ do i = 0, 10
+ x = x + i
+ end do
+end subroutine
diff --git a/flang/test/Transforms/generic-loop-rewriting-todo.mlir b/flang/test/Transforms/generic-loop-rewriting-todo.mlir
index e992296c9a837..64094d61eb9a3 100644
--- a/flang/test/Transforms/generic-loop-rewriting-todo.mlir
+++ b/flang/test/Transforms/generic-loop-rewriting-todo.mlir
@@ -1,24 +1,12 @@
// RUN: fir-opt --omp-generic-loop-conversion -verify-diagnostics %s
-
-omp.declare_reduction @add_reduction_i32 : i32 init {
- ^bb0(%arg0: i32):
- %c0_i32 = arith.constant 0 : i32
- omp.yield(%c0_i32 : i32)
- } combiner {
- ^bb0(%arg0: i32, %arg1: i32):
- %0 = arith.addi %arg0, %arg1 : i32
- omp.yield(%0 : i32)
- }
-
func.func @_QPloop_order() {
omp.teams {
%c0 = arith.constant 0 : i32
%c10 = arith.constant 10 : i32
%c1 = arith.constant 1 : i32
- %sum = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_orderEi"}
- // expected-error at below {{not yet implemented: Unhandled clause reduction in omp.loop operation}}
- omp.loop reduction(@add_reduction_i32 %sum -> %arg2 : !fir.ref<i32>) {
+ // expected-error at below {{not yet implemented: Unhandled clause order in omp.loop operation}}
+ omp.loop order(reproducible:concurrent) {
omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
omp.yield
}
More information about the flang-commits
mailing list