[flang-commits] [flang] [mlir] [MLIR][OpenMP] Normalize handling of entry block arguments (PR #109808)

via flang-commits flang-commits at lists.llvm.org
Tue Sep 24 07:56:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch introduces a new MLIR interface for the OpenMP dialect aimed at providing a uniform way of verifying and handling entry block arguments defined by OpenMP clauses.

The approach consists in defining a set of overrideable methods that return the number of block arguments the operation holds regarding each of the clauses that may define them. These by default return 0, but they are overriden by the corresponding clause through the `extraClassDeclaration` mechanism.

Another set of interface methods to get the actual lists of block arguments is defined, which is implemented based on the previously described methods. These implicitly define a standardized ordering between the list of block arguments associated to each clause, based on the alphabetical ordering of their names. They should be the preferred way of matching operation arguments and entry block arguments to that operation's first region.

Some updates are made to the printing/parsing of `omp.parallel` to follow the expected order between `private` and `reduction` clauses, as well as the MLIR to LLVM IR translation pass to access block arguments using the new interface. Unit tests of operations impacted by additional verification checks and sorting of entry block arguments.

---

Patch is 32.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109808.diff


11 Files Affected:

- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+19-10) 
- (modified) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+2-2) 
- (modified) flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 (+2-2) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+28-11) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+6-1) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+80) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+17-17) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+19-26) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+4) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+17-6) 
- (modified) mlir/test/Target/LLVMIR/openmp-private.mlir (+1-1) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 960286732c90c2..e9095d631beb7b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
 /// \param [in] infoAccessor       - for a private variable, this returns the
 /// data we want to merge: type or location.
 /// \param [out] allRegionArgsInfo - the merged list of region info.
+/// \param [in] addBeforePrivate - `true` if the passed information goes before
+/// private information.
 template <typename OMPOp, typename InfoTy>
 static void
 mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
                      llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
-                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
+                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
+                     bool addBeforePrivate) {
   mlir::OperandRange privateVars = op.getPrivateVars();
 
-  llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
-                  [](InfoTy i) { return i; });
+  if (addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
+
   llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
                   infoAccessor);
+
+  if (!addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
   mergePrivateVarsInfo(targetOp, mapSymTypes,
                        llvm::function_ref<mlir::Type(mlir::Value)>{
                            [](mlir::Value v) { return v.getType(); }},
-                       allRegionArgTypes);
+                       allRegionArgTypes, /*addBeforePrivate=*/true);
 
   mergePrivateVarsInfo(targetOp, mapSymLocs,
                        llvm::function_ref<mlir::Location(mlir::Value)>{
                            [](mlir::Value v) { return v.getLoc(); }},
-                       allRegionArgLocs);
+                       allRegionArgLocs, /*addBeforePrivate=*/true);
 
   mlir::Block *regionBlock = firOpBuilder.createBlock(
       &region, {}, allRegionArgTypes, allRegionArgLocs);
@@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
     mergePrivateVarsInfo(parallelOp, reductionTypes,
                          llvm::function_ref<mlir::Type(mlir::Value)>{
                              [](mlir::Value v) { return v.getType(); }},
-                         allRegionArgTypes);
+                         allRegionArgTypes, /*addBeforePrivate=*/false);
 
     llvm::SmallVector<mlir::Location> allRegionArgLocs;
     mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
                          llvm::function_ref<mlir::Location(mlir::Value)>{
                              [](mlir::Value v) { return v.getLoc(); }},
-                         allRegionArgLocs);
+                         allRegionArgLocs, /*addBeforePrivate=*/false);
 
     mlir::Region &region = parallelOp.getRegion();
     firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
                              allRegionArgLocs);
 
-    llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
-    allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
-                      dsp->getDelayedPrivSymbols().end());
+    llvm::SmallVector<const semantics::Symbol *> allSymbols(
+        dsp->getDelayedPrivSymbols());
+    allSymbols.append(reductionSyms.begin(), reductionSyms.end());
 
     unsigned argIdx = 0;
     for (const semantics::Symbol *arg : allSymbols) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 29439571179322..6c00bb23f15b96 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
index d814b2b0ff0f31..38139e52ce95cb 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
@@ -29,5 +29,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index c579ba6e751d2b..876d53766a0ca1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
       return SmallVector<Value>(getInReductionVars().begin(),
                                 getInReductionVars().end());
     }
+
+    unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
+    // Not adding the BlockArgOpenMPOpInterface here because omp.target is the
+    // only operation defining block arguments for `map` clauses.
     MapClauseOwningOpInterface
   ];
 
@@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
     OptionalAttr<SymbolRefArrayAttr>:$private_syms
@@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
       custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
+  }];
+
   // TODO: Add description.
 }
 
@@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
   let extraClassDeclaration = [{
     /// Returns the number of reduction variables.
     unsigned getNumReductionVars() { return getReductionVars().size(); }
+    unsigned numReductionBlockArgs() { return getReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
                                $task_reduction_byref, $task_reduction_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    /// Returns the reduction variables.
+    SmallVector<Value> getReductionVars() {
+      return SmallVector<Value>(getTaskReductionVars().begin(),
+                                getTaskReductionVars().end());
+    }
+
+    unsigned numTaskReductionBlockArgs() {
+      return getTaskReductionVars().size();
+    }
+  }];
+
   let description = [{
     The `task_reduction` clause specifies a reduction among tasks. For each list
     item, the number of copies is unspecified. Any copies associated with the
@@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
     attribute, and whether the reduction variable should be passed into the
     reduction region by value or by reference in `task_reduction_byref`.
   }];
-
-  let extraClassDeclaration = [{
-    /// Returns the reduction variables.
-    SmallVector<Value> getReductionVars() {
-      return SmallVector<Value>(getTaskReductionVars().begin(),
-                                getTaskReductionVars().end());
-    }
-  }];
 }
 
 def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..326bdd3bbc9463 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
 //===----------------------------------------------------------------------===//
 
 def TargetOp : OpenMP_Op<"target", traits = [
-    AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
+    AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
+    OutlineableOpenMPOpInterface
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
@@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
 
+  let extraClassDeclaration = [{
+    unsigned numMapBlockArgs() { return getMapVars().size(); }
+  }] # clausesExtraClassDeclaration;
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0078e22b1c89a6..030075eaf45b14 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -15,6 +15,86 @@
 
 include "mlir/IR/OpBase.td"
 
+def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
+  let description = [{
+    OpenMP operations that define entry block arguments as part of the
+    representation of its clauses.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    // Default-implemented methods to be overriden by the corresponding clauses.
+    InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
+                    "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `map`.",
+                    "unsigned", "numMapBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `private`.",
+                    "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `reduction`.",
+                    "unsigned", "numReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
+                    "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+
+    // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get block arguments defined by `in_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getInReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().take_front(
+          $_op.numInReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `map`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getMapBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs(), $_op.numMapBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `private`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getPrivateBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs(),
+          $_op.numPrivateBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs(), $_op.numReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `task_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getTaskReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
+          $_op.numTaskReductionBlockArgs());
+    }]>,
+  ];
+
+  let verify = [{
+    auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
+    unsigned expectedArgs = iface.numInReductionBlockArgs() +
+        iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
+        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+    if ($_op->getRegion(0).getNumArguments() < expectedArgs)
+      return $_op->emitOpError() << "expected at least " << expectedArgs
+                                 << " entry block argument(s)";
+    return ::mlir::success();
+  }];
+}
+
 def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
   let description = [{
     OpenMP operations whose region will be outlined will implement this
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db47276dcefe95..7ca7a2afbdbbc4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
 
-  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
-    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
-                                         reductionTypes, reductionByref,
-                                         reductionSyms, regionPrivateArgs)))
-      return failure();
-  }
-
   if (succeeded(parser.parseOptionalKeyword("private"))) {
     auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
     if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
@@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
     }
   }
 
+  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
+                                         reductionTypes, reductionByref,
+                                         reductionSyms, regionPrivateArgs)))
+      return failure();
+  }
+
   return parser.parseRegion(region, regionPrivateArgs);
 }
 
@@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                 DenseBoolArrayAttr reductionByref,
                                 ArrayAttr reductionSyms, ValueRange privateVars,
                                 TypeRange privateTypes, ArrayAttr privateSyms) {
-  if (reductionSyms) {
-    auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
-    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
-                              reductionTypes, reductionByref, reductionSyms);
-  }
-
   if (privateSyms) {
     auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
-                                 argsBegin + reductionVars.size() +
-                                     privateTypes.size());
+    MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
     mlir::SmallVector<bool> isByRefVec;
     isByRefVec.resize(privateTypes.size(), false);
     DenseBoolArrayAttr isByRef =
@@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                               privateTypes, isByRef, privateSyms);
   }
 
+  if (reductionSyms) {
+    auto *argsBegin = region.front().getArguments().begin();
+    MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
+                                 argsBegin + privateVars.size() +
+                                     reductionTypes.size());
+    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
+                              reductionTypes, reductionByref, reductionSyms);
+  }
+
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0cba8d80681f13..769cdd57656b51 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -920,7 +920,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      sectionsOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -954,8 +954,10 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
       // variables
       assert(region.getNumArguments() ==
              sectionsOp.getRegion().getNumArguments());
-      for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
-               sectionsOp.getRegion().getArguments(), region.getArguments())) {
+      for (auto [sectionsArg, sectionArg] :
+           llvm::zip_equal(cast<omp::BlockArgOpenMPOpInterface>(*sectionsOp)
+                               .getReductionBlockArgs(),
+                           region.getArguments())) {
         llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
         assert(llvmVal);
         moduleTranslation.mapValue(sectionArg, llvmVal);
@@ -1216,7 +1218,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      wsloopOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -1329,31 +1331,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 class OmpParallelOpConversionManager {
 public:
   OmpParallelOpConversionManager(omp::ParallelOp opInst)
-      : region(opInst.getRegion()), privateVars(opInst.ge...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/109808


More information about the flang-commits mailing list