[Mlir-commits] [flang] [mlir] [MLIR][OpenMP] Normalize handling of entry block arguments (PR #109808)

Sergio Afonso llvmlistbot at llvm.org
Tue Oct 1 06:29:41 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/109808

>From e5a0b3c4a940b9fd90ac730678e9651eff9c2a90 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 18 Sep 2024 11:47:36 +0100
Subject: [PATCH] [MLIR][OpenMP] Normalize handling of entry block arguments

This patch introduces a new MLIR interface for the OpenMP dialect aimed at
providing a uniform way of verifying and handling entry block arguments defined
by OpenMP clauses.

The approach consists in defining a set of overrideable methods that return the
number of block arguments the operation holds regarding each of the clauses
that may define them. These by default return 0, but they are overriden by the
corresponding clause through the `extraClassDeclaration` mechanism.

Another set of interface methods to get the actual lists of block arguments is
defined, which is implemented based on the previously described methods. These
implicitly define a standardized ordering between the list of block arguments
associated to each clause, based on the alphabetical ordering of their names.
They should be the preferred way of matching operation arguments and entry
block arguments to that operation's first region.

Some updates are made to the printing/parsing of `omp.parallel` to follow the
expected order between `private` and `reduction` clauses, as well as the MLIR
to LLVM IR translation pass to access block arguments using the new interface.
Unit tests of operations impacted by additional verification checks and
sorting of entry block arguments.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  29 +++--
 .../delayed-privatization-reduction-byref.f90 |   4 +-
 .../delayed-privatization-reduction.f90       |   4 +-
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  39 +++++--
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   7 +-
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 108 ++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  34 +++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  50 +++-----
 mlir/test/Dialect/OpenMP/invalid.mlir         |   4 +
 mlir/test/Dialect/OpenMP/ops.mlir             |  23 +++-
 mlir/test/Target/LLVMIR/openmp-private.mlir   |   2 +-
 11 files changed, 222 insertions(+), 82 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d528772f28724b..17ebf93edcce1f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
 /// \param [in] infoAccessor       - for a private variable, this returns the
 /// data we want to merge: type or location.
 /// \param [out] allRegionArgsInfo - the merged list of region info.
+/// \param [in] addBeforePrivate - `true` if the passed information goes before
+/// private information.
 template <typename OMPOp, typename InfoTy>
 static void
 mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
                      llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
-                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
+                     llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
+                     bool addBeforePrivate) {
   mlir::OperandRange privateVars = op.getPrivateVars();
 
-  llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
-                  [](InfoTy i) { return i; });
+  if (addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
+
   llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
                   infoAccessor);
+
+  if (!addBeforePrivate)
+    llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+                    [](InfoTy i) { return i; });
 }
 
 //===----------------------------------------------------------------------===//
@@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
   mergePrivateVarsInfo(targetOp, mapSymTypes,
                        llvm::function_ref<mlir::Type(mlir::Value)>{
                            [](mlir::Value v) { return v.getType(); }},
-                       allRegionArgTypes);
+                       allRegionArgTypes, /*addBeforePrivate=*/true);
 
   mergePrivateVarsInfo(targetOp, mapSymLocs,
                        llvm::function_ref<mlir::Location(mlir::Value)>{
                            [](mlir::Value v) { return v.getLoc(); }},
-                       allRegionArgLocs);
+                       allRegionArgLocs, /*addBeforePrivate=*/true);
 
   mlir::Block *regionBlock = firOpBuilder.createBlock(
       &region, {}, allRegionArgTypes, allRegionArgLocs);
@@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
     mergePrivateVarsInfo(parallelOp, reductionTypes,
                          llvm::function_ref<mlir::Type(mlir::Value)>{
                              [](mlir::Value v) { return v.getType(); }},
-                         allRegionArgTypes);
+                         allRegionArgTypes, /*addBeforePrivate=*/false);
 
     llvm::SmallVector<mlir::Location> allRegionArgLocs;
     mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
                          llvm::function_ref<mlir::Location(mlir::Value)>{
                              [](mlir::Value v) { return v.getLoc(); }},
-                         allRegionArgLocs);
+                         allRegionArgLocs, /*addBeforePrivate=*/false);
 
     mlir::Region &region = parallelOp.getRegion();
     firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
                              allRegionArgLocs);
 
-    llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
-    allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
-                      dsp->getDelayedPrivSymbols().end());
+    llvm::SmallVector<const semantics::Symbol *> allSymbols(
+        dsp->getDelayedPrivSymbols());
+    allSymbols.append(reductionSyms.begin(), reductionSyms.end());
 
     unsigned argIdx = 0;
     for (const semantics::Symbol *arg : allSymbols) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 29439571179322..6c00bb23f15b96 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
index d814b2b0ff0f31..38139e52ce95cb 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
@@ -29,5 +29,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index c579ba6e751d2b..876d53766a0ca1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
       return SmallVector<Value>(getInReductionVars().begin(),
                                 getInReductionVars().end());
     }
+
+    unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
+    // Not adding the BlockArgOpenMPOpInterface here because omp.target is the
+    // only operation defining block arguments for `map` clauses.
     MapClauseOwningOpInterface
   ];
 
@@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
   let arguments = (ins
     Variadic<AnyType>:$private_vars,
     OptionalAttr<SymbolRefArrayAttr>:$private_syms
@@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
       custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
+  }];
+
   // TODO: Add description.
 }
 
@@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
   let extraClassDeclaration = [{
     /// Returns the number of reduction variables.
     unsigned getNumReductionVars() { return getReductionVars().size(); }
+    unsigned numReductionBlockArgs() { return getReductionVars().size(); }
   }];
 
   // Description varies depending on the operation.
@@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let traits = [
-    ReductionClauseInterface
+    BlockArgOpenMPOpInterface, ReductionClauseInterface
   ];
 
   let arguments = (ins
@@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
                                $task_reduction_byref, $task_reduction_syms) `)`
   }];
 
+  let extraClassDeclaration = [{
+    /// Returns the reduction variables.
+    SmallVector<Value> getReductionVars() {
+      return SmallVector<Value>(getTaskReductionVars().begin(),
+                                getTaskReductionVars().end());
+    }
+
+    unsigned numTaskReductionBlockArgs() {
+      return getTaskReductionVars().size();
+    }
+  }];
+
   let description = [{
     The `task_reduction` clause specifies a reduction among tasks. For each list
     item, the number of copies is unspecified. Any copies associated with the
@@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
     attribute, and whether the reduction variable should be passed into the
     reduction region by value or by reference in `task_reduction_byref`.
   }];
-
-  let extraClassDeclaration = [{
-    /// Returns the reduction variables.
-    SmallVector<Value> getReductionVars() {
-      return SmallVector<Value>(getTaskReductionVars().begin(),
-                                getTaskReductionVars().end());
-    }
-  }];
 }
 
 def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..326bdd3bbc9463 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
 //===----------------------------------------------------------------------===//
 
 def TargetOp : OpenMP_Op<"target", traits = [
-    AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
+    AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
+    OutlineableOpenMPOpInterface
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
@@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
 
+  let extraClassDeclaration = [{
+    unsigned numMapBlockArgs() { return getMapVars().size(); }
+  }] # clausesExtraClassDeclaration;
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index ea1e3ebecef7b4..2602384744f233 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -15,6 +15,114 @@
 
 include "mlir/IR/OpBase.td"
 
+def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
+  let description = [{
+    OpenMP operations that define entry block arguments as part of the
+    representation of its clauses.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    // Default-implemented methods to be overriden by the corresponding clauses.
+    InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
+                    "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `map`.",
+                    "unsigned", "numMapBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `private`.",
+                    "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `reduction`.",
+                    "unsigned", "numReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
+                    "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
+
+    // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
+                    "unsigned", "getInReductionBlockArgsStart", (ins), [{
+      return 0;
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `map`.",
+                    "unsigned", "getMapBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getInReductionBlockArgsStart() +
+             $_op.numInReductionBlockArgs();
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `private`.",
+                    "unsigned", "getPrivateBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs();
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `reduction`.",
+                    "unsigned", "getReductionBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs();
+    }]>,
+    InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.",
+                    "unsigned", "getTaskReductionBlockArgsStart", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
+    }]>,
+
+    InterfaceMethod<"Get block arguments defined by `in_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getInReductionBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `map`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getMapBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getMapBlockArgsStart(), $_op.numMapBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `private`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getPrivateBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getReductionBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs());
+    }]>,
+    InterfaceMethod<"Get block arguments defined by `task_reduction`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getTaskReductionBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getTaskReductionBlockArgsStart(),
+          $_op.numTaskReductionBlockArgs());
+    }]>,
+  ];
+
+  let verify = [{
+    auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
+    unsigned expectedArgs = iface.numInReductionBlockArgs() +
+        iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
+        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+    if ($_op->getRegion(0).getNumArguments() < expectedArgs)
+      return $_op->emitOpError() << "expected at least " << expectedArgs
+                                 << " entry block argument(s)";
+    return ::mlir::success();
+  }];
+}
+
 def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
   let description = [{
     OpenMP operations whose region will be outlined will implement this
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 59e71ecc6ec5df..6b1abbc186a19c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
 
-  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
-    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
-                                         reductionTypes, reductionByref,
-                                         reductionSyms, regionPrivateArgs)))
-      return failure();
-  }
-
   if (succeeded(parser.parseOptionalKeyword("private"))) {
     auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
     if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
@@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
     }
   }
 
+  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+    if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
+                                         reductionTypes, reductionByref,
+                                         reductionSyms, regionPrivateArgs)))
+      return failure();
+  }
+
   return parser.parseRegion(region, regionPrivateArgs);
 }
 
@@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                 DenseBoolArrayAttr reductionByref,
                                 ArrayAttr reductionSyms, ValueRange privateVars,
                                 TypeRange privateTypes, ArrayAttr privateSyms) {
-  if (reductionSyms) {
-    auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
-    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
-                              reductionTypes, reductionByref, reductionSyms);
-  }
-
   if (privateSyms) {
     auto *argsBegin = region.front().getArguments().begin();
-    MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
-                                 argsBegin + reductionVars.size() +
-                                     privateTypes.size());
+    MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
     mlir::SmallVector<bool> isByRefVec;
     isByRefVec.resize(privateTypes.size(), false);
     DenseBoolArrayAttr isByRef =
@@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
                               privateTypes, isByRef, privateSyms);
   }
 
+  if (reductionSyms) {
+    auto *argsBegin = region.front().getArguments().begin();
+    MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
+                                 argsBegin + privateVars.size() +
+                                     reductionTypes.size());
+    printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
+                              reductionTypes, reductionByref, reductionSyms);
+  }
+
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index c22d9a189a7e0d..7c89d3bd6ec5a0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -920,7 +920,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      sectionsOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -1216,7 +1216,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   DenseMap<Value, llvm::Value *> reductionVariableMap;
 
   MutableArrayRef<BlockArgument> reductionArgs =
-      wsloopOp.getRegion().getArguments();
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
 
   if (failed(allocAndInitializeReductionVars(
           wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -1329,31 +1329,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 class OmpParallelOpConversionManager {
 public:
   OmpParallelOpConversionManager(omp::ParallelOp opInst)
-      : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
-        privateArgBeginIdx(opInst.getNumReductionVars()),
-        privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
-    auto privateVarsIt = privateVars.begin();
-
-    for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
-         ++argIdx, ++privateVarsIt)
-      mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
-                                       *privateVarsIt, region);
+      : region(opInst.getRegion()),
+        privateBlockArgs(cast<omp::BlockArgOpenMPOpInterface>(*opInst)
+                             .getPrivateBlockArgs()),
+        privateVars(opInst.getPrivateVars()) {
+    for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars))
+      mlir::replaceAllUsesInRegionWith(blockArg, var, region);
   }
 
   ~OmpParallelOpConversionManager() {
-    auto privateVarsIt = privateVars.begin();
-
-    for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
-         ++argIdx, ++privateVarsIt)
-      mlir::replaceAllUsesInRegionWith(*privateVarsIt,
-                                       region.getArgument(argIdx), region);
+    for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars))
+      mlir::replaceAllUsesInRegionWith(var, blockArg, region);
   }
 
 private:
   Region ®ion;
+  llvm::MutableArrayRef<BlockArgument> privateBlockArgs;
   OperandRange privateVars;
-  unsigned privateArgBeginIdx;
-  unsigned privateArgEndIdx;
 };
 
 // Looks up from the operation from and returns the PrivateClauseOp with
@@ -1417,9 +1409,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     DenseMap<Value, llvm::Value *> reductionVariableMap;
 
     MutableArrayRef<BlockArgument> reductionArgs =
-        opInst.getRegion().getArguments().slice(
-            opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(),
-            opInst.getNumReductionVars());
+        cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
 
     allocaIP =
         InsertPointTy(allocaIP.getBlock(),
@@ -3414,6 +3404,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   auto &targetRegion = targetOp.getRegion();
   DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
   SmallVector<Value> mapVars = targetOp.getMapVars();
+  ArrayRef<BlockArgument> mapBlockArgs =
+      cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs();
   llvm::Function *llvmOutlinedFn = nullptr;
 
   // TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3442,11 +3434,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       llvmOutlinedFn->addFnAttr(attr);
 
     builder.restoreIP(codeGenIP);
-    for (auto [argIndex, mapOp] : llvm::enumerate(mapVars)) {
+    for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
       auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
       llvm::Value *mapOpValue =
           moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
-      const auto &arg = targetRegion.front().getArgument(argIndex);
       moduleTranslation.mapValue(arg, mapOpValue);
     }
 
@@ -3457,18 +3448,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
       OperandRange privateVars = targetOp.getPrivateVars();
       std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
-      unsigned numMapVars = targetOp.getMapVars().size();
-      Block &firstTargetBlock = targetRegion.front();
-      BlockArgument *blockArgsStart = firstTargetBlock.getArguments().begin();
-      BlockArgument *privArgsStart = blockArgsStart + numMapVars;
-      BlockArgument *privArgsEnd =
-          privArgsStart + targetOp.getPrivateVars().size();
-      MutableArrayRef privateBlockArgs(privArgsStart, privArgsEnd);
+      MutableArrayRef<BlockArgument> privateBlockArgs =
+          cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
 
       for (auto [privVar, privatizerNameAttr, privBlockArg] :
            llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) {
 
-        SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerNameAttr);
+        SymbolRefAttr privSym = cast<SymbolRefAttr>(privatizerNameAttr);
         omp::PrivateClauseOp privatizer = findPrivatizer(&opInst, privSym);
         if (privatizer.getDataSharingType() ==
                 omp::DataSharingClauseType::FirstPrivate ||
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 35a8883e3a317e..4899583ac3bff2 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1471,6 +1471,7 @@ func.func @omp_sections(%data_var : memref<i32>) -> () {
 func.func @omp_sections(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected as many reduction symbol references as reduction variables}}
   "omp.sections" (%data_var) ({
+  ^bb0(%arg0: memref<i32>):
     omp.terminator
   }) {operandSegmentSizes = array<i32: 0,0,0,1>} : (memref<i32>) -> ()
   return
@@ -1662,6 +1663,7 @@ func.func @omp_task_depend(%data_var: memref<i32>) {
 func.func @omp_task(%ptr: !llvm.ptr) {
   // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
   omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -1686,6 +1688,7 @@ combiner {
 func.func @omp_task(%ptr: !llvm.ptr) {
   // expected-error @below {{op accumulator variable used more than once}}
   omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr, @add_f32 -> %ptr : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -1716,6 +1719,7 @@ atomic {
 func.func @omp_task(%mem: memref<1xf32>) {
   // expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr')}}
   omp.task in_reduction(@add_i32 -> %mem : memref<1xf32>) {
+  ^bb0(%arg0: memref<1xf32>):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e7d3e67ca7e05b..2116071f8523a3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1096,6 +1096,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
   // CHECK: omp.teams reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
   omp.teams reduction(@add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     %1 = arith.constant 2.0 : f32
     // CHECK: omp.terminator
     omp.terminator
@@ -1104,6 +1105,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
   // Test reduction byref
   // CHECK: omp.teams reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr) {
   omp.teams reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     %1 = arith.constant 2.0 : f32
     // CHECK: omp.terminator
     omp.terminator
@@ -1125,6 +1127,7 @@ func.func @sections_reduction() {
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
   // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr)
   omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.section
     omp.section {
       %1 = arith.constant 2.0 : f32
@@ -1146,6 +1149,7 @@ func.func @sections_reduction_byref() {
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
   // CHECK: omp.sections reduction(byref @add_f32 -> {{.+}} : !llvm.ptr)
   omp.sections reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.section
     omp.section {
       %1 = arith.constant 2.0 : f32
@@ -1245,6 +1249,7 @@ func.func @sections_reduction2() {
   %0 = memref.alloca() : memref<1xf32>
   // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
   omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) {
+  ^bb0(%arg0: !llvm.ptr):
     omp.section {
       %1 = arith.constant 2.0 : f32
       omp.terminator
@@ -1901,6 +1906,7 @@ func.func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
 
     // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr)
   "omp.sections" (%redn_var) ({
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.terminator
     omp.terminator
   }) {operandSegmentSizes = array<i32: 0,0,0,1>, reduction_byref = array<i1: false>, reduction_syms=[@add_f32]} : (!llvm.ptr) -> ()
@@ -1913,6 +1919,7 @@ func.func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
 
   // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) {
   omp.sections reduction(@add_f32 -> %redn_var : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr):
     // CHECK: omp.terminator
     omp.terminator
   }
@@ -2087,6 +2094,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
   %1 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
   // CHECK: omp.task in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
   omp.task in_reduction(@add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -2096,6 +2104,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
   // Checking `in_reduction` clause (mixed) byref
   // CHECK: omp.task in_reduction(byref @add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
   omp.task in_reduction(byref @add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -2129,6 +2138,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
       in_reduction(@add_f32 -> %0 : !llvm.ptr, byref @add_f32 -> %1 : !llvm.ptr)
       // CHECK-SAME: priority(%[[i32_var]] : i32) untied
       priority(%i32_var : i32) untied {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     // CHECK: "test.foo"() : () -> ()
     "test.foo"() : () -> ()
     // CHECK: omp.terminator
@@ -2306,6 +2316,7 @@ func.func @omp_taskgroup_clauses() -> () {
   %testf32 = "test.f32"() : () -> (!llvm.ptr)
   // CHECK: omp.taskgroup allocate(%{{.+}}: memref<i32> -> %{{.+}}: memref<i32>) task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr)
   omp.taskgroup allocate(%testmemref : memref<i32> -> %testmemref : memref<i32>) task_reduction(@add_f32 -> %testf32 : !llvm.ptr) {
+  ^bb0(%arg0 : !llvm.ptr):
     // CHECK: omp.task
     omp.task {
       "test.foo"() : () -> ()
@@ -2783,15 +2794,15 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
 // CHECK-LABEL: parallel_op_reduction_and_private
 func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) {
   // CHECK: omp.parallel
-  // CHECK-SAME: reduction(
-  // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
-  // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
-  //
   // CHECK-SAME: private(
   // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr,
   // CHECK-SAME: @y.privatizer %[[PRIV_VAR2:[^[:space:]]+]] -> %[[PRIV_ARG2:[^[:space:]]+]] : !llvm.ptr)
-  omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr)
-               private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) {
+  //
+  // CHECK-SAME: reduction(
+  // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
+  // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
+  omp.parallel private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr)
+               reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr) {
     // CHECK: llvm.load %[[PRIV_ARG]]
     %0 = llvm.load %priv_arg : !llvm.ptr -> f32
     // CHECK: llvm.load %[[PRIV_ARG2]]
diff --git a/mlir/test/Target/LLVMIR/openmp-private.mlir b/mlir/test/Target/LLVMIR/openmp-private.mlir
index 21167668bbee16..a06e44fc5cfe01 100644
--- a/mlir/test/Target/LLVMIR/openmp-private.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-private.mlir
@@ -206,7 +206,7 @@ llvm.func @private_and_reduction_() attributes {fir.internal_name = "_QPprivate_
   %0 = llvm.mlir.constant(1 : i64) : i64
   %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
   %2 = llvm.alloca %0 x f32 {bindc_name = "to_priv"} : (i64) -> !llvm.ptr
-  omp.parallel reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) private(@privatizer.part %2 -> %arg1 : !llvm.ptr) {
+  omp.parallel private(@privatizer.part %2 -> %arg1 : !llvm.ptr) reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) {
     %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
     %4 = llvm.mlir.constant(8.000000e+00 : f32) : f32
     llvm.store %4, %arg1 : f32, !llvm.ptr



More information about the Mlir-commits mailing list