[Mlir-commits] [mlir] [MLIR][OpenMP] Minor improvements to BlockArgOpenMPOpInterface, NFC (PR #130789)

Sergio Afonso llvmlistbot at llvm.org
Wed Mar 12 04:53:27 PDT 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/130789

>From cc7cb76e135918ffb86ddfc193b1b66c0948e42c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 11 Mar 2025 15:53:06 +0000
Subject: [PATCH] [MLIR][OpenMP] Minor improvements to
 BlockArgOpenMPOpInterface, NFC

This patch introduces a use for the new `getBlockArgsPairs` to avoid having to
manually list each applicable clause.

Also, the `numClauseBlockArgs()` function is introduced, which simplifies the
implementation of the interface's verifier and enables better memory handling
within `getBlockArgsPairs`.
---
 mlir/docs/Dialects/OpenMPDialect/_index.md         |  2 ++
 .../mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td     | 14 +++++++++-----
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp   | 12 +++++-------
 3 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index adde176750437..1df80fac2a684 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -372,6 +372,8 @@ accessed:
   should be located.
   - `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
   defined by the given clause.
+  - `numClauseBlockArgs()`: Returns the total number of entry block arguments
+  defined by all clauses.
   - `getBlockArgsPairs()`: Returns a list of pairs where the first element is
   the outside value, or operand, and the second element is the corresponding
   entry block argument.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0766b4e8d1472..3fa54d35ed09b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
     !foreach(clause, clauses, clause.startMethod),
     !foreach(clause, clauses, clause.blockArgsMethod),
     [
+      InterfaceMethod<
+        "Get the total number of clause-defined entry block arguments",
+        "unsigned", "numClauseBlockArgs", (ins),
+        "return " # !interleave(
+          !foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"),
+          " + ") # ";"
+      >,
       InterfaceMethod<
         "Populate a vector of pairs representing the matching between operands "
         "and entry block arguments.", "void", "getBlockArgsPairs",
         (ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
         [{
           auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+          pairs.reserve(pairs.size() + iface.numClauseBlockArgs());
         }] # !interleave(!foreach(clause, clauses, [{
         }] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
         }] # "  for (auto [var, arg] : ::llvm::zip_equal(" #
@@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-  }] # "unsigned expectedArgs = "
-     # !interleave(
-         !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
-         " + "
-       ) # ";" # [{
+    unsigned expectedArgs = iface.numClauseBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3373f19a006ba..b9893716980fe 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
   // corresponding operand. This is semantically equivalent to this wrapper not
   // being present.
   auto forwardArgs =
-      [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
-                           OperandRange operands) {
-        for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
+      [&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
+        llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
+        blockArgIface.getBlockArgsPairs(blockArgsPairs);
+        for (auto [var, arg] : blockArgsPairs)
           moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
       };
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
       .Case([&](omp::SimdOp op) {
-        auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
-        forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
-        forwardArgs(blockArgIface.getReductionBlockArgs(),
-                    op.getReductionVars());
+        forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
         op.emitWarning() << "simd information on composite construct discarded";
         return success();
       })



More information about the Mlir-commits mailing list