[Mlir-commits] [mlir] [flang] [mlir][flang][openmp] Rework parallel reduction operations (PR #79308)

David Truby llvmlistbot at llvm.org
Fri Jan 26 06:42:27 PST 2024


https://github.com/DavidTruby updated https://github.com/llvm/llvm-project/pull/79308

>From c7ea791524de2998f4581620a57ae5c7ec95f079 Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Wed, 24 Jan 2024 15:31:15 +0000
Subject: [PATCH 1/2] [mlir][flang][openmp] Rework parallel reduction
 operations

This patch reworks the way that parallel reduction operations function to better
match the expected semantics from the OpenMP specification. Previously specific
omp.reduction operations were used inside the region, meaning that the reduction
only applied when the correct operation was used, whereas the specification
states that any change to the variable inside the region should be taken into
account for the reduction.

The new semantics create a private reduction variable as a block argument which
should be used normally for all operations on that variable in the region; this
private variable is then combined with the others into the shared variable. This
way no special omp.reduction operations are needed inside the region.

This patch only makes the change for the `parallel` operation, the change for
the `wsloop` operation will be in a separate patch.
---
 flang/lib/Lower/OpenMP.cpp                    | 52 ++++++++++++++-----
 .../OpenMP/FIR/parallel-reduction-add.f90     | 26 ++++++----
 .../Lower/OpenMP/parallel-reduction-add.f90   | 26 +++++++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  9 ++--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 50 ++++++++++++++++++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 16 ++++--
 mlir/test/Dialect/OpenMP/ops.mlir             | 39 ++++++++------
 mlir/test/Target/LLVMIR/openmp-reduction.mlir | 12 +++--
 8 files changed, 173 insertions(+), 57 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index d2215f4d1bf1ce3..5d9a27ccb90d4b6 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -621,7 +621,9 @@ class ClauseProcessor {
   bool processReduction(
       mlir::Location currentLocation,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
-      llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
+      llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+      llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
+          nullptr) const;
   bool processSectionsReduction(mlir::Location currentLocation) const;
   bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
   bool
@@ -1077,7 +1079,9 @@ class ReductionProcessor {
       Fortran::lower::AbstractConverter &converter,
       const Fortran::parser::OmpReductionClause &reduction,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
-      llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+      llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+      llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols =
+          nullptr) {
     fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
     mlir::omp::ReductionDeclareOp decl;
     const auto &redOperator{
@@ -1106,7 +1110,9 @@ class ReductionProcessor {
       for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
         if (const auto *name{
                 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+          if (Fortran::semantics::Symbol * symbol{name->symbol}) {
+            if (reductionSymbols)
+              reductionSymbols->push_back(symbol);
             mlir::Value symVal = converter.getSymbolAddress(*symbol);
             if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
               symVal = declOp.getBase();
@@ -1138,7 +1144,9 @@ class ReductionProcessor {
         for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
           if (const auto *name{
                   Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+            if (Fortran::semantics::Symbol * symbol{name->symbol}) {
+              if (reductionSymbols)
+                reductionSymbols->push_back(symbol);
               mlir::Value symVal = converter.getSymbolAddress(*symbol);
               if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
                 symVal = declOp.getBase();
@@ -1932,13 +1940,16 @@ bool ClauseProcessor::processMap(
 bool ClauseProcessor::processReduction(
     mlir::Location currentLocation,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
-    llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
+    llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+    llvm::SmallVectorImpl<Fortran::semantics::Symbol *> *reductionSymbols)
+    const {
   return findRepeatableClause<ClauseTy::Reduction>(
       [&](const ClauseTy::Reduction *reductionClause,
           const Fortran::parser::CharBlock &) {
         ReductionProcessor rp;
         rp.addReductionDecl(currentLocation, converter, reductionClause->v,
-                            reductionVars, reductionDeclSymbols);
+                            reductionVars, reductionDeclSymbols,
+                            reductionSymbols);
       });
 }
 
@@ -2502,6 +2513,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
       reductionVars;
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+  llvm::SmallVector<Fortran::semantics::Symbol *> reductionSymbols;
 
   ClauseProcessor cp(converter, clauseList);
   cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2511,9 +2523,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   cp.processDefault();
   cp.processAllocate(allocatorOperands, allocateOperands);
   if (!outerCombined)
-    cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
+    cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
+                        &reductionSymbols);
 
-  return genOpWithBody<mlir::omp::ParallelOp>(
+  auto op = genOpWithBody<mlir::omp::ParallelOp>(
       converter, eval, genNested, currentLocation, outerCombined, &clauseList,
       /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
@@ -2523,6 +2536,24 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  reductionDeclSymbols),
       procBindKindAttr);
+
+  // Add reduction block arguments
+  if (!reductionVars.empty()) {
+    mlir::Block &regionBlock = op.getRegion().front();
+    fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+    for (auto [val, sym] : llvm::zip_equal(reductionVars, reductionSymbols)) {
+      auto savedIP = firOpBuilder.getInsertionPoint();
+      firOpBuilder.setInsertionPointToStart(&regionBlock);
+      auto prv = regionBlock.addArgument(val.getType(), op.getLoc());
+      converter.bindSymbol(*sym, prv);
+      val.replaceUsesWithIf(prv, [&regionBlock](mlir::OpOperand &use) {
+        return use.getOwner()->getBlock() == ®ionBlock;
+      });
+      firOpBuilder.setInsertionPoint(&regionBlock, savedIP);
+    }
+  }
+
+  return op;
 }
 
 static mlir::omp::SectionOp
@@ -3483,10 +3514,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     break;
   }
 
-  if (singleDirective) {
-    genOpenMPReduction(converter, beginClauseList);
+  if (singleDirective)
     return;
-  }
 
   // Codegen for combined directives
   bool combinedDirective = false;
@@ -3522,7 +3551,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
                               ")");
 
   genNestedEvaluations(converter, eval);
-  genOpenMPReduction(converter, beginClauseList);
 }
 
 static void
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
index 6580aeb13ccd1e1..4b223e822760a99 100644
--- a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
@@ -27,9 +27,11 @@
 !CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
-!CHECK:  omp.parallel   reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
-!CHECK:    %[[I_INCR:.*]] = arith.constant 1 : i32
-!CHECK:    omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK:  omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK:    %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
 !CHECK:    omp.terminator
 !CHECK:  }
 !CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
 !CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
 !CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
 !CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
-!CHECK:  omp.parallel   reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
-!CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK:    omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK:  omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK:    %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
 !CHECK:    omp.terminator
 !CHECK:  }
 !CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
 !CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
-!CHECK:  omp.parallel   reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK:    omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK:    %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK:    %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK:    fir.store %[[RES1]] to %[[PRV1]]
+!CHECK:    %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
 !CHECK:    %[[I_INCR:.*]] = arith.constant 3 : i32
-!CHECK:    omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK:    %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES0]] to %[[PRV0]]
 !CHECK:    omp.terminator
 !CHECK:  }
 !CHECK: return
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
index 81a93aebbd26619..83edac983727894 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
@@ -28,9 +28,12 @@
 !CHECK:  %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:  omp.parallel   reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
+!CHECK:  omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
 !CHECK:    %[[I_INCR:.*]] = arith.constant 1 : i32
-!CHECK:    omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK:    %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
+!CHECK:    hlfir.assign %[[RES]] to %[[PRV]] : i32, !fir.ref<i32>
 !CHECK:    omp.terminator
 !CHECK:  }
 !CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
 !CHECK:  %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
 !CHECK:  hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK:  omp.parallel   reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK:    omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK:    %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK:    hlfir.assign %[[RES]] to %[[PRV]] : f32, !fir.ref<f32>
 !CHECK:    omp.terminator
 !CHECK:  }
 !CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
 !CHECK:  hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:  omp.parallel   reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+!CHECK:    %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
-!CHECK:    omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
+!CHECK:    %[[R_LPRV:.+]] = fir.load %[[RPRV]] : !fir.ref<f32>
+!CHECK:    %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
+!CHECK:    hlfir.assign %[[RES1]] to %[[RPRV]] : f32, !fir.ref<f32>
+!CHECK:    %[[I_LPRV:.+]] = fir.load %[[IPRV]] : !fir.ref<i32>
 !CHECK:    %[[I_INCR:.*]] = arith.constant 3 : i32
-!CHECK:    omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK:    %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
+!CHECK:    hlfir.assign %[[RES0]] to %[[IPRV]] : i32, !fir.ref<i32>
 !CHECK:    omp.terminator
 !CHECK:  }
 !CHECK: return
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 96c15e775a3024b..f1aace21168bca5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -200,11 +200,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
     unsigned getNumReductionVars() { return getReductionVars().size(); }
   }];
   let assemblyFormat = [{
-    oilist( `reduction` `(`
-              custom<ReductionVarList>(
-                $reduction_vars, type($reduction_vars), $reductions
-              ) `)`
-          | `if` `(` $if_expr_var `:` type($if_expr_var) `)`
+    oilist(
+          `if` `(` $if_expr_var `:` type($if_expr_var) `)`
           | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
           | `allocate` `(`
               custom<AllocateAndAllocator>(
@@ -212,7 +209,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
                 $allocators_vars, type($allocators_vars)
               ) `)`
           | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
-    ) $region attr-dict
+    ) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 13cc16125a2733e..075c49f3a5d186b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Interfaces/FoldInterfaces.h"
 
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
@@ -427,6 +428,55 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
 // Parser, printer and verifier for ReductionVarList
 //===----------------------------------------------------------------------===//
 
+static ParseResult
+parseWsReduction(OpAsmParser &parser, Region &region,
+                 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+                 SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
+
+  // possibly parse reduction
+  if (failed(parser.parseOptionalKeyword("reduction")))
+    return parser.parseRegion(region);
+
+  SmallVector<SymbolRefAttr> reductionVec;
+  SmallVector<OpAsmParser::Argument> privates;
+
+  if (failed(
+          parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
+            if (parser.parseAttribute(reductionVec.emplace_back()) ||
+                parser.parseOperand(operands.emplace_back()) ||
+                parser.parseArrow() ||
+                parser.parseArgument(privates.emplace_back()) ||
+                parser.parseColonType(types.emplace_back()))
+              return failure();
+            return success();
+          })))
+    return failure();
+
+  for (std::size_t i = 0; i < privates.size(); ++i) {
+    privates[i].type = types[i];
+  }
+  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
+  reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
+  return parser.parseRegion(region, privates);
+}
+
+void printWsReduction(OpAsmPrinter &p, Operation *op, Region &region,
+                      ValueRange operands, TypeRange types,
+                      ArrayAttr reductionSymbols) {
+  if (reductionSymbols) {
+    p << "reduction(";
+    llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
+                                          region.front().getArguments(), types),
+                          p, [&p](auto t) {
+                            auto [sym, op, arg, type] = t;
+                            p << sym << " " << op << " -> " << arg << " : "
+                              << type;
+                          });
+    p << ") ";
+  }
+  p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
 /// reduction-entry-list ::= reduction-entry
 ///                        | reduction-entry-list `,` reduction-entry
 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 23e101f1e45272a..71b7937671b26ef 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1018,9 +1018,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     // Allocate reduction vars
     SmallVector<llvm::Value *> privateReductionVariables;
     DenseMap<Value, llvm::Value *> reductionVariableMap;
-    allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
-                       reductionDecls, privateReductionVariables,
-                       reductionVariableMap);
+    {
+      llvm::IRBuilderBase::InsertPointGuard guard(builder);
+      builder.restoreIP(allocaIP);
+      auto args = opInst.getRegion().getArguments();
+
+      for (std::size_t i = 0; i < opInst.getNumReductionVars(); ++i) {
+        llvm::Value *var = builder.CreateAlloca(
+            moduleTranslation.convertType(reductionDecls[i].getType()));
+        moduleTranslation.mapValue(args[i], var);
+        privateReductionVariables.push_back(var);
+        reductionVariableMap.try_emplace(opInst.getReductionVars()[i], var);
+      }
+    }
 
     // Store the mapping between reduction variables and their private copies on
     // ModuleTranslation stack. It can be then recovered when translating
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index ccf72ae31d439ed..0451f4a0bfa2344 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -640,11 +640,13 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
 func.func @parallel_reduction() {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
-  // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr)
-  omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+  // CHECK: omp.parallel reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr)
+  omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
     %1 = arith.constant 2.0 : f32
-    // CHECK: omp.reduction %{{.+}}, %{{.+}}
-    omp.reduction %1, %0 : f32, !llvm.ptr
+    %2 = llvm.load %prv : !llvm.ptr -> f32
+    // CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
+    %3 = llvm.fadd %1, %2 : f32
+    llvm.store %3, %prv : f32, !llvm.ptr
     omp.terminator
   }
   return
@@ -654,13 +656,14 @@ func.func @parallel_reduction() {
 func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
-  // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
-  omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+  // CHECK: omp.parallel reduction(@add_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
+  omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
     // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
     omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
       %1 = arith.constant 2.0 : f32
-      // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
-      omp.reduction %1, %0 : f32, !llvm.ptr
+      %2 = llvm.load %prv : !llvm.ptr -> f32
+      // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
+      llvm.fadd %1, %2 : f32
       // CHECK: omp.yield
       omp.yield
     }
@@ -799,11 +802,14 @@ func.func @wsloop_reduction2(%lb : index, %ub : index, %step : index) {
 // CHECK-LABEL: func @parallel_reduction2
 func.func @parallel_reduction2() {
   %0 = memref.alloca() : memref<1xf32>
-  // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
-  omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) {
+  // CHECK: omp.parallel reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>)
+  omp.parallel reduction(@add2_f32 %0 -> %prv : memref<1xf32>) {
     %1 = arith.constant 2.0 : f32
-    // CHECK: omp.reduction
-    omp.reduction %1, %0 : f32, memref<1xf32>
+    %2 = arith.constant 0 : index
+    %3 = memref.load %prv[%2] : memref<1xf32>
+    // CHECK: llvm.fadd
+    %4 = llvm.fadd %1, %3 : f32
+    memref.store %4, %prv[%2] : memref<1xf32>
     omp.terminator
   }
   return
@@ -813,13 +819,14 @@ func.func @parallel_reduction2() {
 func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
-  // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) {
-  omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) {
+  // CHECK: omp.parallel reduction(@add2_f32 %{{.*}} -> %{{.+}} : !llvm.ptr) {
+  omp.parallel reduction(@add2_f32 %0 -> %prv : !llvm.ptr) {
     // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}})
     omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
       %1 = arith.constant 2.0 : f32
-      // CHECK: omp.reduction %{{.+}}, %{{.+}} : f32, !llvm.ptr
-      omp.reduction %1, %0 : f32, !llvm.ptr
+      %2 = llvm.load %prv : !llvm.ptr -> f32
+      // CHECK: llvm.fadd %{{.+}}, %{{.+}} : f32
+      %3 = llvm.fadd %1, %2 : f32
       // CHECK: omp.yield
       omp.yield
     }
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction.mlir b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
index 93ab578df9e4e86..dae83c0cf92ed88 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction.mlir
@@ -441,9 +441,11 @@ atomic {
 llvm.func @simple_reduction_parallel() {
   %c1 = llvm.mlir.constant(1 : i32) : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
-  omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) {
+  omp.parallel reduction(@add_f32 %0 -> %prv : !llvm.ptr) {
     %1 = llvm.mlir.constant(2.0 : f32) : f32
-    omp.reduction %1, %0 : f32, !llvm.ptr
+    %2 = llvm.load %prv : !llvm.ptr -> f32
+    %3 = llvm.fadd %2, %1 : f32
+    llvm.store %3, %prv : f32, !llvm.ptr
     omp.terminator
   }
   llvm.return
@@ -512,10 +514,12 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
   %lb = llvm.mlir.constant(1 : i64) : i64
   %step = llvm.mlir.constant(1 : i64) : i64
   
-  omp.parallel reduction(@add_i32 -> %0 : !llvm.ptr) {
+  omp.parallel reduction(@add_i32 %0 -> %prv : !llvm.ptr) {
     omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
       %ival = llvm.trunc %iv : i64 to i32
-      omp.reduction %ival, %0 : i32, !llvm.ptr
+      %lprv = llvm.load %prv : !llvm.ptr -> i32
+      %add = llvm.add %lprv, %ival : i32
+      llvm.store %add, %prv : i32, !llvm.ptr
       omp.yield
     }
     omp.terminator

>From 7890b53c0e6bfaac1bbe7cf9ea7cc54aa3756a1a Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Fri, 26 Jan 2024 14:34:15 +0000
Subject: [PATCH 2/2] Move parsing and printing reductions to separate function

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 72 ++++++++++++-------
 2 files changed, 46 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f1aace21168bca5..77bea9db5276e1e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -209,7 +209,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
                 $allocators_vars, type($allocators_vars)
               ) `)`
           | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
-    ) custom<WsReduction>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
+    ) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 075c49f3a5d186b..3947d6af19849b8 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -35,6 +35,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
+#include "mlir/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace mlir::omp;
@@ -428,17 +429,15 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
 // Parser, printer and verifier for ReductionVarList
 //===----------------------------------------------------------------------===//
 
-static ParseResult
-parseWsReduction(OpAsmParser &parser, Region &region,
-                 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
-                 SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
-
-  // possibly parse reduction
+ParseResult parseReductionClause(
+    OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+    SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
+    SmallVectorImpl<OpAsmParser::Argument> &privates) {
   if (failed(parser.parseOptionalKeyword("reduction")))
-    return parser.parseRegion(region);
+    return failure();
 
   SmallVector<SymbolRefAttr> reductionVec;
-  SmallVector<OpAsmParser::Argument> privates;
 
   if (failed(
           parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
@@ -452,28 +451,46 @@ parseWsReduction(OpAsmParser &parser, Region &region,
           })))
     return failure();
 
-  for (std::size_t i = 0; i < privates.size(); ++i) {
-    privates[i].type = types[i];
+  for (auto [prv, type] : llvm::zip_equal(privates, types)) {
+    prv.type = type;
   }
   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
   reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
-  return parser.parseRegion(region, privates);
-}
-
-void printWsReduction(OpAsmPrinter &p, Operation *op, Region &region,
-                      ValueRange operands, TypeRange types,
-                      ArrayAttr reductionSymbols) {
-  if (reductionSymbols) {
-    p << "reduction(";
-    llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
-                                          region.front().getArguments(), types),
-                          p, [&p](auto t) {
-                            auto [sym, op, arg, type] = t;
-                            p << sym << " " << op << " -> " << arg << " : "
-                              << type;
-                          });
-    p << ") ";
-  }
+  return success();
+}
+
+static void printReductionClause(OpAsmPrinter &p, Operation *op, Region &region,
+                                 ValueRange operands, TypeRange types,
+                                 ArrayAttr reductionSymbols) {
+  p << "reduction(";
+  llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
+                                        region.front().getArguments(), types),
+                        p, [&p](auto t) {
+                          auto [sym, op, arg, type] = t;
+                          p << sym << " " << op << " -> " << arg << " : "
+                            << type;
+                        });
+  p << ") ";
+}
+
+static ParseResult
+parseParallelRegion(OpAsmParser &parser, Region &region,
+                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+                    SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
+
+  llvm::SmallVector<OpAsmParser::Argument> privates;
+  if (succeeded(parseReductionClause(parser, region, operands, types,
+                                     reductionSymbols, privates)))
+    return parser.parseRegion(region, privates);
+
+  return parser.parseRegion(region);
+}
+
+static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
+                                ValueRange operands, TypeRange types,
+                                ArrayAttr reductionSymbols) {
+  if (reductionSymbols)
+    printReductionClause(p, op, region, operands, types, reductionSymbols);
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
@@ -1164,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region &region,
   loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
   for (auto &iv : ivs)
     iv.type = loopVarType;
+
   return parser.parseRegion(region, ivs);
 }
 



More information about the Mlir-commits mailing list