[flang-commits] [flang] [mlir] [MLIR][OpenMP] Normalize handling of entry block arguments (PR #109808)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Mon Sep 30 04:43:17 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/109808

>From 32fdaf6f83833d2439327748baf6a932c6117cfd Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 18 Sep 2024 11:47:36 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Normalize handling of entry block
 arguments

This patch introduces a new MLIR interface for the OpenMP dialect aimed at
providing a uniform way of verifying and handling entry block arguments defined
by OpenMP clauses.

The approach consists in defining a set of overrideable methods that return the
number of block arguments the operation holds regarding each of the clauses
that may define them. These by default return 0, but they are overriden by the
corresponding clause through the `extraClassDeclaration` mechanism.

Another set of interface methods to get the actual lists of block arguments is
defined, which is implemented based on the previously described methods. These
implicitly define a standardized ordering between the list of block arguments
associated to each clause, based on the alphabetical ordering of their names.
They should be the preferred way of matching operation arguments and entry
block arguments to that operation's first region.

Some updates are made to the printing/parsing of `omp.parallel` to follow the
expected order between `private` and `reduction` clauses, as well as the MLIR
to LLVM IR translation pass to access block arguments using the new interface.
Unit tests of operations impacted by additional verification checks and
sorting of entry block arguments.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 29 ++++---
 .../delayed-privatization-reduction-byref.f90 |  4 +-
 .../delayed-privatization-reduction.f90       |  4 +-
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 39 ++++++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  7 +-
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 80 +++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 34 ++++----
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 45 +++++------
 mlir/test/Dialect/OpenMP/invalid.mlir         |  4 +
 mlir/test/Dialect/OpenMP/ops.mlir             | 23 ++++--
 mlir/test/Target/LLVMIR/openmp-private.mlir   |  2 +-
 11 files changed, 195 insertions(+), 76 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d528772f28724b..17ebf93edcce1f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
 /// \param [in] infoAccessor       - for a private variable, this returns the
 /// data we want to merge: type or location.
 /// \param [out] allRegionArgsInfo - the merged list of region info.
+/// \param [in] addBeforePrivate - `true` if the passed information goes before
+/// private information.
 template <typename OMPOp, typename InfoTy>
 static void
 mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
                      llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
-                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
+                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
+                     bool addBeforePrivate) {
   mlir::OperandRange privateVars = op.getPrivateVars();
 
-  llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
-                  [](InfoTy i) { return i; });
+  if (addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
+
   llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
                   infoAccessor);
+
+  if (!addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
   mergePrivateVarsInfo(targetOp, mapSymTypes,
                        llvm::function_ref<mlir::Type(mlir::Value)>{
                            [](mlir::Value v) { return v.getType(); }},
-                       allRegionArgTypes);
+                       allRegionArgTypes, /*addBeforePrivate=*/true);
 
   mergePrivateVarsInfo(targetOp, mapSymLocs,
                        llvm::function_ref<mlir::Location(mlir::Value)>{
                            [](mlir::Value v) { return v.getLoc(); }},
-                       allRegionArgLocs);
+                       allRegionArgLocs, /*addBeforePrivate=*/true);
 
   mlir::Block *regionBlock = firOpBuilder.createBlock(
       &region, {}, allRegionArgTypes, allRegionArgLocs);
@@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
     mergePrivateVarsInfo(parallelOp, reductionTypes,
                          llvm::function_ref<mlir::Type(mlir::Value)>{
                              [](mlir::Value v) { return v.getType(); }},
-                         allRegionArgTypes);
+                         allRegionArgTypes, /*addBeforePrivate=*/false);
 
     llvm::SmallVector<mlir::Location> allRegionArgLocs;
     mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
                          llvm::function_ref<mlir::Location(mlir::Value)>{
                              [](mlir::Value v) { return v.getLoc(); }},
-                         allRegionArgLocs);
+                         allRegionArgLocs, /*addBeforePrivate=*/false);
 
     mlir::Region &region = parallelOp.getRegion();
     firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
                              allRegionArgLocs);
 
-    llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
-    allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
-                      dsp->getDelayedPrivSymbols().end());
+    llvm::SmallVector<const semantics::Symbol *> allSymbols(
+        dsp->getDelayedPrivSymbols());
+    allSymbols.append(reductionSyms.begin(), reductionSyms.end());
 
     unsigned argIdx = 0;
     for (const semantics::Symbol *arg : allSymbols) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 29439571179322..6c00bb23f15b96 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
index d814b2b0ff0f31..38139e52ce95cb 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
@@ -29,5 +29,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index c579ba6e751d2b..876d53766a0ca1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
       return SmallVector<Value>(getInReductionVars().begin(),
                                 getInReductionVars().end());
     }
+
+    unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
+    // Not adding the BlockArgOpenMPOpInterface here because omp.target is the
+    // only operation defining block arguments for `map` clauses.
     MapClauseOwningOpInterface
   ];
 
@@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
     OptionalAttr<SymbolRefArrayAttr>:$private_syms
@@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
       custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
+  }];
+
   // TODO: Add description.
 }
 
@@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
   let extraClassDeclaration = [{
     /// Returns the number of reduction variables.
     unsigned getNumReductionVars() { return getReductionVars().size(); }
+    unsigned numReductionBlockArgs() { return getReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
                                $task_reduction_byref, $task_reduction_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    /// Returns the reduction variables.
+    SmallVector<Value> getReductionVars() {
+      return SmallVector<Value>(getTaskReductionVars().begin(),
+                                getTaskReductionVars().end());
+    }
+
+    unsigned numTaskReductionBlockArgs() {
+      return getTaskReductionVars().size();
+    }
+  }];
+
   let description = [{
     The `task_reduction` clause specifies a reduction among tasks. For each list
     item, the number of copies is unspecified. Any copies associated with the
@@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
     attribute, and whether the reduction variable should be passed into the
     reduction region by value or by reference in `task_reduction_byref`.
   }];
-
-  let extraClassDeclaration = [{
-    /// Returns the reduction variables.
-    SmallVector<Value> getReductionVars() {
-      return SmallVector<Value>(getTaskReductionVars().begin(),
-                                getTaskReductionVars().end());
-    }
-  }];
 }
 
 def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..326bdd3bbc9463 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
 //===----------------------------------------------------------------------===//
 
 def TargetOp : OpenMP_Op<"target", traits = [
-    AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
+    AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
+    OutlineableOpenMPOpInterface
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
@@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
 
+  let extraClassDeclaration = [{
+    unsigned numMapBlockArgs() { return getMapVars().size(); }
+  }] # clausesExtraClassDeclaration;
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0078e22b1c89a6..030075eaf45b14 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -15,6 +15,86 @@
 
 include "mlir/IR/OpBase.td"
 
+def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
+  let description = [{
+    OpenMP operations that define entry block arguments as part of the
+    representation of its clauses.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    // Default-implemented methods to be overriden by the corresponding clauses.
+    InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
+                    "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `map`.",
+                    "unsigned", "numMapBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `private`.",
+                    "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `reduction`.",
+                    "unsigned", "numReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
+                    "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+
+    // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get block arguments defined by `in_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getInReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().take_front(
+          $_op.numInReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `map`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getMapBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs(), $_op.numMapBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `private`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getPrivateBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs(),
+          $_op.numPrivateBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs(), $_op.numReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `task_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getTaskReductionBlockArgs", (ins), [{
+      return $_op->getRegion(0).getArguments().slice(
+          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+          $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
+          $_op.numTaskReductionBlockArgs());
+    }]>,
+  ];
+
+  let verify = [{
+    auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
+    unsigned expectedArgs = iface.numInReductionBlockArgs() +
+        iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
+        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+    if ($_op->getRegion(0).getNumArguments() < expectedArgs)
+      return $_op->emitOpError() << "expected at least " << expectedArgs
+                                 << " entry block argument(s)";
+    return ::mlir::success();
+  }];
+}
+
 def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
   let description = [{
     OpenMP operations whose region will be outlined will implement this
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90bf5df67b03ba..b50c036877882a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
 
-  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
-    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
-                                         reductionTypes, reductionByref,
-                                         reductionSyms, regionPrivateArgs)))
-      return failure();
-  }
-
   if (succeeded(parser.parseOptionalKeyword("private"))) {
     auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
     if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
@@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
     }
   }
 
+  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
+                                         reductionTypes, reductionByref,
+                                         reductionSyms, regionPrivateArgs)))
+      return failure();
+  }
+
   return parser.parseRegion(region, regionPrivateArgs);
 }
 
@@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                 DenseBoolArrayAttr reductionByref,
                                 ArrayAttr reductionSyms, ValueRange privateVars,
                                 TypeRange privateTypes, ArrayAttr privateSyms) {
-  if (reductionSyms) {
-    auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
-    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
-                              reductionTypes, reductionByref, reductionSyms);
-  }
-
   if (privateSyms) {
     auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
-                                 argsBegin + reductionVars.size() +
-                                     privateTypes.size());
+    MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
     mlir::SmallVector<bool> isByRefVec;
     isByRefVec.resize(privateTypes.size(), false);
     DenseBoolArrayAttr isByRef =
@@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                               privateTypes, isByRef, privateSyms);
   }
 
+  if (reductionSyms) {
+    auto *argsBegin = region.front().getArguments().begin();
+    MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
+                                 argsBegin + privateVars.size() +
+                                     reductionTypes.size());
+    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
+                              reductionTypes, reductionByref, reductionSyms);
+  }
+
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d788fe1f6165e6..bbc0b518e99bfc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -920,7 +920,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      sectionsOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -954,8 +954,10 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
       // variables
       assert(region.getNumArguments() ==
              sectionsOp.getRegion().getNumArguments());
-      for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
-               sectionsOp.getRegion().getArguments(), region.getArguments())) {
+      for (auto [sectionsArg, sectionArg] :
+           llvm::zip_equal(cast<omp::BlockArgOpenMPOpInterface>(*sectionsOp)
+                               .getReductionBlockArgs(),
+                           region.getArguments())) {
         llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
         assert(llvmVal);
         moduleTranslation.mapValue(sectionArg, llvmVal);
@@ -1216,7 +1218,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      wsloopOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -1329,31 +1331,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 class OmpParallelOpConversionManager {
 public:
   OmpParallelOpConversionManager(omp::ParallelOp opInst)
-      : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
-        privateArgBeginIdx(opInst.getNumReductionVars()),
-        privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
-    auto privateVarsIt = privateVars.begin();
-
-    for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
-         ++argIdx, ++privateVarsIt)
-      mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
-                                       *privateVarsIt, region);
+      : region(opInst.getRegion()),
+        privateBlockArgs(cast<omp::BlockArgOpenMPOpInterface>(*opInst)
+                             .getPrivateBlockArgs()),
+        privateVars(opInst.getPrivateVars()) {
+    for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars))
+      mlir::replaceAllUsesInRegionWith(blockArg, var, region);
   }
 
   ~OmpParallelOpConversionManager() {
-    auto privateVarsIt = privateVars.begin();
-
-    for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
-         ++argIdx, ++privateVarsIt)
-      mlir::replaceAllUsesInRegionWith(*privateVarsIt,
-                                       region.getArgument(argIdx), region);
+    for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars))
+      mlir::replaceAllUsesInRegionWith(var, blockArg, region);
   }
 
 private:
   Region ®ion;
+  llvm::MutableArrayRef<BlockArgument> privateBlockArgs;
   OperandRange privateVars;
-  unsigned privateArgBeginIdx;
-  unsigned privateArgEndIdx;
 };
 
 /// Converts the OpenMP parallel operation to LLVM IR.
@@ -1382,9 +1376,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     DenseMap<Value, llvm::Value *> reductionVariableMap;
 
     MutableArrayRef<BlockArgument> reductionArgs =
-        opInst.getRegion().getArguments().slice(
-            opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(),
-            opInst.getNumReductionVars());
+        cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
 
     allocaIP =
         InsertPointTy(allocaIP.getBlock(),
@@ -3399,6 +3391,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   auto &targetRegion = targetOp.getRegion();
   DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
   SmallVector<Value> mapVars = targetOp.getMapVars();
+  ArrayRef<BlockArgument> mapBlockArgs =
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs();
   llvm::Function *llvmOutlinedFn = nullptr;
 
   // TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3427,11 +3421,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       llvmOutlinedFn->addFnAttr(attr);
 
     builder.restoreIP(codeGenIP);
-    for (auto [argIndex, mapOp] : llvm::enumerate(mapVars)) {
+    for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
       auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
       llvm::Value *mapOpValue =
           moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
-      const auto &arg = targetRegion.front().getArgument(argIndex);
       moduleTranslation.mapValue(arg, mapOpValue);
     }
     llvm::BasicBlock *exitBlock = convertOmpOpRegions(
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d8745f1015af83..5e182dea52b40e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1459,6 +1459,7 @@ func.func @omp_sections(%data_var : memref<i32>) -> () {
 func.func @omp_sections(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected as many reduction symbol references as reduction variables}}
   "omp.sections" (%data_var) ({
+  ^bb0(%arg0: memref<i32>):
     omp.terminator
   }) {operandSegmentSizes = array<i32: 0,0,0,1>} : (memref<i32>) -> ()
   return
@@ -1650,6 +1651,7 @@ func.func @omp_task_depend(%data_var: memref<i32>) {
 func.func @omp_task(%ptr: !llvm.ptr) {
   // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
   omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -1674,6 +1676,7 @@ combiner {
 func.func @omp_task(%ptr: !llvm.ptr) {
   // expected-error @below {{op accumulator variable used more than once}}
   omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr, @add_f32 -> %ptr : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -1704,6 +1707,7 @@ atomic {
 func.func @omp_task(%mem: memref<1xf32>) {
   // expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr')}}
   omp.task in_reduction(@add_i32 -> %mem : memref<1xf32>) {
+  ^bb0(%arg0: memref<1xf32>):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e7d3e67ca7e05b..2116071f8523a3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1096,6 +1096,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
   // CHECK: omp.teams reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
   omp.teams reduction(@add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     %1 = arith.constant 2.0 : f32
     // CHECK: omp.terminator
     omp.terminator
@@ -1104,6 +1105,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
   // Test reduction byref
   // CHECK: omp.teams reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr) {
   omp.teams reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     %1 = arith.constant 2.0 : f32
     // CHECK: omp.terminator
     omp.terminator
@@ -1125,6 +1127,7 @@ func.func @sections_reduction() {
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
   // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr)
   omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.section
     omp.section {
       %1 = arith.constant 2.0 : f32
@@ -1146,6 +1149,7 @@ func.func @sections_reduction_byref() {
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
   // CHECK: omp.sections reduction(byref @add_f32 -> {{.+}} : !llvm.ptr)
   omp.sections reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.section
     omp.section {
       %1 = arith.constant 2.0 : f32
@@ -1245,6 +1249,7 @@ func.func @sections_reduction2() {
   %0 = memref.alloca() : memref<1xf32>
   // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
   omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) {
+  ^bb0(%arg0: !llvm.ptr):
     omp.section {
       %1 = arith.constant 2.0 : f32
       omp.terminator
@@ -1901,6 +1906,7 @@ func.func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
 
     // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr)
   "omp.sections" (%redn_var) ({
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.terminator
     omp.terminator
   }) {operandSegmentSizes = array<i32: 0,0,0,1>, reduction_byref = array<i1: false>, reduction_syms=[@add_f32]} : (!llvm.ptr) -> ()
@@ -1913,6 +1919,7 @@ func.func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
 
   // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) {
   omp.sections reduction(@add_f32 -> %redn_var : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.terminator
     omp.terminator
   }
@@ -2087,6 +2094,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
   %1 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
   // CHECK: omp.task in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
   omp.task in_reduction(@add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -2096,6 +2104,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
   // Checking `in_reduction` clause (mixed) byref
   // CHECK: omp.task in_reduction(byref @add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
   omp.task in_reduction(byref @add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -2129,6 +2138,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
       in_reduction(@add_f32 -> %0 : !llvm.ptr, byref @add_f32 -> %1 : !llvm.ptr)
       // CHECK-SAME: priority(%[[i32_var]] : i32) untied
       priority(%i32_var : i32) untied {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -2306,6 +2316,7 @@ func.func @omp_taskgroup_clauses() -> () {
   %testf32 = "test.f32"() : () -> (!llvm.ptr)
   // CHECK: omp.taskgroup allocate(%{{.+}}: memref<i32> -> %{{.+}}: memref<i32>) task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr)
   omp.taskgroup allocate(%testmemref : memref<i32> -> %testmemref : memref<i32>) task_reduction(@add_f32 -> %testf32 : !llvm.ptr) {
+  ^bb0(%arg0 : !llvm.ptr):
     // CHECK: omp.task
     omp.task {
       "test.foo"() : () -> ()
@@ -2783,15 +2794,15 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
 // CHECK-LABEL: parallel_op_reduction_and_private
 func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) {
   // CHECK: omp.parallel
-  // CHECK-SAME: reduction(
-  // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
-  // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
-  //
   // CHECK-SAME: private(
   // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr,
   // CHECK-SAME: @y.privatizer %[[PRIV_VAR2:[^[:space:]]+]] -> %[[PRIV_ARG2:[^[:space:]]+]] : !llvm.ptr)
-  omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr)
-               private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) {
+  //
+  // CHECK-SAME: reduction(
+  // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
+  // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
+  omp.parallel private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr)
+               reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr) {
     // CHECK: llvm.load %[[PRIV_ARG]]
     %0 = llvm.load %priv_arg : !llvm.ptr -> f32
     // CHECK: llvm.load %[[PRIV_ARG2]]
diff --git a/mlir/test/Target/LLVMIR/openmp-private.mlir b/mlir/test/Target/LLVMIR/openmp-private.mlir
index 21167668bbee16..a06e44fc5cfe01 100644
--- a/mlir/test/Target/LLVMIR/openmp-private.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-private.mlir
@@ -206,7 +206,7 @@ llvm.func @private_and_reduction_() attributes {fir.internal_name = "_QPprivate_
   %0 = llvm.mlir.constant(1 : i64) : i64
   %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
   %2 = llvm.alloca %0 x f32 {bindc_name = "to_priv"} : (i64) -> !llvm.ptr
-  omp.parallel reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) private(@privatizer.part %2 -> %arg1 : !llvm.ptr) {
+  omp.parallel private(@privatizer.part %2 -> %arg1 : !llvm.ptr) reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) {
     %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
     %4 = llvm.mlir.constant(8.000000e+00 : f32) : f32
     llvm.store %4, %arg1 : f32, !llvm.ptr

>From 4555bc479a8b400a49fc96cad1169541861b73ee Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 27 Sep 2024 16:57:26 +0100
Subject: [PATCH 2/2] Improve interface implementation to make adding clauses
 less error-prone

---
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 46 +++++++++++++++----
 1 file changed, 37 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 030075eaf45b14..1aaa4060793995 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -47,38 +47,66 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
     }]>,
 
     // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
+                    "unsigned", "getInReductionBlockArgsStart", (ins), [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `map`.",
+                    "unsigned", "getMapBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getInReductionBlockArgsStart() +
+             $_op.numInReductionBlockArgs();
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `private`.",
+                    "unsigned", "getPrivateBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs();
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `reduction`.",
+                    "unsigned", "getReductionBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs();
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.",
+                    "unsigned", "getTaskReductionBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
+    }]>,
+
     InterfaceMethod<"Get block arguments defined by `in_reduction`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getInReductionBlockArgs", (ins), [{
-      return $_op->getRegion(0).getArguments().take_front(
-          $_op.numInReductionBlockArgs());
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs());
     }]>,
     InterfaceMethod<"Get block arguments defined by `map`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getMapBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
       return $_op->getRegion(0).getArguments().slice(
-          $_op.numInReductionBlockArgs(), $_op.numMapBlockArgs());
+          iface.getMapBlockArgsStart(), $_op.numMapBlockArgs());
     }]>,
     InterfaceMethod<"Get block arguments defined by `private`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getPrivateBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
       return $_op->getRegion(0).getArguments().slice(
-          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs(),
-          $_op.numPrivateBlockArgs());
+          iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs());
     }]>,
     InterfaceMethod<"Get block arguments defined by `reduction`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getReductionBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
       return $_op->getRegion(0).getArguments().slice(
-          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
-          $_op.numPrivateBlockArgs(), $_op.numReductionBlockArgs());
+          iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs());
     }]>,
     InterfaceMethod<"Get block arguments defined by `task_reduction`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getTaskReductionBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
       return $_op->getRegion(0).getArguments().slice(
-          $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
-          $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
+          iface.getTaskReductionBlockArgsStart(),
           $_op.numTaskReductionBlockArgs());
     }]>,
   ];



More information about the flang-commits mailing list