[llvm-branch-commits] [flang] [llvm] [mlir] [flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref (PR #84304)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Mar 7 03:01:43 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tom Eccles (tblah)

<details>
<summary>Changes</summary>

Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation.

This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region).

Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel.

Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing.

Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering.

Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref.

Commit series for by-ref OpenMP reductions 3/3

---

Patch is 297.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84304.diff


37 Files Affected:

- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+13-5) 
- (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+93-27) 
- (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+13-5) 
- (added) flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 (+117) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+392) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+46) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+45) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+45) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+187) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+189) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+90) 
- (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+91) 
- (added) flang/test/Lower/OpenMP/default-clause-byref.f90 (+385) 
- (added) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+30) 
- (added) flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 (+125) 
- (added) flang/test/Lower/OpenMP/parallel-reduction-byref.f90 (+44) 
- (added) flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 (+16) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+433) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+58) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+64) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+55) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+64) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+206) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+202) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+207) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+204) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-max-2-byref.f90 (+20) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+152) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+62) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+154) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+414) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+3-1) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+22-8) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+11-1) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+3-2) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+72-27) 
- (added) mlir/test/Target/LLVMIR/openmp-reduction-byref.mlir (+66) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 185e0316870e94..d9648f3d692cc6 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     return reductionSymbols;
   };
 
+  mlir::UnitAttr byrefAttr;
+  if (ReductionProcessor::doReductionByRef(reductionVars))
+    byrefAttr = converter.getFirOpBuilder().getUnitAttr();
+
   OpWithBodyGenInfo genInfo =
       OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
           .setGenNested(genNested)
@@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
             : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                    reductionDeclSymbols),
         procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
-        /*privatizers=*/nullptr);
+        /*privatizers=*/nullptr, byrefAttr);
   }
 
   bool privatize = !outerCombined;
@@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
       delayedPrivatizationInfo.privatizers.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
-                                 privatizers));
+                                 privatizers),
+      byrefAttr);
 }
 
 static mlir::omp::SectionOp
@@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
   mlir::omp::ClauseOrderKindAttr orderClauseOperand;
   mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
-  mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
+  mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
   mlir::IntegerAttr orderedClauseOperand;
   mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
   std::size_t loopVarTypeSize;
@@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   convertLoopBounds(converter, loc, lowerBound, upperBound, step,
                     loopVarTypeSize);
 
+  if (ReductionProcessor::doReductionByRef(reductionVars))
+    byrefOperand = firOpBuilder.getUnitAttr();
+
   auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
       loc, lowerBound, upperBound, step, linearVars, linearStepVars,
       reductionVars,
@@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
                                  reductionDeclSymbols),
       scheduleValClauseOperand, scheduleChunkClauseOperand,
       /*schedule_modifiers=*/nullptr,
-      /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
-      orderClauseOperand,
+      /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
+      orderedClauseOperand, orderClauseOperand,
       /*inclusive=*/firOpBuilder.getUnitAttr());
 
   // Handle attribute based clauses.
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..34959e95dce04a 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -14,9 +14,16 @@
 
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Parser/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<bool> forceByrefReduction(
+    "force-byref-reduction",
+    llvm::cl::desc("Pass all reduction arguments by reference"),
+    llvm::cl::Hidden);
 
 namespace Fortran {
 namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
 }
 
 std::string ReductionProcessor::getReductionName(llvm::StringRef name,
-                                                 mlir::Type ty) {
+                                                 mlir::Type ty, bool isByRef) {
+  ty = fir::unwrapRefType(ty);
+
+  // extra string to distinguish reduction functions for variables passed by
+  // reference
+  llvm::StringRef byrefAddition{""};
+  if (isByRef)
+    byrefAddition = "_byref";
+
   return (llvm::Twine(name) +
           (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
+          llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
       .str();
 }
 
 std::string ReductionProcessor::getReductionName(
     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type ty) {
+    mlir::Type ty, bool isByRef) {
   std::string reductionName;
 
   switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
     break;
   }
 
-  return getReductionName(reductionName, ty);
+  return getReductionName(reductionName, ty, isByRef);
 }
 
 mlir::Value
 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                           ReductionIdentifier redId,
                                           fir::FirOpBuilder &builder) {
+  type = fir::unwrapRefType(type);
   assert((fir::isa_integer(type) || fir::isa_real(type) ||
           type.isa<fir::LogicalType>()) &&
          "only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
     mlir::Type type, mlir::Value op1, mlir::Value op2) {
   mlir::Value reductionOp;
+  type = fir::unwrapRefType(type);
   switch (redId) {
   case ReductionIdentifier::MAX:
     reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
 
 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+    bool isByRef) {
   mlir::OpBuilder::InsertionGuard guard(builder);
   mlir::ModuleOp module = builder.getModule();
 
@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
     return decl;
 
   mlir::OpBuilder modBuilder(module.getBodyRegion());
+  mlir::Type valTy = fir::unwrapRefType(type);
+  if (!isByRef)
+    type = valTy;
 
   decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
                                                           type);
   builder.createBlock(&decl.getInitializerRegion(),
                       decl.getInitializerRegion().end(), {type}, {loc});
   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+
   mlir::Value init = getReductionInitValue(loc, type, redId, builder);
-  builder.create<mlir::omp::YieldOp>(loc, init);
+  if (isByRef) {
+    mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
+    builder.createStoreWithConvert(loc, init, alloca);
+    builder.create<mlir::omp::YieldOp>(loc, alloca);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, init);
+  }
 
   builder.createBlock(&decl.getReductionRegion(),
                       decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+  mlir::Value outAddr = op1;
+
+  op1 = builder.loadIfRef(loc, op1);
+  op2 = builder.loadIfRef(loc, op2);
 
   mlir::Value reductionOp =
       createScalarCombiner(builder, loc, redId, type, op1, op2);
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  if (isByRef) {
+    builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+    builder.create<mlir::omp::YieldOp>(loc, outAddr);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  }
 
   return decl;
 }
 
+bool ReductionProcessor::doReductionByRef(
+    const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+  if (reductionVars.empty())
+    return false;
+  if (forceByrefReduction)
+    return true;
+
+  for (mlir::Value reductionVar : reductionVars) {
+    if (auto declare =
+            mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+      reductionVar = declare.getMemref();
+
+    if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+      return true;
+  }
+  return false;
+}
+
 void ReductionProcessor::addReductionDecl(
     mlir::Location currentLocation,
     Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
   const auto &redOperator{
       std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
   const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+
+  // initial pass to collect all recuction vars so we can figure out if this
+  // should happen byref
+  for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+    if (const auto *name{
+            Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+      if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+        if (reductionSymbols)
+          reductionSymbols->push_back(symbol);
+        mlir::Value symVal = converter.getSymbolAddress(*symbol);
+        if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+          symVal = declOp.getBase();
+        reductionVars.push_back(symVal);
+      }
+    }
+  }
+  const bool isByRef = doReductionByRef(reductionVars);
+
   if (const auto &redDefinedOp =
           std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
     const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
       if (const auto *name{
               Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
         if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-          if (reductionSymbols)
-            reductionSymbols->push_back(symbol);
           mlir::Value symVal = converter.getSymbolAddress(*symbol);
           if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
             symVal = declOp.getBase();
-          mlir::Type redType =
-              symVal.getType().cast<fir::ReferenceType>().getEleTy();
-          reductionVars.push_back(symVal);
-          if (redType.isa<fir::LogicalType>())
+          auto redType = symVal.getType().cast<fir::ReferenceType>();
+          if (redType.getEleTy().isa<fir::LogicalType>())
             decl = createReductionDecl(
                 firOpBuilder,
-                getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
-                redType, currentLocation);
-          else if (redType.isIntOrIndexOrFloat()) {
-            decl = createReductionDecl(firOpBuilder,
-                                       getReductionName(intrinsicOp, redType),
-                                       redId, redType, currentLocation);
+                getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
+                                 isByRef),
+                redId, redType, currentLocation, isByRef);
+          else if (redType.getEleTy().isIntOrIndexOrFloat()) {
+            decl = createReductionDecl(
+                firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
+                redId, redType, currentLocation, isByRef);
           } else {
             TODO(currentLocation, "Reduction of some types is not supported");
           }
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
         if (const auto *name{
                 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
           if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-            if (reductionSymbols)
-              reductionSymbols->push_back(symbol);
             mlir::Value symVal = converter.getSymbolAddress(*symbol);
             if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
               symVal = declOp.getBase();
-            mlir::Type redType =
-                symVal.getType().cast<fir::ReferenceType>().getEleTy();
-            reductionVars.push_back(symVal);
-            assert(redType.isIntOrIndexOrFloat() &&
+            auto redType = symVal.getType().cast<fir::ReferenceType>();
+            assert(redType.getEleTy().isIntOrIndexOrFloat() &&
                    "Unsupported reduction type");
             decl = createReductionDecl(
                 firOpBuilder,
                 getReductionName(getRealName(*reductionIntrinsic).ToString(),
-                                 redType),
-                redId, redType, currentLocation);
+                                 redType, isByRef),
+                redId, redType, currentLocation, isByRef);
             reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
                 firOpBuilder.getContext(), decl.getSymName()));
           }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..679580f2a3cac7 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -14,6 +14,7 @@
 #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
   static const Fortran::semantics::SourceName
   getRealName(const Fortran::parser::ProcedureDesignator &pd);
 
-  static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
+  static bool
+  doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
+
+  static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
+                                      bool isByRef);
 
   static std::string getReductionName(
       Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type ty);
+      mlir::Type ty, bool isByRef);
 
   /// This function returns the identity value of the operator \p
   /// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
   /// symbol table. The declaration has a constant initializer with the neutral
   /// value `initValue`, and the reduction combiner carried over from `reduce`.
   /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
+  static mlir::omp::ReductionDeclareOp
+  createReductionDecl(fir::FirOpBuilder &builder,
+                      llvm::StringRef reductionOpName,
+                      const ReductionIdentifier redId, mlir::Type type,
+                      mlir::Location loc, bool isByRef);
 
   /// Creates a reduction declaration and associates it with an OpenMP block
   /// directive.
@@ -124,6 +131,7 @@ mlir::Value
 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
                                           mlir::Type type, mlir::Location loc,
                                           mlir::Value op1, mlir::Value op2) {
+  type = fir::unwrapRefType(type);
   assert(type.isIntOrIndexOrFloat() &&
          "only integer and float types are currently supported");
   if (type.isIntOrIndex())
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
new file mode 100644
index 00000000000000..ca432662b77c44
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
@@ -0,0 +1,117 @@
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -flang-deprecated-no-hlfir -fopenmp -mmlir --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : !fir.ref<f32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<f32>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  %[[REF:.*]] = fir.alloca f32
+!CHECKL  fir.store [[%C0_1]] to %[[REF]] : !fir.ref<f32>
+!CHECK:  omp.yield(%[[REF]] : !fir.ref<f32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<f32>, %[[ARG1:.*]]: !fir.ref<f32>):
+!CHECK:  %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<f32>
+!CHECK:  %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<f32>
+!CHECK:  %[[RES:.*]] = arith.addf %[[LD0]], %[[LD1]] {{.*}}: f32
+!CHECK:  fir.store %[[RES]] to %[[ARG0]] : !fir.ref<f32>
+!CHECK:  omp.yield(%[[ARG0]] : !fir.ref<f32>)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : !fir.ref<i32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK:  %[[REF:.*]] = fir.alloca i32
+!CHECKL  fir.store [[%C0_1]] to %[[REF]] : !fir.ref<i32>
+!CHECK:  omp.yield(%[[REF]] : !fir.ref<i32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<i32>):
+!CHECK:  %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<i32>
+!CHECK:  %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<i32>
+!CHECK:  %[[RES:.*]] = arith.addi %[[LD0]], %[[LD1]] : i32
+!CHECK:  fir.store %[[RES]] to %[[ARG0]] : !fir.ref<i32>
+!CHECK:  omp.yield(%[[ARG0]] : !fir.ref<i32>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK:    %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_int_add
+    integer :: i
+    i = 0
+
+    !$omp parallel reduction(+:i)
+    i = i + 1
+    !$omp end parallel
+
+    print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  omp.parallel byref reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK:    %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_real_add
+    real :: r
+    r = 0.0
+
+    !$omp parallel reduction(+:r)
+    r = r + 1.5
+    !$omp end parallel
+
+    print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK:    %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK:    fir.store %[[RES1]] to %[[PRV1]]
+!CHECK:    %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK:    %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES0]] t...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/84304


More information about the llvm-branch-commits mailing list