[Mlir-commits] [mlir] [NFC][OpenMP][MLIR] Refactor code related to collecting privatizer info into a shared util (PR #131582)

Kareem Ergawy llvmlistbot at llvm.org
Wed Mar 19 00:25:37 PDT 2025


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/131582

>From 3c1af0691861cc10485f92cb5faea7d387fc11d1 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 17 Mar 2025 03:37:00 -0500
Subject: [PATCH 1/2] [OpenMP][MLIR] Refactor code related to collecting
 privatizer info into a shared util

Moves code needed to collect info about delayed privatizers into a
shared util instread of repeating the same patter across all relevant
constructs.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 257 ++++++++----------
 1 file changed, 107 insertions(+), 150 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 537558a83cb36..aff874643d41f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -696,20 +696,42 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Populates `privatizations` with privatization declarations used for the
-/// given op.
-template <class OP>
-static void collectPrivatizationDecls(
-    OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
-  std::optional<ArrayAttr> attr = op.getPrivateSyms();
-  if (!attr)
-    return;
+/// A util to collect info needed to convert delayed privatizers from MLIR to
+/// LLVM.
+struct PrivateVarsInfo {
+  template <typename OP>
+  PrivateVarsInfo(OP op)
+      : privateBlockArgs(
+            cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
+    mlirPrivateVars.reserve(privateBlockArgs.size());
+    llvmPrivateVars.reserve(privateBlockArgs.size());
+    collectPrivatizationDecls<OP>(op, privateDecls);
 
-  privatizations.reserve(privatizations.size() + attr->size());
-  for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
-    privatizations.push_back(findPrivatizer(op, symbolRef));
+    for (mlir::Value privateVar : op.getPrivateVars())
+      mlirPrivateVars.push_back(privateVar);
   }
-}
+
+  MutableArrayRef<BlockArgument> privateBlockArgs;
+  SmallVector<mlir::Value> mlirPrivateVars;
+  SmallVector<llvm::Value *> llvmPrivateVars;
+  SmallVector<omp::PrivateClauseOp> privateDecls;
+
+private:
+  /// Populates `privatizations` with privatization declarations used for the
+  /// given op.
+  template <class OP>
+  static void collectPrivatizationDecls(
+      OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
+    std::optional<ArrayAttr> attr = op.getPrivateSyms();
+    if (!attr)
+      return;
+
+    privatizations.reserve(privatizations.size() + attr->size());
+    for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
+      privatizations.push_back(findPrivatizer(op, symbolRef));
+    }
+  }
+};
 
 /// Populates `reductions` with reduction declarations used in the given op.
 template <typename T>
@@ -1384,19 +1406,18 @@ static llvm::Expected<llvm::Value *> initPrivateVar(
 static llvm::Error
 initPrivateVars(llvm::IRBuilderBase &builder,
                 LLVM::ModuleTranslation &moduleTranslation,
-                MutableArrayRef<BlockArgument> privateBlockArgs,
-                MutableArrayRef<omp::PrivateClauseOp> privateDecls,
-                MutableArrayRef<mlir::Value> mlirPrivateVars,
-                llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+                PrivateVarsInfo &privateVarsInfo,
                 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
-  if (privateBlockArgs.empty())
+  if (privateVarsInfo.privateBlockArgs.empty())
     return llvm::Error::success();
 
   llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
   setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
 
   for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
-           privateDecls, mlirPrivateVars, privateBlockArgs, llvmPrivateVars))) {
+           privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
+           privateVarsInfo.privateBlockArgs,
+           privateVarsInfo.llvmPrivateVars))) {
     auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
     llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
         builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
@@ -1420,10 +1441,7 @@ initPrivateVars(llvm::IRBuilderBase &builder,
 static llvm::Expected<llvm::BasicBlock *>
 allocatePrivateVars(llvm::IRBuilderBase &builder,
                     LLVM::ModuleTranslation &moduleTranslation,
-                    MutableArrayRef<BlockArgument> privateBlockArgs,
-                    MutableArrayRef<omp::PrivateClauseOp> privateDecls,
-                    MutableArrayRef<mlir::Value> mlirPrivateVars,
-                    llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
+                    PrivateVarsInfo &privateVarsInfo,
                     const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
                     llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
   // Allocate private vars
@@ -1449,8 +1467,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
                                ->getDataLayout()
                                .getProgramAddressSpace();
 
-  for (auto [privDecl, mlirPrivVar, blockArg] :
-       llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs)) {
+  for (auto [privDecl, mlirPrivVar, blockArg] : llvm::zip_equal(
+           privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
+           privateVarsInfo.privateBlockArgs)) {
     llvm::Type *llvmAllocType =
         moduleTranslation.convertType(privDecl.getType());
     builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
@@ -1460,7 +1479,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
       llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
                                                    builder.getPtrTy(defaultAS));
 
-    llvmPrivateVars.push_back(llvmPrivateVar);
+    privateVarsInfo.llvmPrivateVars.push_back(llvmPrivateVar);
   }
 
   return afterAllocas;
@@ -1888,19 +1907,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   if (failed(checkImplementationStatus(*taskOp)))
     return failure();
 
-  // Collect delayed privatisation declarations
-  MutableArrayRef<BlockArgument> privateBlockArgs =
-      cast<omp::BlockArgOpenMPOpInterface>(*taskOp).getPrivateBlockArgs();
-  SmallVector<mlir::Value> mlirPrivateVars;
-  SmallVector<llvm::Value *> llvmPrivateVars;
-  SmallVector<omp::PrivateClauseOp> privateDecls;
-  mlirPrivateVars.reserve(privateBlockArgs.size());
-  llvmPrivateVars.reserve(privateBlockArgs.size());
-  collectPrivatizationDecls(taskOp, privateDecls);
+  PrivateVarsInfo privateVarsInfo(taskOp);
   TaskContextStructManager taskStructMgr{builder, moduleTranslation,
-                                         privateDecls};
-  for (mlir::Value privateVar : taskOp.getPrivateVars())
-    mlirPrivateVars.push_back(privateVar);
+                                         privateVarsInfo.privateDecls};
 
   // Allocate and copy private variables before creating the task. This avoids
   // accessing invalid memory if (after this scope ends) the private variables
@@ -1959,7 +1968,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   taskStructMgr.createGEPsToPrivateVars();
 
   for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
-       llvm::zip_equal(privateDecls, mlirPrivateVars, privateBlockArgs,
+       llvm::zip_equal(privateVarsInfo.privateDecls,
+                       privateVarsInfo.mlirPrivateVars,
+                       privateVarsInfo.privateBlockArgs,
                        taskStructMgr.getLLVMPrivateVarGEPs())) {
     // To be handled inside the task.
     if (!privDecl.readsFromMold())
@@ -1998,9 +2009,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
 
   // firstprivate copy region
   setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
-  if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                  taskStructMgr.getLLVMPrivateVarGEPs(),
-                                  privateDecls)))
+  if (failed(copyFirstPrivateVars(
+          builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+          taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privateDecls)))
     return llvm::failure();
 
   // Set up for call to createTask()
@@ -2017,9 +2028,11 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
     builder.restoreIP(codegenIP);
 
     llvm::BasicBlock *privInitBlock = nullptr;
-    llvmPrivateVars.resize(privateBlockArgs.size());
+    privateVarsInfo.llvmPrivateVars.resize(
+        privateVarsInfo.privateBlockArgs.size());
     for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
-             privateBlockArgs, privateDecls, mlirPrivateVars))) {
+             privateVarsInfo.privateBlockArgs, privateVarsInfo.privateDecls,
+             privateVarsInfo.mlirPrivateVars))) {
       auto [blockArg, privDecl, mlirPrivVar] = zip;
       // This is handled before the task executes
       if (privDecl.readsFromMold())
@@ -2038,23 +2051,25 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
       if (!privateVarOrError)
         return privateVarOrError.takeError();
       moduleTranslation.mapValue(blockArg, privateVarOrError.get());
-      llvmPrivateVars[i] = privateVarOrError.get();
+      privateVarsInfo.llvmPrivateVars[i] = privateVarOrError.get();
     }
 
     taskStructMgr.createGEPsToPrivateVars();
     for (auto [i, llvmPrivVar] :
          llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
       if (!llvmPrivVar) {
-        assert(llvmPrivateVars[i] && "This is added in the loop above");
+        assert(privateVarsInfo.llvmPrivateVars[i] &&
+               "This is added in the loop above");
         continue;
       }
-      llvmPrivateVars[i] = llvmPrivVar;
+      privateVarsInfo.llvmPrivateVars[i] = llvmPrivVar;
     }
 
     // Find and map the addresses of each variable within the task context
     // structure
-    for (auto [blockArg, llvmPrivateVar, privateDecl] :
-         llvm::zip_equal(privateBlockArgs, llvmPrivateVars, privateDecls)) {
+    for (auto [blockArg, llvmPrivateVar, privateDecl] : llvm::zip_equal(
+             privateVarsInfo.privateBlockArgs, privateVarsInfo.llvmPrivateVars,
+             privateVarsInfo.privateDecls)) {
       // This was handled above.
       if (!privateDecl.readsFromMold())
         continue;
@@ -2076,7 +2091,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
     builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
-                                  llvmPrivateVars, privateDecls)))
+                                  privateVarsInfo.llvmPrivateVars,
+                                  privateVarsInfo.privateDecls)))
       return llvm::make_error<PreviouslyReportedError>();
 
     // Free heap allocated task context structure at the end of the task.
@@ -2171,17 +2187,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
-  MutableArrayRef<BlockArgument> privateBlockArgs =
-      cast<omp::BlockArgOpenMPOpInterface>(*wsloopOp).getPrivateBlockArgs();
-  SmallVector<mlir::Value> mlirPrivateVars;
-  SmallVector<llvm::Value *> llvmPrivateVars;
-  SmallVector<omp::PrivateClauseOp> privateDecls;
-  mlirPrivateVars.reserve(privateBlockArgs.size());
-  llvmPrivateVars.reserve(privateBlockArgs.size());
-  collectPrivatizationDecls(wsloopOp, privateDecls);
-
-  for (mlir::Value privateVar : wsloopOp.getPrivateVars())
-    mlirPrivateVars.push_back(privateVar);
+  PrivateVarsInfo privateVarsInfo(wsloopOp);
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
   collectReductionDecls(wsloopOp, reductionDecls);
@@ -2192,8 +2198,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
       wsloopOp.getNumReductionVars());
 
   llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-      builder, moduleTranslation, privateBlockArgs, privateDecls,
-      mlirPrivateVars, llvmPrivateVars, allocaIP);
+      builder, moduleTranslation, privateVarsInfo, allocaIP);
   if (handleError(afterAllocas, opInst).failed())
     return failure();
 
@@ -2210,15 +2215,14 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                                 deferredStores, isByRef)))
     return failure();
 
-  if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
-                                  privateDecls, mlirPrivateVars,
-                                  llvmPrivateVars),
+  if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
                   opInst)
           .failed())
     return failure();
 
-  if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                  llvmPrivateVars, privateDecls)))
+  if (failed(copyFirstPrivateVars(
+          builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+          privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
     return failure();
 
   assert(afterAllocas.get()->getSinglePredecessor());
@@ -2271,7 +2275,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     return failure();
 
   return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
-                            llvmPrivateVars, privateDecls);
+                            privateVarsInfo.llvmPrivateVars,
+                            privateVarsInfo.privateDecls);
 }
 
 /// Converts the OpenMP parallel operation to LLVM IR.
@@ -2286,17 +2291,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (failed(checkImplementationStatus(*opInst)))
     return failure();
 
-  // Collect delayed privatization declarations
-  MutableArrayRef<BlockArgument> privateBlockArgs =
-      cast<omp::BlockArgOpenMPOpInterface>(*opInst).getPrivateBlockArgs();
-  SmallVector<mlir::Value> mlirPrivateVars;
-  SmallVector<llvm::Value *> llvmPrivateVars;
-  SmallVector<omp::PrivateClauseOp> privateDecls;
-  mlirPrivateVars.reserve(privateBlockArgs.size());
-  llvmPrivateVars.reserve(privateBlockArgs.size());
-  collectPrivatizationDecls(opInst, privateDecls);
-  for (mlir::Value privateVar : opInst.getPrivateVars())
-    mlirPrivateVars.push_back(privateVar);
+  PrivateVarsInfo privateVarsInfo(opInst);
 
   // Collect reduction declarations
   SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2308,8 +2303,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto bodyGenCB = [&](InsertPointTy allocaIP,
                        InsertPointTy codeGenIP) -> llvm::Error {
     llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-        builder, moduleTranslation, privateBlockArgs, privateDecls,
-        mlirPrivateVars, llvmPrivateVars, allocaIP);
+        builder, moduleTranslation, privateVarsInfo, allocaIP);
     if (handleError(afterAllocas, *opInst).failed())
       return llvm::make_error<PreviouslyReportedError>();
 
@@ -2332,15 +2326,15 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     assert(afterAllocas.get()->getSinglePredecessor());
     builder.restoreIP(codeGenIP);
 
-    if (handleError(initPrivateVars(builder, moduleTranslation,
-                                    privateBlockArgs, privateDecls,
-                                    mlirPrivateVars, llvmPrivateVars),
-                    *opInst)
+    if (handleError(
+            initPrivateVars(builder, moduleTranslation, privateVarsInfo),
+            *opInst)
             .failed())
       return llvm::make_error<PreviouslyReportedError>();
 
-    if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                    llvmPrivateVars, privateDecls)))
+    if (failed(copyFirstPrivateVars(
+            builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
+            privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
       return llvm::make_error<PreviouslyReportedError>();
 
     if (failed(
@@ -2422,7 +2416,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
           "failed to inline `cleanup` region of `omp.declare_reduction`");
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
-                                  llvmPrivateVars, privateDecls)))
+                                  privateVarsInfo.llvmPrivateVars,
+                                  privateVarsInfo.privateDecls)))
       return llvm::make_error<PreviouslyReportedError>();
 
     builder.restoreIP(oldIP);
@@ -2490,30 +2485,17 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(checkImplementationStatus(opInst)))
     return failure();
 
-  MutableArrayRef<BlockArgument> privateBlockArgs =
-      cast<omp::BlockArgOpenMPOpInterface>(*simdOp).getPrivateBlockArgs();
-  SmallVector<mlir::Value> mlirPrivateVars;
-  SmallVector<llvm::Value *> llvmPrivateVars;
-  SmallVector<omp::PrivateClauseOp> privateDecls;
-  mlirPrivateVars.reserve(privateBlockArgs.size());
-  llvmPrivateVars.reserve(privateBlockArgs.size());
-  collectPrivatizationDecls(simdOp, privateDecls);
-
-  for (mlir::Value privateVar : simdOp.getPrivateVars())
-    mlirPrivateVars.push_back(privateVar);
+  PrivateVarsInfo privateVarsInfo(simdOp);
 
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
   llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-      builder, moduleTranslation, privateBlockArgs, privateDecls,
-      mlirPrivateVars, llvmPrivateVars, allocaIP);
+      builder, moduleTranslation, privateVarsInfo, allocaIP);
   if (handleError(afterAllocas, opInst).failed())
     return failure();
 
-  if (handleError(initPrivateVars(builder, moduleTranslation, privateBlockArgs,
-                                  privateDecls, mlirPrivateVars,
-                                  llvmPrivateVars),
+  if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
                   opInst)
           .failed())
     return failure();
@@ -2562,7 +2544,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                         order, simdlen, safelen);
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
-                            llvmPrivateVars, privateDecls);
+                            privateVarsInfo.llvmPrivateVars,
+                            privateVarsInfo.privateDecls);
 }
 
 /// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
@@ -4186,37 +4169,21 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
 
     // DistributeOp has only one region associated with it.
     builder.restoreIP(codeGenIP);
+    PrivateVarsInfo privVarsInfo(distributeOp);
 
-    // TODO This is a recurring pattern in almost all ops that need
-    // privatization. Try to abstract it in a shared util/interface.
-    MutableArrayRef<BlockArgument> privateBlockArgs =
-        cast<omp::BlockArgOpenMPOpInterface>(*distributeOp)
-            .getPrivateBlockArgs();
-    SmallVector<mlir::Value> mlirPrivateVars;
-    SmallVector<llvm::Value *> llvmPrivateVars;
-    SmallVector<omp::PrivateClauseOp> privateDecls;
-    mlirPrivateVars.reserve(privateBlockArgs.size());
-    llvmPrivateVars.reserve(privateBlockArgs.size());
-    collectPrivatizationDecls(distributeOp, privateDecls);
-
-    for (mlir::Value privateVar : distributeOp.getPrivateVars())
-      mlirPrivateVars.push_back(privateVar);
-
-    llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-        builder, moduleTranslation, privateBlockArgs, privateDecls,
-        mlirPrivateVars, llvmPrivateVars, allocaIP);
+    llvm::Expected<llvm::BasicBlock *> afterAllocas =
+        allocatePrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP);
     if (handleError(afterAllocas, opInst).failed())
       return llvm::make_error<PreviouslyReportedError>();
 
-    if (handleError(initPrivateVars(builder, moduleTranslation,
-                                    privateBlockArgs, privateDecls,
-                                    mlirPrivateVars, llvmPrivateVars),
+    if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
                     opInst)
             .failed())
       return llvm::make_error<PreviouslyReportedError>();
 
-    if (failed(copyFirstPrivateVars(builder, moduleTranslation, mlirPrivateVars,
-                                    llvmPrivateVars, privateDecls)))
+    if (failed(copyFirstPrivateVars(
+            builder, moduleTranslation, privVarsInfo.mlirPrivateVars,
+            privVarsInfo.llvmPrivateVars, privVarsInfo.privateDecls)))
       return llvm::make_error<PreviouslyReportedError>();
 
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
@@ -4257,9 +4224,9 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
         return wsloopIP.takeError();
     }
 
-    if (failed(cleanupPrivateVars(builder, moduleTranslation,
-                                  distributeOp.getLoc(), llvmPrivateVars,
-                                  privateDecls)))
+    if (failed(cleanupPrivateVars(
+            builder, moduleTranslation, distributeOp.getLoc(),
+            privVarsInfo.llvmPrivateVars, privVarsInfo.privateDecls)))
       return llvm::make_error<PreviouslyReportedError>();
 
     return llvm::Error::success();
@@ -4876,35 +4843,25 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
     // Do privatization after moduleTranslation has already recorded
     // mapped values.
-    MutableArrayRef<BlockArgument> privateBlockArgs =
-        argIface.getPrivateBlockArgs();
-    SmallVector<mlir::Value> mlirPrivateVars;
-    SmallVector<llvm::Value *> llvmPrivateVars;
-    SmallVector<omp::PrivateClauseOp> privateDecls;
-    mlirPrivateVars.reserve(privateBlockArgs.size());
-    llvmPrivateVars.reserve(privateBlockArgs.size());
-    collectPrivatizationDecls(targetOp, privateDecls);
-    for (mlir::Value privateVar : targetOp.getPrivateVars())
-      mlirPrivateVars.push_back(privateVar);
+    PrivateVarsInfo privateVarsInfo(targetOp);
 
-    llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
-        builder, moduleTranslation, privateBlockArgs, privateDecls,
-        mlirPrivateVars, llvmPrivateVars, allocaIP, &mappedPrivateVars);
+    llvm::Expected<llvm::BasicBlock *> afterAllocas =
+        allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
+                            allocaIP, &mappedPrivateVars);
 
     if (failed(handleError(afterAllocas, *targetOp)))
       return llvm::make_error<PreviouslyReportedError>();
 
     builder.restoreIP(codeGenIP);
-    if (handleError(initPrivateVars(builder, moduleTranslation,
-                                    privateBlockArgs, privateDecls,
-                                    mlirPrivateVars, llvmPrivateVars,
+    if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
                                     &mappedPrivateVars),
                     *targetOp)
             .failed())
       return llvm::make_error<PreviouslyReportedError>();
 
     SmallVector<Region *> privateCleanupRegions;
-    llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
+    llvm::transform(privateVarsInfo.privateDecls,
+                    std::back_inserter(privateCleanupRegions),
                     [](omp::PrivateClauseOp privatizer) {
                       return &privatizer.getDeallocRegion();
                     });
@@ -4918,8 +4875,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     builder.SetInsertPoint(*exitBlock);
     if (!privateCleanupRegions.empty()) {
       if (failed(inlineOmpRegionCleanup(
-              privateCleanupRegions, llvmPrivateVars, moduleTranslation,
-              builder, "omp.targetop.private.cleanup",
+              privateCleanupRegions, privateVarsInfo.llvmPrivateVars,
+              moduleTranslation, builder, "omp.targetop.private.cleanup",
               /*shouldLoadCleanupRegionArg=*/false))) {
         return llvm::createStringError(
             "failed to inline `dealloc` region of `omp.private` "

>From b93f7dbc6717470e3486764a7ae71452daf113cf Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 19 Mar 2025 02:25:19 -0500
Subject: [PATCH 2/2] Handle review comments

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 108 +++++++++---------
 1 file changed, 52 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index aff874643d41f..f303c0337a67a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -701,34 +701,33 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
 struct PrivateVarsInfo {
   template <typename OP>
   PrivateVarsInfo(OP op)
-      : privateBlockArgs(
+      : blockArgs(
             cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
-    mlirPrivateVars.reserve(privateBlockArgs.size());
-    llvmPrivateVars.reserve(privateBlockArgs.size());
-    collectPrivatizationDecls<OP>(op, privateDecls);
+    mlirVars.reserve(blockArgs.size());
+    llvmVars.reserve(blockArgs.size());
+    collectPrivatizationDecls<OP>(op);
 
     for (mlir::Value privateVar : op.getPrivateVars())
-      mlirPrivateVars.push_back(privateVar);
+      mlirVars.push_back(privateVar);
   }
 
-  MutableArrayRef<BlockArgument> privateBlockArgs;
-  SmallVector<mlir::Value> mlirPrivateVars;
-  SmallVector<llvm::Value *> llvmPrivateVars;
-  SmallVector<omp::PrivateClauseOp> privateDecls;
+  MutableArrayRef<BlockArgument> blockArgs;
+  SmallVector<mlir::Value> mlirVars;
+  SmallVector<llvm::Value *> llvmVars;
+  SmallVector<omp::PrivateClauseOp> privatizers;
 
 private:
   /// Populates `privatizations` with privatization declarations used for the
   /// given op.
   template <class OP>
-  static void collectPrivatizationDecls(
-      OP op, SmallVectorImpl<omp::PrivateClauseOp> &privatizations) {
+  void collectPrivatizationDecls(OP op) {
     std::optional<ArrayAttr> attr = op.getPrivateSyms();
     if (!attr)
       return;
 
-    privatizations.reserve(privatizations.size() + attr->size());
+    privatizers.reserve(privatizers.size() + attr->size());
     for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
-      privatizations.push_back(findPrivatizer(op, symbolRef));
+      privatizers.push_back(findPrivatizer(op, symbolRef));
     }
   }
 };
@@ -1408,16 +1407,15 @@ initPrivateVars(llvm::IRBuilderBase &builder,
                 LLVM::ModuleTranslation &moduleTranslation,
                 PrivateVarsInfo &privateVarsInfo,
                 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
-  if (privateVarsInfo.privateBlockArgs.empty())
+  if (privateVarsInfo.blockArgs.empty())
     return llvm::Error::success();
 
   llvm::BasicBlock *privInitBlock = splitBB(builder, true, "omp.private.init");
   setInsertPointForPossiblyEmptyBlock(builder, privInitBlock);
 
   for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
-           privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
-           privateVarsInfo.privateBlockArgs,
-           privateVarsInfo.llvmPrivateVars))) {
+           privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
+           privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
     auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
     llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
         builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
@@ -1467,9 +1465,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
                                ->getDataLayout()
                                .getProgramAddressSpace();
 
-  for (auto [privDecl, mlirPrivVar, blockArg] : llvm::zip_equal(
-           privateVarsInfo.privateDecls, privateVarsInfo.mlirPrivateVars,
-           privateVarsInfo.privateBlockArgs)) {
+  for (auto [privDecl, mlirPrivVar, blockArg] :
+       llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
+                       privateVarsInfo.blockArgs)) {
     llvm::Type *llvmAllocType =
         moduleTranslation.convertType(privDecl.getType());
     builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
@@ -1479,7 +1477,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
       llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
                                                    builder.getPtrTy(defaultAS));
 
-    privateVarsInfo.llvmPrivateVars.push_back(llvmPrivateVar);
+    privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
   }
 
   return afterAllocas;
@@ -1909,7 +1907,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
 
   PrivateVarsInfo privateVarsInfo(taskOp);
   TaskContextStructManager taskStructMgr{builder, moduleTranslation,
-                                         privateVarsInfo.privateDecls};
+                                         privateVarsInfo.privatizers};
 
   // Allocate and copy private variables before creating the task. This avoids
   // accessing invalid memory if (after this scope ends) the private variables
@@ -1968,9 +1966,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   taskStructMgr.createGEPsToPrivateVars();
 
   for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
-       llvm::zip_equal(privateVarsInfo.privateDecls,
-                       privateVarsInfo.mlirPrivateVars,
-                       privateVarsInfo.privateBlockArgs,
+       llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
+                       privateVarsInfo.blockArgs,
                        taskStructMgr.getLLVMPrivateVarGEPs())) {
     // To be handled inside the task.
     if (!privDecl.readsFromMold())
@@ -2010,8 +2007,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
   // firstprivate copy region
   setInsertPointForPossiblyEmptyBlock(builder, copyBlock);
   if (failed(copyFirstPrivateVars(
-          builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
-          taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privateDecls)))
+          builder, moduleTranslation, privateVarsInfo.mlirVars,
+          taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers)))
     return llvm::failure();
 
   // Set up for call to createTask()
@@ -2028,11 +2025,10 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
     builder.restoreIP(codegenIP);
 
     llvm::BasicBlock *privInitBlock = nullptr;
-    privateVarsInfo.llvmPrivateVars.resize(
-        privateVarsInfo.privateBlockArgs.size());
+    privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
     for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
-             privateVarsInfo.privateBlockArgs, privateVarsInfo.privateDecls,
-             privateVarsInfo.mlirPrivateVars))) {
+             privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
+             privateVarsInfo.mlirVars))) {
       auto [blockArg, privDecl, mlirPrivVar] = zip;
       // This is handled before the task executes
       if (privDecl.readsFromMold())
@@ -2051,25 +2047,25 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
       if (!privateVarOrError)
         return privateVarOrError.takeError();
       moduleTranslation.mapValue(blockArg, privateVarOrError.get());
-      privateVarsInfo.llvmPrivateVars[i] = privateVarOrError.get();
+      privateVarsInfo.llvmVars[i] = privateVarOrError.get();
     }
 
     taskStructMgr.createGEPsToPrivateVars();
     for (auto [i, llvmPrivVar] :
          llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
       if (!llvmPrivVar) {
-        assert(privateVarsInfo.llvmPrivateVars[i] &&
+        assert(privateVarsInfo.llvmVars[i] &&
                "This is added in the loop above");
         continue;
       }
-      privateVarsInfo.llvmPrivateVars[i] = llvmPrivVar;
+      privateVarsInfo.llvmVars[i] = llvmPrivVar;
     }
 
     // Find and map the addresses of each variable within the task context
     // structure
-    for (auto [blockArg, llvmPrivateVar, privateDecl] : llvm::zip_equal(
-             privateVarsInfo.privateBlockArgs, privateVarsInfo.llvmPrivateVars,
-             privateVarsInfo.privateDecls)) {
+    for (auto [blockArg, llvmPrivateVar, privateDecl] :
+         llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
+                         privateVarsInfo.privatizers)) {
       // This was handled above.
       if (!privateDecl.readsFromMold())
         continue;
@@ -2091,8 +2087,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
     builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
-                                  privateVarsInfo.llvmPrivateVars,
-                                  privateVarsInfo.privateDecls)))
+                                  privateVarsInfo.llvmVars,
+                                  privateVarsInfo.privatizers)))
       return llvm::make_error<PreviouslyReportedError>();
 
     // Free heap allocated task context structure at the end of the task.
@@ -2221,8 +2217,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     return failure();
 
   if (failed(copyFirstPrivateVars(
-          builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
-          privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
+          builder, moduleTranslation, privateVarsInfo.mlirVars,
+          privateVarsInfo.llvmVars, privateVarsInfo.privatizers)))
     return failure();
 
   assert(afterAllocas.get()->getSinglePredecessor());
@@ -2275,8 +2271,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     return failure();
 
   return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
-                            privateVarsInfo.llvmPrivateVars,
-                            privateVarsInfo.privateDecls);
+                            privateVarsInfo.llvmVars,
+                            privateVarsInfo.privatizers);
 }
 
 /// Converts the OpenMP parallel operation to LLVM IR.
@@ -2333,8 +2329,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       return llvm::make_error<PreviouslyReportedError>();
 
     if (failed(copyFirstPrivateVars(
-            builder, moduleTranslation, privateVarsInfo.mlirPrivateVars,
-            privateVarsInfo.llvmPrivateVars, privateVarsInfo.privateDecls)))
+            builder, moduleTranslation, privateVarsInfo.mlirVars,
+            privateVarsInfo.llvmVars, privateVarsInfo.privatizers)))
       return llvm::make_error<PreviouslyReportedError>();
 
     if (failed(
@@ -2416,8 +2412,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
           "failed to inline `cleanup` region of `omp.declare_reduction`");
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
-                                  privateVarsInfo.llvmPrivateVars,
-                                  privateVarsInfo.privateDecls)))
+                                  privateVarsInfo.llvmVars,
+                                  privateVarsInfo.privatizers)))
       return llvm::make_error<PreviouslyReportedError>();
 
     builder.restoreIP(oldIP);
@@ -2544,8 +2540,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                         order, simdlen, safelen);
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
-                            privateVarsInfo.llvmPrivateVars,
-                            privateVarsInfo.privateDecls);
+                            privateVarsInfo.llvmVars,
+                            privateVarsInfo.privatizers);
 }
 
 /// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
@@ -4182,8 +4178,8 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
       return llvm::make_error<PreviouslyReportedError>();
 
     if (failed(copyFirstPrivateVars(
-            builder, moduleTranslation, privVarsInfo.mlirPrivateVars,
-            privVarsInfo.llvmPrivateVars, privVarsInfo.privateDecls)))
+            builder, moduleTranslation, privVarsInfo.mlirVars,
+            privVarsInfo.llvmVars, privVarsInfo.privatizers)))
       return llvm::make_error<PreviouslyReportedError>();
 
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
@@ -4224,9 +4220,9 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
         return wsloopIP.takeError();
     }
 
-    if (failed(cleanupPrivateVars(
-            builder, moduleTranslation, distributeOp.getLoc(),
-            privVarsInfo.llvmPrivateVars, privVarsInfo.privateDecls)))
+    if (failed(cleanupPrivateVars(builder, moduleTranslation,
+                                  distributeOp.getLoc(), privVarsInfo.llvmVars,
+                                  privVarsInfo.privatizers)))
       return llvm::make_error<PreviouslyReportedError>();
 
     return llvm::Error::success();
@@ -4860,7 +4856,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       return llvm::make_error<PreviouslyReportedError>();
 
     SmallVector<Region *> privateCleanupRegions;
-    llvm::transform(privateVarsInfo.privateDecls,
+    llvm::transform(privateVarsInfo.privatizers,
                     std::back_inserter(privateCleanupRegions),
                     [](omp::PrivateClauseOp privatizer) {
                       return &privatizer.getDeallocRegion();
@@ -4875,7 +4871,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     builder.SetInsertPoint(*exitBlock);
     if (!privateCleanupRegions.empty()) {
       if (failed(inlineOmpRegionCleanup(
-              privateCleanupRegions, privateVarsInfo.llvmPrivateVars,
+              privateCleanupRegions, privateVarsInfo.llvmVars,
               moduleTranslation, builder, "omp.targetop.private.cleanup",
               /*shouldLoadCleanupRegionArg=*/false))) {
         return llvm::createStringError(



More information about the Mlir-commits mailing list