[Mlir-commits] [mlir] [MLIR][OpenMP] Support basic materialization for `omp.private` ops (PR #81715)

Kareem Ergawy llvmlistbot at llvm.org
Tue Feb 27 09:12:18 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/81715

>From a043ecee87ec08d3bace37ae0e0a6c4ffae8ced1 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 13 Feb 2024 04:46:50 -0600
Subject: [PATCH 1/2] [MLIR][OpenMP] Support basic materialization for
 `omp.private` ops

Adds basic support for materializing delayed privatization. So far, the
restrictions on the implementation are:
- Only `private` clauses are supported (`firstprivate` support will be
  added in a later PR).
- Only single-block `omp.private -> alloc` regions are supported
  (multi-block ones will be supported in a later PR).
---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |   5 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 201 ++++++++++++++----
 mlir/test/Target/LLVMIR/openmp-private.mlir   | 142 +++++++++++++
 3 files changed, 304 insertions(+), 44 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-private.mlir

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c2b471ab96183f..8a6980e2c6a2d9 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1957,7 +1957,10 @@ LogicalResult PrivateClauseOp::verify() {
   Type symType = getType();
 
   auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
-    if (!terminator->hasSuccessors() && !llvm::isa<YieldOp>(terminator))
+    if (!terminator->getBlock()->getSuccessors().empty())
+      return success();
+
+    if (!llvm::isa<YieldOp>(terminator))
       return mlir::emitError(terminator->getLoc())
              << "expected exit block terminator to be an `omp.yield` op.";
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6e53d801a0d2f0..0c390f68bb6890 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
 
 /// Translates the blocks contained in the given region and appends them to at
 /// the current insertion point of `builder`. The operations of the entry block
-/// are appended to the current insertion block, which is not expected to have a
-/// terminator. If set, `continuationBlockArgs` is populated with translated
-/// values that correspond to the values omp.yield'ed from the region.
+/// are appended to the current insertion block. If set, `continuationBlockArgs`
+/// is populated with translated values that correspond to the values
+/// omp.yield'ed from the region.
 static LogicalResult inlineConvertOmpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
   // Special case for single-block regions that don't create additional blocks:
   // insert operations without creating additional blocks.
   if (llvm::hasSingleElement(region)) {
+    llvm::Instruction *potentialTerminator =
+        builder.GetInsertBlock()->empty() ? nullptr
+                                          : &builder.GetInsertBlock()->back();
+
+    if (potentialTerminator && potentialTerminator->isTerminator())
+      potentialTerminator->removeFromParent();
     moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
+
     if (failed(moduleTranslation.convertBlock(
             region.front(), /*ignoreArguments=*/true, builder)))
       return failure();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
     // Drop the mapping that is no longer necessary so that the same region can
     // be processed multiple times.
     moduleTranslation.forgetMapping(region);
+
+    if (potentialTerminator && potentialTerminator->isTerminator())
+      potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
+
     return success();
   }
 
@@ -1000,11 +1011,41 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Replace the region arguments of the parallel op (which correspond to private
+/// variables) with the actual private variables they correspond to. This
+/// prepares the parallel op so that it matches what is expected by the
+/// OMPIRBuilder. Instead of editing the original op in-place, this function
+/// does the required changes to a cloned version which should then be erased by
+/// the caller.
+static omp::ParallelOp
+prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
+  Region &region = opInst.getRegion();
+  auto privateVars = opInst.getPrivateVars();
+
+  auto privateVarsIt = privateVars.begin();
+  // Reduction precede private arguments, so skip them first.
+  unsigned privateArgBeginIdx = opInst.getNumReductionVars();
+  unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
+
+  mlir::IRMapping mapping;
+  for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
+       ++argIdx, ++privateVarsIt)
+    mapping.map(region.getArgument(argIdx), *privateVarsIt);
+
+  mlir::OpBuilder cloneBuilder(opInst);
+  omp::ParallelOp opInstClone =
+      llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst, mapping));
+
+  return opInstClone;
+}
+
 /// Converts the OpenMP parallel operation to LLVM IR.
 static LogicalResult
 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+  omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
+
   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
   // relying on captured variables.
   LogicalResult bodyGenStatus = success();
@@ -1013,12 +1054,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
     // Collect reduction declarations
     SmallVector<omp::ReductionDeclareOp> reductionDecls;
-    collectReductionDecls(opInst, reductionDecls);
+    collectReductionDecls(opInstClone, reductionDecls);
 
     // Allocate reduction vars
     SmallVector<llvm::Value *> privateReductionVariables;
     DenseMap<Value, llvm::Value *> reductionVariableMap;
-    allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
+    allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
                        reductionDecls, privateReductionVariables,
                        reductionVariableMap);
 
@@ -1030,7 +1071,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
 
     // Initialize reduction vars
     builder.restoreIP(allocaIP);
-    for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
+    for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
       SmallVector<llvm::Value *> phis;
       if (failed(inlineConvertOmpRegions(
               reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1051,18 +1092,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     // ParallelOp has only one region associated with it.
     builder.restoreIP(codeGenIP);
     auto regionBlock =
-        convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
+        convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
                             moduleTranslation, bodyGenStatus);
 
     // Process the reductions if required.
-    if (opInst.getNumReductionVars() > 0) {
+    if (opInstClone.getNumReductionVars() > 0) {
       // Collect reduction info
       SmallVector<OwningReductionGen> owningReductionGens;
       SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
       SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
-      collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
-                           owningReductionGens, owningAtomicReductionGens,
-                           privateReductionVariables, reductionInfos);
+      collectReductionInfo(opInstClone, builder, moduleTranslation,
+                           reductionDecls, owningReductionGens,
+                           owningAtomicReductionGens, privateReductionVariables,
+                           reductionInfos);
 
       // Move to region cont block
       builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1075,7 +1117,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
           ompBuilder->createReductions(builder.saveIP(), allocaIP,
                                        reductionInfos, false);
       if (!contInsertPoint.getBlock()) {
-        bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
+        bodyGenStatus = opInstClone->emitOpError()
+                        << "failed to convert reductions";
         return;
       }
 
@@ -1086,12 +1129,82 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
 
   // TODO: Perform appropriate actions according to the data-sharing
   // attribute (shared, private, firstprivate, ...) of variables.
-  // Currently defaults to shared.
+  // Currently shared and private are supported.
   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
                     llvm::Value &, llvm::Value &vPtr,
                     llvm::Value *&replacementValue) -> InsertPointTy {
     replacementValue = &vPtr;
 
+    // If this is a private value, this lambda will return the corresponding
+    // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
+    // returned.
+    auto [privVar, privatizerClone] =
+        [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
+      if (!opInstClone.getPrivateVars().empty()) {
+        auto privVars = opInstClone.getPrivateVars();
+        auto privatizers = opInstClone.getPrivatizers();
+
+        for (auto [privVar, privatizerAttr] :
+             llvm::zip_equal(privVars, *privatizers)) {
+          // Find the MLIR private variable corresponding to the LLVM value
+          // being privatized.
+          llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
+          if (llvmPrivVar != &vPtr)
+            continue;
+
+          SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
+          omp::PrivateClauseOp privatizer =
+              SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+                  opInstClone, privSym);
+
+          // Clone the privatizer in case it is used by more than one parallel
+          // region. The privatizer is processed in-place (see below) before it
+          // gets inlined in the parallel region and therefore processing the
+          // original op is dangerous.
+          return {privVar, privatizer.clone()};
+        }
+      }
+
+      return {mlir::Value(), omp::PrivateClauseOp()};
+    }();
+
+    if (privVar) {
+      if (privatizerClone.getDataSharingType() ==
+          omp::DataSharingClauseType::FirstPrivate) {
+        privatizerClone.emitOpError(
+            "TODO: delayed privatization is not "
+            "supported for `firstprivate` clauses yet.");
+        bodyGenStatus = failure();
+        return codeGenIP;
+      }
+
+      Region &allocRegion = privatizerClone.getAllocRegion();
+
+      // Replace the privatizer block argument with mlir value being privatized.
+      // This way, the body of the privatizer will be changed from using the
+      // region/block argument to the value being privatized.
+      auto allocRegionArg = allocRegion.getArgument(0);
+      replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
+
+      auto oldIP = builder.saveIP();
+      builder.restoreIP(allocaIP);
+
+      SmallVector<llvm::Value *, 1> yieldedValues;
+      if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
+                                         moduleTranslation, &yieldedValues))) {
+        opInstClone.emitError(
+            "failed to inline `alloc` region of an `omp.private` "
+            "op in the parallel region");
+        bodyGenStatus = failure();
+      } else {
+        assert(yieldedValues.size() == 1);
+        replacementValue = yieldedValues.front();
+      }
+
+      privatizerClone.erase();
+      builder.restoreIP(oldIP);
+    }
+
     return codeGenIP;
   };
 
@@ -1100,13 +1213,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto finiCB = [&](InsertPointTy codeGenIP) {};
 
   llvm::Value *ifCond = nullptr;
-  if (auto ifExprVar = opInst.getIfExprVar())
+  if (auto ifExprVar = opInstClone.getIfExprVar())
     ifCond = moduleTranslation.lookupValue(ifExprVar);
   llvm::Value *numThreads = nullptr;
-  if (auto numThreadsVar = opInst.getNumThreadsVar())
+  if (auto numThreadsVar = opInstClone.getNumThreadsVar())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
-  if (auto bind = opInst.getProcBindVal())
+  if (auto bind = opInstClone.getProcBindVal())
     pbKind = getProcBindKind(*bind);
   // TODO: Is the Parallel construct cancellable?
   bool isCancellable = false;
@@ -1119,6 +1232,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
                                  ifCond, numThreads, pbKind, isCancellable));
 
+  opInstClone.erase();
   return bodyGenStatus;
 }
 
@@ -1635,7 +1749,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
 // A small helper structure to contain data gathered
 // for map lowering and coalese it into one area and
 // avoiding extra computations such as searches in the
-// llvm module for lowered mapped varibles or checking
+// llvm module for lowered mapped variables or checking
 // if something is declare target (and retrieving the
 // value) more than neccessary.
 struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -2854,26 +2968,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
                                                 moduleTranslation);
               return failure();
             })
-      .Case(
-          "omp.requires",
-          [&](Attribute attr) {
-            if (auto requiresAttr = attr.dyn_cast<omp::ClauseRequiresAttr>()) {
-              using Requires = omp::ClauseRequires;
-              Requires flags = requiresAttr.getValue();
-              llvm::OpenMPIRBuilderConfig &config =
-                  moduleTranslation.getOpenMPBuilder()->Config;
-              config.setHasRequiresReverseOffload(
-                  bitEnumContainsAll(flags, Requires::reverse_offload));
-              config.setHasRequiresUnifiedAddress(
-                  bitEnumContainsAll(flags, Requires::unified_address));
-              config.setHasRequiresUnifiedSharedMemory(
-                  bitEnumContainsAll(flags, Requires::unified_shared_memory));
-              config.setHasRequiresDynamicAllocators(
-                  bitEnumContainsAll(flags, Requires::dynamic_allocators));
-              return success();
-            }
-            return failure();
-          })
+      .Case("omp.requires",
+            [&](Attribute attr) {
+              if (auto requiresAttr =
+                      attr.dyn_cast<omp::ClauseRequiresAttr>()) {
+                using Requires = omp::ClauseRequires;
+                Requires flags = requiresAttr.getValue();
+                llvm::OpenMPIRBuilderConfig &config =
+                    moduleTranslation.getOpenMPBuilder()->Config;
+                config.setHasRequiresReverseOffload(
+                    bitEnumContainsAll(flags, Requires::reverse_offload));
+                config.setHasRequiresUnifiedAddress(
+                    bitEnumContainsAll(flags, Requires::unified_address));
+                config.setHasRequiresUnifiedSharedMemory(
+                    bitEnumContainsAll(flags, Requires::unified_shared_memory));
+                config.setHasRequiresDynamicAllocators(
+                    bitEnumContainsAll(flags, Requires::dynamic_allocators));
+                return success();
+              }
+              return failure();
+            })
       .Default([](Attribute) {
         // Fall through for omp attributes that do not require lowering.
         return success();
@@ -2988,12 +3102,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
       .Case([&](omp::TargetOp) {
         return convertOmpTarget(*op, builder, moduleTranslation);
       })
-      .Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
-        // No-op, should be handled by relevant owning operations e.g.
-        // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
-        // discarded
-        return success();
-      })
+      .Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
+          [&](auto op) {
+            // No-op, should be handled by relevant owning operations e.g.
+            // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
+            // discarded
+            return success();
+          })
       .Default([&](Operation *inst) {
         return inst->emitError("unsupported OpenMP operation: ")
                << inst->getName();
diff --git a/mlir/test/Target/LLVMIR/openmp-private.mlir b/mlir/test/Target/LLVMIR/openmp-private.mlir
new file mode 100644
index 00000000000000..58bda87c3b7bea
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-private.mlir
@@ -0,0 +1,142 @@
+// Test code-gen for `omp.parallel` ops with delayed privatizers (i.e. using
+// `omp.private` ops).
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @parallel_op_1_private(%arg0: !llvm.ptr) {
+  omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr) {
+    %0 = llvm.load %arg2 : !llvm.ptr -> f32
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @parallel_op_1_private
+// CHECK-SAME: (ptr %[[ORIG:.*]]) {
+// CHECK: %[[OMP_PAR_ARG:.*]] = alloca { ptr }, align 8
+// CHECK: %[[ORIG_GEP:.*]] = getelementptr { ptr }, ptr %[[OMP_PAR_ARG]], i32 0, i32 0
+// CHECK: store ptr %[[ORIG]], ptr %[[ORIG_GEP]], align 8
+// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @parallel_op_1_private..omp_par, ptr %[[OMP_PAR_ARG]])
+// CHECK: }
+
+// CHECK-LABEL: void @parallel_op_1_private..omp_par
+// CHECK-SAME: (ptr noalias %{{.*}}, ptr noalias %{{.*}}, ptr %[[ARG:.*]])
+// CHECK: %[[ORIG_PTR_PTR:.*]] = getelementptr { ptr }, ptr %[[ARG]], i32 0, i32 0
+// CHECK: %[[ORIG_PTR:.*]] = load ptr, ptr %[[ORIG_PTR_PTR]], align 8
+
+// Check that the privatizer alloc region was inlined properly.
+// CHECK: %[[PRIV_ALLOC:.*]] = alloca float, align 4
+// CHECK: %[[ORIG_VAL:.*]] = load float, ptr %[[ORIG_PTR]], align 4
+// CHECK: store float %[[ORIG_VAL]], ptr %[[PRIV_ALLOC]], align 4
+// CHECK-NEXT: br
+
+// Check that the privatized value is used (rather than the original one).
+// CHECK: load float, ptr %[[PRIV_ALLOC]], align 4
+// CHECK: }
+
+llvm.func @parallel_op_2_privates(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
+    %0 = llvm.load %arg2 : !llvm.ptr -> f32
+    %1 = llvm.load %arg3 : !llvm.ptr -> i32
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @parallel_op_2_privates
+// CHECK-SAME: (ptr %[[ORIG1:.*]], ptr %[[ORIG2:.*]]) {
+// CHECK: %[[OMP_PAR_ARG:.*]] = alloca { ptr, ptr }, align 8
+// CHECK: %[[ORIG1_GEP:.*]] = getelementptr { ptr, ptr }, ptr %[[OMP_PAR_ARG]], i32 0, i32 0
+// CHECK: store ptr %[[ORIG1]], ptr %[[ORIG1_GEP]], align 8
+// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @parallel_op_2_privates..omp_par, ptr %[[OMP_PAR_ARG]])
+// CHECK: }
+
+// CHECK-LABEL: void @parallel_op_2_privates..omp_par
+// CHECK-SAME: (ptr noalias %{{.*}}, ptr noalias %{{.*}}, ptr %[[ARG:.*]])
+// CHECK: %[[ORIG1_PTR_PTR:.*]] = getelementptr { ptr, ptr }, ptr %[[ARG]], i32 0, i32 0
+// CHECK: %[[ORIG1_PTR:.*]] = load ptr, ptr %[[ORIG1_PTR_PTR]], align 8
+// CHECK: %[[ORIG2_PTR_PTR:.*]] = getelementptr { ptr, ptr }, ptr %[[ARG]], i32 0, i32 1
+// CHECK: %[[ORIG2_PTR:.*]] = load ptr, ptr %[[ORIG2_PTR_PTR]], align 8
+
+// Check that the privatizer alloc region was inlined properly.
+// CHECK: %[[PRIV1_ALLOC:.*]] = alloca float, align 4
+// CHECK: %[[ORIG1_VAL:.*]] = load float, ptr %[[ORIG1_PTR]], align 4
+// CHECK: store float %[[ORIG1_VAL]], ptr %[[PRIV1_ALLOC]], align 4
+// CHECK: %[[PRIV2_ALLOC:.*]] = alloca i32, align 4
+// CHECK: %[[ORIG2_VAL:.*]] = load i32, ptr %[[ORIG2_PTR]], align 4
+// CHECK: store i32 %[[ORIG2_VAL]], ptr %[[PRIV2_ALLOC]], align 4
+// CHECK-NEXT: br
+
+// Check that the privatized value is used (rather than the original one).
+// CHECK: load float, ptr %[[PRIV1_ALLOC]], align 4
+// CHECK: load i32, ptr %[[PRIV2_ALLOC]], align 4
+// CHECK: }
+
+omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+  %c1 = llvm.mlir.constant(1 : i32) : i32
+  %0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
+  %1 = llvm.load %arg0 : !llvm.ptr -> f32
+  llvm.store %1, %0 : f32, !llvm.ptr
+  omp.yield(%0 : !llvm.ptr)
+}
+
+omp.private {type = private} @y.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+  %c1 = llvm.mlir.constant(1 : i32) : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+  %1 = llvm.load %arg0 : !llvm.ptr -> i32
+  llvm.store %1, %0 : i32, !llvm.ptr
+  omp.yield(%0 : !llvm.ptr)
+}
+
+// -----
+
+llvm.func @parallel_op_private_multi_block(%arg0: !llvm.ptr) {
+  omp.parallel private(@multi_block.privatizer %arg0 -> %arg2 : !llvm.ptr) {
+    %0 = llvm.load %arg2 : !llvm.ptr -> f32
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define internal void @parallel_op_private_multi_block..omp_par
+// CHECK: omp.par.entry:
+// CHECK:  %[[ORIG_PTR_PTR:.*]] = getelementptr { ptr }, ptr %{{.*}}, i32 0, i32 0
+// CHECK:  %[[ORIG_PTR:.*]] = load ptr, ptr %[[ORIG_PTR_PTR]], align 8
+// CHECK:   br label %[[PRIV_BB1:.*]]
+
+// Check contents of the first block in the `alloc` region.
+// CHECK: [[PRIV_BB1]]:
+// CHECK-NEXT:   %[[PRIV_ALLOC:.*]] = alloca float, align 4
+// CHECK-NEXT:   br label %[[PRIV_BB2:.*]]
+
+// Check contents of the second block in the `alloc` region.
+// CHECK: [[PRIV_BB2]]:
+// CHECK-NEXT:   %[[ORIG_PTR2:.*]] = phi ptr [ %[[ORIG_PTR]], %[[PRIV_BB1]] ]
+// CHECK-NEXT:   %[[PRIV_ALLOC2:.*]] = phi ptr [ %[[PRIV_ALLOC]], %[[PRIV_BB1]] ]
+// CHECK-NEXT:   %[[ORIG_VAL:.*]] = load float, ptr %[[ORIG_PTR2]], align 4
+// CHECK-NEXT:   store float %[[ORIG_VAL]], ptr %[[PRIV_ALLOC2]], align 4
+// CHECK-NEXT:   br label %[[PRIV_CONT:.*]]
+
+// Check that the privatizer's continuation block yileds the private clone's
+// address.
+// CHECK: [[PRIV_CONT]]:
+// CHECK-NEXT:   %[[PRIV_ALLOC3:.*]] = phi ptr [ %[[PRIV_ALLOC2]], %[[PRIV_BB2]] ]
+// CHECK-NEXT:   br label %[[PAR_REG:.*]]
+
+// Check that the body of the parallel region loads from the private clone.
+// CHECK: [[PAR_REG]]:
+// CHECK:        %{{.*}} = load float, ptr %[[PRIV_ALLOC3]], align 4
+
+omp.private {type = private} @multi_block.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+  %c1 = llvm.mlir.constant(1 : i32) : i32
+  %0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
+  llvm.br ^bb1(%arg0, %0 : !llvm.ptr, !llvm.ptr)
+
+^bb1(%arg1: !llvm.ptr, %arg2: !llvm.ptr):
+  %1 = llvm.load %arg1 : !llvm.ptr -> f32
+  llvm.store %1, %arg2 : f32, !llvm.ptr
+  omp.yield(%arg2 : !llvm.ptr)
+}

>From d2d1ee0323ad74d7f03efba7f68efe9a9bd37da3 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 22 Feb 2024 22:41:37 -0600
Subject: [PATCH 2/2] Add RAII object to manage mapping of the op's arguments.

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 105 +++++++++---------
 1 file changed, 55 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0c390f68bb6890..8c20689c4a39dd 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1011,40 +1011,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Replace the region arguments of the parallel op (which correspond to private
-/// variables) with the actual private variables they correspond to. This
-/// prepares the parallel op so that it matches what is expected by the
-/// OMPIRBuilder. Instead of editing the original op in-place, this function
-/// does the required changes to a cloned version which should then be erased by
-/// the caller.
-static omp::ParallelOp
-prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
-  Region &region = opInst.getRegion();
-  auto privateVars = opInst.getPrivateVars();
-
-  auto privateVarsIt = privateVars.begin();
-  // Reduction precede private arguments, so skip them first.
-  unsigned privateArgBeginIdx = opInst.getNumReductionVars();
-  unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
-
-  mlir::IRMapping mapping;
-  for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
-       ++argIdx, ++privateVarsIt)
-    mapping.map(region.getArgument(argIdx), *privateVarsIt);
-
-  mlir::OpBuilder cloneBuilder(opInst);
-  omp::ParallelOp opInstClone =
-      llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst, mapping));
-
-  return opInstClone;
-}
+/// A RAII class that on construction replaces the region arguments of the
+/// parallel op (which correspond to private variables) with the actual private
+/// variables they correspond to. This prepares the parallel op so that it
+/// matches what is expected by the OMPIRBuilder.
+///
+/// On destruction, it restores the original state of the operation so that on
+/// the MLIR side, the op is not affected by conversion to LLVM IR.
+class OmpParallelOpConversionManager {
+public:
+  OmpParallelOpConversionManager(omp::ParallelOp opInst)
+      : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
+        privateArgBeginIdx(opInst.getNumReductionVars()),
+        privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
+    auto privateVarsIt = privateVars.begin();
+
+    for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
+         ++argIdx, ++privateVarsIt)
+      mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
+                                       *privateVarsIt, region);
+  }
+
+  ~OmpParallelOpConversionManager() {
+    auto privateVarsIt = privateVars.begin();
+
+    for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
+         ++argIdx, ++privateVarsIt)
+      mlir::replaceAllUsesInRegionWith(*privateVarsIt,
+                                       region.getArgument(argIdx), region);
+  }
+
+private:
+  Region ®ion;
+  OperandRange privateVars;
+  unsigned privateArgBeginIdx;
+  unsigned privateArgEndIdx;
+};
 
 /// Converts the OpenMP parallel operation to LLVM IR.
 static LogicalResult
 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
-  omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
+  OmpParallelOpConversionManager raii(opInst);
 
   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
   // relying on captured variables.
@@ -1054,12 +1063,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
     // Collect reduction declarations
     SmallVector<omp::ReductionDeclareOp> reductionDecls;
-    collectReductionDecls(opInstClone, reductionDecls);
+    collectReductionDecls(opInst, reductionDecls);
 
     // Allocate reduction vars
     SmallVector<llvm::Value *> privateReductionVariables;
     DenseMap<Value, llvm::Value *> reductionVariableMap;
-    allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
+    allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
                        reductionDecls, privateReductionVariables,
                        reductionVariableMap);
 
@@ -1071,7 +1080,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
 
     // Initialize reduction vars
     builder.restoreIP(allocaIP);
-    for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
+    for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
       SmallVector<llvm::Value *> phis;
       if (failed(inlineConvertOmpRegions(
               reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1092,19 +1101,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     // ParallelOp has only one region associated with it.
     builder.restoreIP(codeGenIP);
     auto regionBlock =
-        convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
+        convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
                             moduleTranslation, bodyGenStatus);
 
     // Process the reductions if required.
-    if (opInstClone.getNumReductionVars() > 0) {
+    if (opInst.getNumReductionVars() > 0) {
       // Collect reduction info
       SmallVector<OwningReductionGen> owningReductionGens;
       SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
       SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
-      collectReductionInfo(opInstClone, builder, moduleTranslation,
-                           reductionDecls, owningReductionGens,
-                           owningAtomicReductionGens, privateReductionVariables,
-                           reductionInfos);
+      collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
+                           owningReductionGens, owningAtomicReductionGens,
+                           privateReductionVariables, reductionInfos);
 
       // Move to region cont block
       builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1117,8 +1125,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
           ompBuilder->createReductions(builder.saveIP(), allocaIP,
                                        reductionInfos, false);
       if (!contInsertPoint.getBlock()) {
-        bodyGenStatus = opInstClone->emitOpError()
-                        << "failed to convert reductions";
+        bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
         return;
       }
 
@@ -1140,9 +1147,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     // returned.
     auto [privVar, privatizerClone] =
         [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
-      if (!opInstClone.getPrivateVars().empty()) {
-        auto privVars = opInstClone.getPrivateVars();
-        auto privatizers = opInstClone.getPrivatizers();
+      if (!opInst.getPrivateVars().empty()) {
+        auto privVars = opInst.getPrivateVars();
+        auto privatizers = opInst.getPrivatizers();
 
         for (auto [privVar, privatizerAttr] :
              llvm::zip_equal(privVars, *privatizers)) {
@@ -1155,7 +1162,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
           SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
           omp::PrivateClauseOp privatizer =
               SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
-                  opInstClone, privSym);
+                  opInst, privSym);
 
           // Clone the privatizer in case it is used by more than one parallel
           // region. The privatizer is processed in-place (see below) before it
@@ -1192,9 +1199,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       SmallVector<llvm::Value *, 1> yieldedValues;
       if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
                                          moduleTranslation, &yieldedValues))) {
-        opInstClone.emitError(
-            "failed to inline `alloc` region of an `omp.private` "
-            "op in the parallel region");
+        opInst.emitError("failed to inline `alloc` region of an `omp.private` "
+                         "op in the parallel region");
         bodyGenStatus = failure();
       } else {
         assert(yieldedValues.size() == 1);
@@ -1213,13 +1219,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   auto finiCB = [&](InsertPointTy codeGenIP) {};
 
   llvm::Value *ifCond = nullptr;
-  if (auto ifExprVar = opInstClone.getIfExprVar())
+  if (auto ifExprVar = opInst.getIfExprVar())
     ifCond = moduleTranslation.lookupValue(ifExprVar);
   llvm::Value *numThreads = nullptr;
-  if (auto numThreadsVar = opInstClone.getNumThreadsVar())
+  if (auto numThreadsVar = opInst.getNumThreadsVar())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
-  if (auto bind = opInstClone.getProcBindVal())
+  if (auto bind = opInst.getProcBindVal())
     pbKind = getProcBindKind(*bind);
   // TODO: Is the Parallel construct cancellable?
   bool isCancellable = false;
@@ -1232,7 +1238,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
       ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
                                  ifCond, numThreads, pbKind, isCancellable));
 
-  opInstClone.erase();
   return bodyGenStatus;
 }
 



More information about the Mlir-commits mailing list