[flang-commits] [mlir] [flang] [WIP] Delayed privatization. (PR #79862)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Fri Feb 2 02:14:56 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/79862

>From 1ad8a5c1161e8bcf2ca35c4c75bf9e08093d943b Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 29 Jan 2024 04:45:18 -0600
Subject: [PATCH 1/4] [WIP] Delayed privatization.

This is a PoC for delayed privatization in OpenMP. Instead of directly
emitting privatization code in the frontend, we add a new op to outline
the privatization logic for a symbol and call-like mapping that maps
from the host symbol to an outlined function-like privatizer op.

Later, we would inline the delayed privatizer function-like op in the
OpenMP region to basically get the same code generated directly by the
fronend at the moment.
---
 flang/include/flang/Lower/AbstractConverter.h |   4 +
 flang/lib/Lower/Bridge.cpp                    |   2 +-
 flang/lib/Lower/OpenMP.cpp                    | 132 ++++++++++++++----
 .../OpenMP/FIR/delayed_privatization.f90      |  43 ++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  43 +++++-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |   4 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  92 +++++++++++-
 mlir/test/Dialect/OpenMP/roundtrip.mlir       |  36 +++++
 8 files changed, 326 insertions(+), 30 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
 create mode 100644 mlir/test/Dialect/OpenMP/roundtrip.mlir

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 796933a4eb5f6..55bc33e76e5f6 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -16,6 +16,7 @@
 #include "flang/Common/Fortran.h"
 #include "flang/Lower/LoweringOptions.h"
 #include "flang/Lower/PFTDefs.h"
+#include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Semantics/symbol.h"
 #include "mlir/IR/Builders.h"
@@ -296,6 +297,9 @@ class AbstractConverter {
     return loweringOptions;
   }
 
+  virtual Fortran::lower::SymbolBox
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 4c8e0cb128744..f511212fd34f5 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1070,7 +1070,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Find the symbol in one level up of symbol map such as for host-association
   /// in OpenMP code or return null.
   Fortran::lower::SymbolBox
-  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
     if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
       return v;
     return {};
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index be2117efbabc0..4d012c45108fd 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -161,6 +161,11 @@ class DataSharingProcessor {
   const Fortran::parser::OmpClauseList &opClauseList;
   Fortran::lower::pft::Evaluation &eval;
 
+  bool useDelayedPrivatizationWhenPossible;
+  Fortran::lower::SymMap *symTable;
+  llvm::SetVector<mlir::SymbolRefAttr> privateInitializers;
+  llvm::SetVector<mlir::Value> privateSymHostAddrsses;
+
   bool needBarrier();
   void collectSymbols(Fortran::semantics::Symbol::Flag flag);
   void collectOmpObjectListSymbol(
@@ -182,10 +187,14 @@ class DataSharingProcessor {
 public:
   DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
                        const Fortran::parser::OmpClauseList &opClauseList,
-                       Fortran::lower::pft::Evaluation &eval)
+                       Fortran::lower::pft::Evaluation &eval,
+                       bool useDelayedPrivatizationWhenPossible = false,
+                       Fortran::lower::SymMap *symTable = nullptr)
       : hasLastPrivateOp(false), converter(converter),
         firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
-        eval(eval) {}
+        eval(eval), useDelayedPrivatizationWhenPossible(
+                        useDelayedPrivatizationWhenPossible),
+        symTable(symTable) {}
   // Privatisation is split into two steps.
   // Step1 performs cloning of all privatisation clauses and copying for
   // firstprivates. Step1 is performed at the place where process/processStep1
@@ -204,6 +213,14 @@ class DataSharingProcessor {
     assert(!loopIV && "Loop iteration variable already set");
     loopIV = iv;
   }
+
+  const llvm::SetVector<mlir::SymbolRefAttr> &getPrivateInitializers() const {
+    return privateInitializers;
+  };
+
+  const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
+    return privateSymHostAddrsses;
+  }
 };
 
 void DataSharingProcessor::processStep1() {
@@ -496,8 +513,46 @@ void DataSharingProcessor::privatize() {
         copyFirstPrivateSymbol(&*mem);
       }
     } else {
-      cloneSymbol(sym);
-      copyFirstPrivateSymbol(sym);
+      if (useDelayedPrivatizationWhenPossible) {
+        auto ip = firOpBuilder.saveInsertionPoint();
+
+        auto moduleOp = firOpBuilder.getInsertionBlock()
+                            ->getParentOp()
+                            ->getParentOfType<mlir::ModuleOp>();
+
+        firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
+                                       moduleOp.getBodyRegion().front().end());
+
+        Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
+        assert(hsb && "Host symbol box not found");
+
+        auto symType = hsb.getAddr().getType();
+        auto symLoc = hsb.getAddr().getLoc();
+        auto privatizerOp = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
+            symLoc, symType, sym->name().ToString());
+        firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());
+
+        symTable->pushScope();
+        symTable->addSymbol(*sym, privatizerOp.getArgument(0));
+        symTable->pushScope();
+
+        cloneSymbol(sym);
+        copyFirstPrivateSymbol(sym);
+
+        firOpBuilder.create<mlir::omp::YieldOp>(
+            hsb.getAddr().getLoc(),
+            symTable->shallowLookupSymbol(*sym).getAddr());
+
+        symTable->popScope();
+        symTable->popScope();
+        firOpBuilder.restoreInsertionPoint(ip);
+
+        privateInitializers.insert(mlir::SymbolRefAttr::get(privatizerOp));
+        privateSymHostAddrsses.insert(hsb.getAddr());
+      } else {
+        cloneSymbol(sym);
+        copyFirstPrivateSymbol(sym);
+      }
     }
   }
 }
@@ -2463,12 +2518,12 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
                           Fortran::lower::pft::Evaluation &eval, bool genNested,
                           mlir::Location currentLocation, bool outerCombined,
                           const Fortran::parser::OmpClauseList *clauseList,
-                          Args &&...args) {
+                          DataSharingProcessor *dsp, Args &&...args) {
   auto op = converter.getFirOpBuilder().create<OpTy>(
       currentLocation, std::forward<Args>(args)...);
   createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
                        clauseList,
-                       /*args=*/{}, outerCombined);
+                       /*args=*/{}, outerCombined, dsp);
   return op;
 }
 
@@ -2480,6 +2535,7 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
                                             currentLocation,
                                             /*outerCombined=*/false,
                                             /*clauseList=*/nullptr,
+                                            /*dsp=*/nullptr,
                                             /*resultTypes=*/mlir::TypeRange());
 }
 
@@ -2487,14 +2543,17 @@ static mlir::omp::OrderedRegionOp
 genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval, bool genNested,
                    mlir::Location currentLocation) {
-  return genOpWithBody<mlir::omp::OrderedRegionOp>(
-      converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false,
-      /*clauseList=*/nullptr, /*simd=*/false);
+  return genOpWithBody<mlir::omp::OrderedRegionOp>(converter, eval, genNested,
+                                                   currentLocation,
+                                                   /*outerCombined=*/false,
+                                                   /*clauseList=*/nullptr,
+                                                   /*dsp=*/nullptr,
+                                                   /*simd=*/false);
 }
 
 static mlir::omp::ParallelOp
 genParallelOp(Fortran::lower::AbstractConverter &converter,
+              Fortran::lower::SymMap &symTable,
               Fortran::lower::pft::Evaluation &eval, bool genNested,
               mlir::Location currentLocation,
               const Fortran::parser::OmpClauseList &clauseList,
@@ -2516,16 +2575,37 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   if (!outerCombined)
     cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
 
+  bool privatize = !outerCombined;
+  DataSharingProcessor dsp(converter, clauseList, eval,
+                           /*useDelayedPrivatizationWhenPossible=*/true,
+                           &symTable);
+
+  if (privatize) {
+    dsp.processStep1();
+  }
+
+  llvm::SmallVector<mlir::Attribute> privateInits(
+      dsp.getPrivateInitializers().begin(), dsp.getPrivateInitializers().end());
+
+  llvm::SmallVector<mlir::Value> privateSymAddresses(
+      dsp.getPrivateSymHostAddrsses().begin(),
+      dsp.getPrivateSymHostAddrsses().end());
+
   return genOpWithBody<mlir::omp::ParallelOp>(
       converter, eval, genNested, currentLocation, outerCombined, &clauseList,
+      &dsp,
       /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
-      reductionVars,
+      reductionVars, privateSymAddresses,
       reductionDeclSymbols.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  reductionDeclSymbols),
-      procBindKindAttr);
+      procBindKindAttr,
+      privateInits.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 privateInits));
 }
 
 static mlir::omp::SectionOp
@@ -2537,7 +2617,8 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
   // all privatization is done within `omp.section` operations.
   return genOpWithBody<mlir::omp::SectionOp>(
       converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false, &sectionsClauseList);
+      /*outerCombined=*/false, &sectionsClauseList,
+      /*dsp=*/nullptr);
 }
 
 static mlir::omp::SingleOp
@@ -2558,8 +2639,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::SingleOp>(
       converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false, &beginClauseList, allocateOperands,
-      allocatorOperands, nowaitAttr);
+      /*outerCombined=*/false, &beginClauseList, /*dsp=*/nullptr,
+      allocateOperands, allocatorOperands, nowaitAttr);
 }
 
 static mlir::omp::TaskOp
@@ -2591,8 +2672,8 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::TaskOp>(
       converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false, &clauseList, ifClauseOperand, finalClauseOperand,
-      untiedAttr, mergeableAttr,
+      /*outerCombined=*/false, &clauseList, /*dsp=*/nullptr, ifClauseOperand,
+      finalClauseOperand, untiedAttr, mergeableAttr,
       /*in_reduction_vars=*/mlir::ValueRange(),
       /*in_reductions=*/nullptr, priorityClauseOperand,
       dependTypeOperands.empty()
@@ -2615,6 +2696,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
   return genOpWithBody<mlir::omp::TaskGroupOp>(
       converter, eval, genNested, currentLocation,
       /*outerCombined=*/false, &clauseList,
+      /*dsp=*/nullptr,
       /*task_reduction_vars=*/mlir::ValueRange(),
       /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
 }
@@ -2994,6 +3076,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::TeamsOp>(
       converter, eval, genNested, currentLocation, outerCombined, &clauseList,
+      /*dsp=*/nullptr,
       /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
       threadLimitClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
@@ -3392,8 +3475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
-                    loopOpClauseList,
+      genParallelOp(converter, symTable, eval, /*genNested=*/false,
+                    currentLocation, loopOpClauseList,
                     /*outerCombined=*/true);
     }
   }
@@ -3481,8 +3564,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation);
     break;
   case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, eval, /*genNested=*/true, currentLocation,
-                  beginClauseList);
+    genParallelOp(converter, symTable, eval, /*genNested=*/true,
+                  currentLocation, beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_single:
     genSingleOp(converter, eval, /*genNested=*/true, currentLocation,
@@ -3541,8 +3624,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
           .test(directive.v)) {
     bool outerCombined =
         directive.v != llvm::omp::Directive::OMPD_target_parallel;
-    genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
-                  beginClauseList, outerCombined);
+    genParallelOp(converter, symTable, eval, /*genNested=*/false,
+                  currentLocation, beginClauseList, outerCombined);
     combinedDirective = true;
   }
   if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
@@ -3625,7 +3708,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
 
   // Parallel wrapper of PARALLEL SECTIONS construct
   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
-    genParallelOp(converter, eval,
+    genParallelOp(converter, symTable, eval,
                   /*genNested=*/false, currentLocation, sectionsClauseList,
                   /*outerCombined=*/true);
   } else {
@@ -3642,6 +3725,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
                                        /*genNested=*/false, currentLocation,
                                        /*outerCombined=*/false,
                                        /*clauseList=*/nullptr,
+                                       /*dsp=*/nullptr,
                                        /*reduction_vars=*/mlir::ValueRange(),
                                        /*reductions=*/nullptr, allocateOperands,
                                        allocatorOperands, nowaitClauseOperand);
diff --git a/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90 b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
new file mode 100644
index 0000000000000..6a668e6fb6660
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
@@ -0,0 +1,43 @@
+subroutine delayed_privatization()
+  integer :: var1
+  integer :: var2
+
+!$OMP PARALLEL FIRSTPRIVATE(var1, var2)
+  var1 = var1 + var2 + 2
+!$OMP END PARALLEL
+
+end subroutine
+
+! This is what flang emits with the PoC:
+! --------------------------------------
+!
+!func.func @_QPdelayed_privatization() {
+!  %0 = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatizationEvar1"}
+!  %1 = fir.alloca i32 {bindc_name = "var2", uniq_name = "_QFdelayed_privatizationEvar2"}
+!  omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
+!    %2 = fir.load %0 : !fir.ref<i32>
+!    %3 = fir.load %1 : !fir.ref<i32>
+!    %4 = arith.addi %2, %3 : i32
+!    %c2_i32 = arith.constant 2 : i32
+!    %5 = arith.addi %4, %c2_i32 : i32
+!    fir.store %5 to %0 : !fir.ref<i32>
+!    omp.terminator
+!  }
+!  return
+!}
+!
+!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var1.privatizer"}> ({
+!^bb0(%arg0: !fir.ref<i32>):
+!  %0 = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatizationEvar1"}
+!  %1 = fir.load %arg0 : !fir.ref<i32>
+!  fir.store %1 to %0 : !fir.ref<i32>
+!  omp.yield(%0 : !fir.ref<i32>)
+!}) : () -> ()
+!
+!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var2.privatizer"}> ({
+!^bb0(%arg0: !fir.ref<i32>):
+!  %0 = fir.alloca i32 {bindc_name = "var2", pinned, uniq_name = "_QFdelayed_privatizationEvar2"}
+!  %1 = fir.load %arg0 : !fir.ref<i32>
+!  fir.store %1 to %0 : !fir.ref<i32>
+!  omp.yield(%0 : !fir.ref<i32>)
+!}) : () -> ()
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 96c15e775a302..23e058d372c79 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -16,6 +16,7 @@
 
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
@@ -187,8 +188,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
              Variadic<AnyType>:$allocate_vars,
              Variadic<AnyType>:$allocators_vars,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+             Variadic<AnyType>:$private_vars,
              OptionalAttr<SymbolRefArrayAttr>:$reductions,
-             OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
+             OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
+             OptionalAttr<SymbolRefArrayAttr>:$private_inits);
 
   let regions = (region AnyRegion:$region);
 
@@ -212,6 +215,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
                 $allocators_vars, type($allocators_vars)
               ) `)`
           | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
+          | `private` `(`
+              custom<PrivateVarList>(
+                $private_vars, type($private_vars), $private_inits
+              ) `)`
     ) $region attr-dict
   }];
   let hasVerifier = 1;
@@ -621,7 +628,7 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
 def YieldOp : OpenMP_Op<"yield",
     [Pure, ReturnLike, Terminator,
      ParentOneOf<["WsLoopOp", "ReductionDeclareOp",
-     "AtomicUpdateOp", "SimdLoopOp"]>]> {
+     "AtomicUpdateOp", "SimdLoopOp", "PrivateClauseOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -1478,6 +1485,38 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
 //===----------------------------------------------------------------------===//
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
+def PrivateClauseOp : OpenMP_Op<"private", [
+    IsolatedFromAbove, FunctionOpInterface
+  ]> {
+  let summary = "TODO";
+  let description = [{}];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type);
+
+  let regions = (region AnyRegion:$body);
+
+  let builders = [OpBuilder<(ins
+    "::mlir::Type":$privateVar,
+    "::llvm::StringRef":$privateVarName
+  )>];
+
+  let extraClassDeclaration = [{
+    ::mlir::Region *getCallableRegion() {
+      return &getBody();
+    }
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+  }];
+}
 
 def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
   let summary = "target construct";
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 2f8b3f7e11de1..b381aaf20bf89 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -419,8 +419,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocators_vars = */ llvm::SmallVector<Value>{},
         /* reduction_vars = */ llvm::SmallVector<Value>{},
+        /*private_vars=*/mlir::ValueRange{},
         /* reductions = */ ArrayAttr{},
-        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
+        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
+        /*private_inits*/ nullptr);
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 13cc16125a273..c4ef7ef3f2fb5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -989,8 +989,9 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(
       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
-      /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
-      /*proc_bind_val=*/nullptr);
+      /*reduction_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
+      /*reductions=*/nullptr,
+      /*proc_bind_val=*/nullptr, /*private_inits*/ nullptr);
   state.addAttributes(attributes);
 }
 
@@ -1594,6 +1595,93 @@ LogicalResult DataBoundsOp::verify() {
   return success();
 }
 
+void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                            Type privateVarType, StringRef privateVarName) {
+  FunctionType initializerType = FunctionType::get(
+      odsBuilder.getContext(), {privateVarType}, {privateVarType});
+  std::string privatizerName = (privateVarName + ".privatizer").str();
+
+  build(odsBuilder, odsState, privatizerName, initializerType);
+
+  mlir::Block &block = odsState.regions.front()->emplaceBlock();
+  block.addArgument(privateVarType, odsState.location);
+}
+
+static ParseResult parsePrivateVarList(
+    OpAsmParser &parser,
+    llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
+    llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privateInitsAttr) {
+  SymbolRefAttr privatizerSym;
+  OpAsmParser::UnresolvedOperand arg;
+  OpAsmParser::UnresolvedOperand blockArg;
+  Type argType;
+
+  SmallVector<SymbolRefAttr> privateInitsVec;
+
+  auto parsePrivatizers = [&]() -> ParseResult {
+    if (parser.parseAttribute(privatizerSym) || parser.parseOperand(arg)) {
+      return failure();
+    }
+
+    privateInitsVec.push_back(privatizerSym);
+    privateVarsOperands.push_back(arg);
+    return success();
+  };
+
+  auto parseTypes = [&]() -> ParseResult {
+    if (parser.parseType(argType))
+      return failure();
+    privateVarsTypes.push_back(argType);
+    return success();
+  };
+
+  if (parser.parseCommaSeparatedList(parsePrivatizers))
+    return failure();
+
+  SmallVector<Attribute> privateInits(privateInitsVec.begin(),
+                                      privateInitsVec.end());
+  privateInitsAttr = ArrayAttr::get(parser.getContext(), privateInits);
+
+  if (parser.parseColon())
+    return failure();
+
+  if (parser.parseCommaSeparatedList(parseTypes))
+    return failure();
+
+  return success();
+}
+
+static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
+                                OperandRange privateVars,
+                                TypeRange privateVarTypes,
+                                std::optional<ArrayAttr> privateInitsAttr) {
+  unsigned argIndex = 0;
+  assert(privateVars.size() == privateVarTypes.size() &&
+         ((privateVars.empty()) ||
+          (*privateInitsAttr &&
+           (privateInitsAttr->size() == privateVars.size()))));
+
+  for (const auto &privateVar : privateVars) {
+    assert(privateInitsAttr);
+    const auto &privateInitSym = (*privateInitsAttr)[argIndex];
+    printer << privateInitSym << " " << privateVar;
+
+    argIndex++;
+    if (argIndex < privateVars.size())
+      printer << ", ";
+  }
+
+  printer << " : ";
+
+  argIndex = 0;
+  for (const auto &mapType : privateVarTypes) {
+    printer << mapType;
+    argIndex++;
+    if (argIndex < privateVarTypes.size())
+      printer << ", ";
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/roundtrip.mlir b/mlir/test/Dialect/OpenMP/roundtrip.mlir
new file mode 100644
index 0000000000000..c6e9fab6f7f98
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/roundtrip.mlir
@@ -0,0 +1,36 @@
+// RUN: fir-opt -verify-diagnostics %s | fir-opt | FileCheck %s
+
+// CHECK-LABEL: _QPprivate_clause
+func.func @_QPprivate_clause() {
+  %0 = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFprivate_clause_allocatableEx"}
+  %1 = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFprivate_clause_allocatableEy"}
+
+  // CHECK: omp.parallel private(@x.privatizer %0, @y.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>)
+  omp.parallel private(@x.privatizer %0, @y.privatizer %1: !fir.ref<i32>, !fir.ref<i32>) {
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
+"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
+// CHECK: ^bb0(%arg0: {{.*}}):
+^bb0(%arg0: !fir.ref<i32>):
+
+  // CHECK: %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
+  %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
+
+  // CHECK: omp.yield(%0 : !fir.ref<i32>)
+  omp.yield(%0 : !fir.ref<i32>)
+}) : () -> ()
+
+// CHECK: "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "y.privatizer"}> ({
+"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "y.privatizer"}> ({
+^bb0(%arg0: !fir.ref<i32>):
+
+  // CHECK: %0 = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFprivate_clause_allocatableEy"}
+  %0 = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFprivate_clause_allocatableEy"}
+
+  // CHECK: omp.yield(%0 : !fir.ref<i32>)
+  omp.yield(%0 : !fir.ref<i32>)
+}) : () -> ()

>From 89039513bea77eb9a9e1b8278e9afbc065538505 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 31 Jan 2024 08:34:54 -0600
Subject: [PATCH 2/4] Add conversion patttern for `PrivateClauseOp`.

---
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  | 23 +++++++++++++++----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 730858ffc67a7..d4ccbdf608293 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -46,6 +46,17 @@ struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
                                            *this->getTypeConverter())))
       return failure();
 
+    if constexpr (std::is_same_v<OpType, mlir::omp::PrivateClauseOp>) {
+      auto llvmType = this->getTypeConverter()->convertType(
+          adaptor.getFunctionType().getInput(0));
+
+      if (!llvmType)
+        return rewriter.notifyMatchFailure(curOp,
+                                           "signature conversion failed");
+      newOp.setFunctionType(
+          FunctionType::get(rewriter.getContext(), {llvmType}, {llvmType}));
+    }
+
     rewriter.eraseOp(curOp);
     return success();
   }
@@ -231,11 +242,12 @@ void mlir::configureOpenMPToLLVMConversionLegality(
       mlir::omp::DataOp, mlir::omp::OrderedRegionOp, mlir::omp::ParallelOp,
       mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp, mlir::omp::MasterOp,
       mlir::omp::SectionOp, mlir::omp::SectionsOp, mlir::omp::SingleOp,
-      mlir::omp::TaskGroupOp, mlir::omp::TaskOp>([&](Operation *op) {
-    return typeConverter.isLegal(&op->getRegion(0)) &&
-           typeConverter.isLegal(op->getOperandTypes()) &&
-           typeConverter.isLegal(op->getResultTypes());
-  });
+      mlir::omp::TaskGroupOp, mlir::omp::TaskOp, mlir::omp::PrivateClauseOp>(
+      [&](Operation *op) {
+        return typeConverter.isLegal(&op->getRegion(0)) &&
+               typeConverter.isLegal(op->getOperandTypes()) &&
+               typeConverter.isLegal(op->getResultTypes());
+      });
   target.addDynamicallyLegalOp<
       mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp,
       mlir::omp::ThreadprivateOp, mlir::omp::YieldOp, mlir::omp::EnterDataOp,
@@ -275,6 +287,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionOpConversion<omp::SimdLoopOp>, RegionOpConversion<omp::SingleOp>,
       RegionOpConversion<omp::TaskGroupOp>, RegionOpConversion<omp::TaskOp>,
       RegionOpConversion<omp::DataOp>, RegionOpConversion<omp::TargetOp>,
+      RegionOpConversion<omp::PrivateClauseOp>,
       RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
       RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>,
       RegionLessOpWithVarOperandsConversion<omp::FlushOp>,

>From 0d05d4aa0b2316ba4020e95ded04bddb67a80234 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Fri, 2 Feb 2024 01:33:01 -0600
Subject: [PATCH 3/4] Convert private clauses to LLVM.

---
 .../OpenMP/FIR/delayed_privatization.f90      | 192 +++++++++++++++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  82 +++++++-
 2 files changed, 240 insertions(+), 34 deletions(-)

diff --git a/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90 b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
index 6a668e6fb6660..c6c8523beb676 100644
--- a/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
+++ b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
@@ -1,43 +1,179 @@
+! TODO Convert this file into a bunch of lit tests for each conversion step.
+
 subroutine delayed_privatization()
   integer :: var1
   integer :: var2
 
+  var1 = 111
+  var2 = 222
+
 !$OMP PARALLEL FIRSTPRIVATE(var1, var2)
   var1 = var1 + var2 + 2
 !$OMP END PARALLEL
 
 end subroutine
 
-! This is what flang emits with the PoC:
-! --------------------------------------
+! -----------------------------------------
+! ## This is what flang emits with the PoC:
+! -----------------------------------------
 !
-!func.func @_QPdelayed_privatization() {
-!  %0 = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatizationEvar1"}
-!  %1 = fir.alloca i32 {bindc_name = "var2", uniq_name = "_QFdelayed_privatizationEvar2"}
-!  omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
-!    %2 = fir.load %0 : !fir.ref<i32>
-!    %3 = fir.load %1 : !fir.ref<i32>
-!    %4 = arith.addi %2, %3 : i32
-!    %c2_i32 = arith.constant 2 : i32
-!    %5 = arith.addi %4, %c2_i32 : i32
-!    fir.store %5 to %0 : !fir.ref<i32>
-!    omp.terminator
+! ----------------------------
+! ### Conversion to FIR + OMP:
+! ----------------------------
+!module {
+!  func.func @_QPdelayed_privatization() {
+!    %0 = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatizationEvar1"}
+!    %1 = fir.alloca i32 {bindc_name = "var2", uniq_name = "_QFdelayed_privatizationEvar2"}
+!    %c111_i32 = arith.constant 111 : i32
+!    fir.store %c111_i32 to %0 : !fir.ref<i32>
+!    %c222_i32 = arith.constant 222 : i32
+!    fir.store %c222_i32 to %1 : !fir.ref<i32>
+!    omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
+!      %2 = fir.load %0 : !fir.ref<i32>
+!      %3 = fir.load %1 : !fir.ref<i32>
+!      %4 = arith.addi %2, %3 : i32
+!      %c2_i32 = arith.constant 2 : i32
+!      %5 = arith.addi %4, %c2_i32 : i32
+!      fir.store %5 to %0 : !fir.ref<i32>
+!      omp.terminator
+!    }
+!    return
 !  }
-!  return
+!  "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var1.privatizer"}> ({
+!  ^bb0(%arg0: !fir.ref<i32>):
+!    %0 = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatizationEvar1"}
+!    %1 = fir.load %arg0 : !fir.ref<i32>
+!    fir.store %1 to %0 : !fir.ref<i32>
+!    omp.yield(%0 : !fir.ref<i32>)
+!  }) : () -> ()
+!  "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var2.privatizer"}> ({
+!  ^bb0(%arg0: !fir.ref<i32>):
+!    %0 = fir.alloca i32 {bindc_name = "var2", pinned, uniq_name = "_QFdelayed_privatizationEvar2"}
+!    %1 = fir.load %arg0 : !fir.ref<i32>
+!    fir.store %1 to %0 : !fir.ref<i32>
+!    omp.yield(%0 : !fir.ref<i32>)
+!  }) : () -> ()
 !}
 !
-!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var1.privatizer"}> ({
-!^bb0(%arg0: !fir.ref<i32>):
-!  %0 = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatizationEvar1"}
-!  %1 = fir.load %arg0 : !fir.ref<i32>
-!  fir.store %1 to %0 : !fir.ref<i32>
-!  omp.yield(%0 : !fir.ref<i32>)
-!}) : () -> ()
+! -----------------------------
+! ### Conversion to LLVM + OMP:
+! -----------------------------
+!module {
+!  llvm.func @_QPdelayed_privatization() {
+!    %0 = llvm.mlir.constant(1 : i64) : i64
+!    %1 = llvm.alloca %0 x i32 {bindc_name = "var1"} : (i64) -> !llvm.ptr
+!    %2 = llvm.mlir.constant(1 : i64) : i64
+!    %3 = llvm.alloca %2 x i32 {bindc_name = "var2"} : (i64) -> !llvm.ptr
+!    %4 = llvm.mlir.constant(111 : i32) : i32
+!    llvm.store %4, %1 : i32, !llvm.ptr
+!    %5 = llvm.mlir.constant(222 : i32) : i32
+!    llvm.store %5, %3 : i32, !llvm.ptr
+!    omp.parallel private(@var1.privatizer %1, @var2.privatizer %3 : !llvm.ptr, !llvm.ptr) {
+!      %6 = llvm.load %1 : !llvm.ptr -> i32
+!      %7 = llvm.load %3 : !llvm.ptr -> i32
+!      %8 = llvm.add %6, %7  : i32
+!      %9 = llvm.mlir.constant(2 : i32) : i32
+!      %10 = llvm.add %8, %9  : i32
+!      llvm.store %10, %1 : i32, !llvm.ptr
+!      omp.terminator
+!    }
+!    llvm.return
+!  }
+!  "omp.private"() <{function_type = (!llvm.ptr) -> !llvm.ptr, sym_name = "var1.privatizer"}> ({
+!  ^bb0(%arg0: !llvm.ptr):
+!    %0 = llvm.mlir.constant(1 : i64) : i64
+!    %1 = llvm.alloca %0 x i32 {bindc_name = "var1", pinned} : (i64) -> !llvm.ptr
+!    %2 = llvm.load %arg0 : !llvm.ptr -> i32
+!    llvm.store %2, %1 : i32, !llvm.ptr
+!    omp.yield(%1 : !llvm.ptr)
+!  }) : () -> ()
+!  "omp.private"() <{function_type = (!llvm.ptr) -> !llvm.ptr, sym_name = "var2.privatizer"}> ({
+!  ^bb0(%arg0: !llvm.ptr):
+!    %0 = llvm.mlir.constant(1 : i64) : i64
+!    %1 = llvm.alloca %0 x i32 {bindc_name = "var2", pinned} : (i64) -> !llvm.ptr
+!    %2 = llvm.load %arg0 : !llvm.ptr -> i32
+!    llvm.store %2, %1 : i32, !llvm.ptr
+!    omp.yield(%1 : !llvm.ptr)
+!  }) : () -> ()
+!}
 !
-!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var2.privatizer"}> ({
-!^bb0(%arg0: !fir.ref<i32>):
-!  %0 = fir.alloca i32 {bindc_name = "var2", pinned, uniq_name = "_QFdelayed_privatizationEvar2"}
-!  %1 = fir.load %arg0 : !fir.ref<i32>
-!  fir.store %1 to %0 : !fir.ref<i32>
-!  omp.yield(%0 : !fir.ref<i32>)
-!}) : () -> ()
+! --------------------------
+! ### Conversion to LLVM IR:
+! --------------------------
+!%struct.ident_t = type { i32, i32, i32, i32, ptr }
+
+!@0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
+!@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
+
+!define void @_QPdelayed_privatization() {
+!  %structArg = alloca { ptr, ptr }, align 8
+!  %1 = alloca i32, i64 1, align 4
+!  %2 = alloca i32, i64 1, align 4
+!  store i32 111, ptr %1, align 4
+!  store i32 222, ptr %2, align 4
+!  br label %entry
+
+!entry:                                            ; preds = %0
+!  %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr @1)
+!  br label %omp_parallel
+
+!omp_parallel:                                     ; preds = %entry
+!  %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
+!  store ptr %1, ptr %gep_, align 8
+!  %gep_2 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
+!  store ptr %2, ptr %gep_2, align 8
+!  call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @_QPdelayed_privatization..omp_par, ptr %structArg)
+!  br label %omp.par.outlined.exit
+
+!omp.par.outlined.exit:                            ; preds = %omp_parallel
+!  br label %omp.par.exit.split
+
+!omp.par.exit.split:                               ; preds = %omp.par.outlined.exit
+!  ret void
+!}
+
+!; Function Attrs: nounwind
+!define internal void @_QPdelayed_privatization..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+!omp.par.entry:
+!  %gep_ = getelementptr { ptr, ptr }, ptr %0, i32 0, i32 0
+!  %loadgep_ = load ptr, ptr %gep_, align 8
+!  %gep_1 = getelementptr { ptr, ptr }, ptr %0, i32 0, i32 1
+!  %loadgep_2 = load ptr, ptr %gep_1, align 8
+!  %tid.addr.local = alloca i32, align 4
+!  %1 = load i32, ptr %tid.addr, align 4
+!  store i32 %1, ptr %tid.addr.local, align 4
+!  %tid = load i32, ptr %tid.addr.local, align 4
+!  %2 = alloca i32, i64 1, align 4
+!  %3 = load i32, ptr %loadgep_, align 4
+!  store i32 %3, ptr %2, align 4
+!  %4 = alloca i32, i64 1, align 4
+!  %5 = load i32, ptr %loadgep_2, align 4
+!  store i32 %5, ptr %4, align 4
+!  br label %omp.par.region
+
+!omp.par.region:                                   ; preds = %omp.par.entry
+!  br label %omp.par.region1
+
+!omp.par.region1:                                  ; preds = %omp.par.region
+!  %6 = load i32, ptr %2, align 4
+!  %7 = load i32, ptr %4, align 4
+!  %8 = add i32 %6, %7
+!  %9 = add i32 %8, 2
+!  store i32 %9, ptr %2, align 4
+!  br label %omp.region.cont
+
+!omp.region.cont:                                  ; preds = %omp.par.region1
+!  br label %omp.par.pre_finalize
+
+!omp.par.pre_finalize:                             ; preds = %omp.region.cont
+!  br label %omp.par.outlined.exit.exitStub
+
+!omp.par.outlined.exit.exitStub:                   ; preds = %omp.par.pre_finalize
+!  ret void
+!}
+
+!; Function Attrs: nounwind
+!declare i32 @__kmpc_global_thread_num(ptr) #0
+
+!; Function Attrs: nounwind
+!declare !callback !2 void @__kmpc_fork_call(ptr, i32, ptr, ...) #0
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 23e101f1e4527..253f06d1e4d5f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1092,6 +1092,75 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
                     llvm::Value *&replacementValue) -> InsertPointTy {
     replacementValue = &vPtr;
 
+    // If this is a private value, this lambda will return the corresponding
+    // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
+    // returned.
+    auto [privVar,
+          privInit] = [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
+      if (!opInst.getPrivateVars().empty()) {
+        auto privVars = opInst.getPrivateVars();
+        auto privInits = opInst.getPrivateInits();
+        assert(privInits && privInits->size() == privVars.size());
+
+        const auto *privInitIt = privInits->begin();
+        for (auto privVarIt = privVars.begin(); privVarIt != privVars.end();
+             ++privVarIt, ++privInitIt) {
+          auto *llvmPrivVarOp = moduleTranslation.lookupValue(*privVarIt);
+          if (llvmPrivVarOp != &vPtr) {
+            continue;
+          }
+
+          auto privSym = llvm::cast<SymbolRefAttr>(*privInitIt);
+          auto privOp =
+              SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+                  opInst, privSym);
+
+          return {*privVarIt, privOp};
+        }
+      }
+
+      return {mlir::Value(), omp::PrivateClauseOp()};
+    }();
+
+    if (privVar) {
+
+      // Replace the privatizer block argument with mlir value being privatized.
+      // This way, the body of the privatizer will be changed from using the
+      // region/block argument to the value being privatized.
+      assert(privInit->getRegions().front().getNumArguments() == 1);
+
+      auto arg = privInit->getRegions().front().getArgument(0);
+      for (auto &op : privInit->getRegions().front().front()) {
+        op.replaceUsesOfWith(arg, privVar);
+      }
+
+      auto oldIP = builder.saveIP();
+      builder.restoreIP(allocaIP);
+
+      // Temporarily unlink the terminator from its parent since
+      // `inlineConvertOmpRegions` expects the insertion block to **not**
+      // contain a terminator.
+      auto &allocaTerminator = builder.GetInsertBlock()->back();
+      assert(lastInstr.isTerminator());
+      allocaTerminator.removeFromParent();
+
+      SmallVector<llvm::Value *, 1> yieldedValues;
+      if (failed(inlineConvertOmpRegions(privInit->getRegion(0),
+                                         "omp.privatizer", builder,
+                                         moduleTranslation, &yieldedValues))) {
+        // TODO proper error-handling.
+        builder.restoreIP(oldIP);
+        return codeGenIP;
+      }
+
+      allocaTerminator.insertAfter(&builder.GetInsertBlock()->back());
+
+      assert(yieldedValues.size() == 1);
+      replacementValue = yieldedValues.front();
+
+      builder.restoreIP(oldIP);
+    }
+
     return codeGenIP;
   };
 
@@ -2774,12 +2843,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
       .Case([&](omp::TargetOp) {
         return convertOmpTarget(*op, builder, moduleTranslation);
       })
-      .Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
-        // No-op, should be handled by relevant owning operations e.g.
-        // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
-        // discarded
-        return success();
-      })
+      .Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
+          [&](auto op) {
+            // No-op, should be handled by relevant owning operations e.g.
+            // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
+            // discarded
+            return success();
+          })
       .Default([&](Operation *inst) {
         return inst->emitError("unsupported OpenMP operation: ")
                << inst->getName();

>From 4692aadde0fdd67b69319e29e315565255569333 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Fri, 2 Feb 2024 04:10:29 -0600
Subject: [PATCH 4/4] Handle some comments.

---
 flang/lib/Lower/OpenMP.cpp                    | 20 +++++------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  6 ++--
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  4 +--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 34 ++++++++++---------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  2 +-
 5 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 4d012c45108fd..9e66a52d8d958 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -163,7 +163,7 @@ class DataSharingProcessor {
 
   bool useDelayedPrivatizationWhenPossible;
   Fortran::lower::SymMap *symTable;
-  llvm::SetVector<mlir::SymbolRefAttr> privateInitializers;
+  llvm::SetVector<mlir::SymbolRefAttr> privatizers;
   llvm::SetVector<mlir::Value> privateSymHostAddrsses;
 
   bool needBarrier();
@@ -214,8 +214,8 @@ class DataSharingProcessor {
     loopIV = iv;
   }
 
-  const llvm::SetVector<mlir::SymbolRefAttr> &getPrivateInitializers() const {
-    return privateInitializers;
+  const llvm::SetVector<mlir::SymbolRefAttr> &getPrivatizers() const {
+    return privatizers;
   };
 
   const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
@@ -547,7 +547,7 @@ void DataSharingProcessor::privatize() {
         symTable->popScope();
         firOpBuilder.restoreInsertionPoint(ip);
 
-        privateInitializers.insert(mlir::SymbolRefAttr::get(privatizerOp));
+        privatizers.insert(mlir::SymbolRefAttr::get(privatizerOp));
         privateSymHostAddrsses.insert(hsb.getAddr());
       } else {
         cloneSymbol(sym);
@@ -2584,8 +2584,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     dsp.processStep1();
   }
 
-  llvm::SmallVector<mlir::Attribute> privateInits(
-      dsp.getPrivateInitializers().begin(), dsp.getPrivateInitializers().end());
+  llvm::SmallVector<mlir::Attribute> privatizers(dsp.getPrivatizers().begin(),
+                                                 dsp.getPrivatizers().end());
 
   llvm::SmallVector<mlir::Value> privateSymAddresses(
       dsp.getPrivateSymHostAddrsses().begin(),
@@ -2596,16 +2596,16 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
       &dsp,
       /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
-      reductionVars, privateSymAddresses,
+      reductionVars,
       reductionDeclSymbols.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  reductionDeclSymbols),
-      procBindKindAttr,
-      privateInits.empty()
+      procBindKindAttr, privateSymAddresses,
+      privatizers.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
-                                 privateInits));
+                                 privatizers));
 }
 
 static mlir::omp::SectionOp
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 23e058d372c79..daeb9c911f1c2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -188,10 +188,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
              Variadic<AnyType>:$allocate_vars,
              Variadic<AnyType>:$allocators_vars,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,
-             Variadic<AnyType>:$private_vars,
              OptionalAttr<SymbolRefArrayAttr>:$reductions,
              OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
-             OptionalAttr<SymbolRefArrayAttr>:$private_inits);
+             Variadic<AnyType>:$private_vars,
+             OptionalAttr<SymbolRefArrayAttr>:$privatizers);
 
   let regions = (region AnyRegion:$region);
 
@@ -217,7 +217,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
           | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
           | `private` `(`
               custom<PrivateVarList>(
-                $private_vars, type($private_vars), $private_inits
+                $private_vars, type($private_vars), $privatizers
               ) `)`
     ) $region attr-dict
   }];
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index b381aaf20bf89..889aa755d8ba4 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -419,10 +419,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocators_vars = */ llvm::SmallVector<Value>{},
         /* reduction_vars = */ llvm::SmallVector<Value>{},
-        /*private_vars=*/mlir::ValueRange{},
         /* reductions = */ ArrayAttr{},
         /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
-        /*private_inits*/ nullptr);
+        /*private_vars=*/mlir::ValueRange{},
+        /*privatizers=*/nullptr);
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c4ef7ef3f2fb5..a227473c9cdca 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -989,9 +989,10 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(
       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
-      /*reduction_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
+      /*reduction_vars=*/ValueRange(),
       /*reductions=*/nullptr,
-      /*proc_bind_val=*/nullptr, /*private_inits*/ nullptr);
+      /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
+      /*privatizers*/ nullptr);
   state.addAttributes(attributes);
 }
 
@@ -1610,20 +1611,20 @@ void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 static ParseResult parsePrivateVarList(
     OpAsmParser &parser,
     llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
-    llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privateInitsAttr) {
+    llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privatizersAttr) {
   SymbolRefAttr privatizerSym;
   OpAsmParser::UnresolvedOperand arg;
   OpAsmParser::UnresolvedOperand blockArg;
   Type argType;
 
-  SmallVector<SymbolRefAttr> privateInitsVec;
+  SmallVector<SymbolRefAttr> privatizersVec;
 
   auto parsePrivatizers = [&]() -> ParseResult {
     if (parser.parseAttribute(privatizerSym) || parser.parseOperand(arg)) {
       return failure();
     }
 
-    privateInitsVec.push_back(privatizerSym);
+    privatizersVec.push_back(privatizerSym);
     privateVarsOperands.push_back(arg);
     return success();
   };
@@ -1638,9 +1639,9 @@ static ParseResult parsePrivateVarList(
   if (parser.parseCommaSeparatedList(parsePrivatizers))
     return failure();
 
-  SmallVector<Attribute> privateInits(privateInitsVec.begin(),
-                                      privateInitsVec.end());
-  privateInitsAttr = ArrayAttr::get(parser.getContext(), privateInits);
+  SmallVector<Attribute> privatizers(privatizersVec.begin(),
+                                     privatizersVec.end());
+  privatizersAttr = ArrayAttr::get(parser.getContext(), privatizers);
 
   if (parser.parseColon())
     return failure();
@@ -1654,17 +1655,18 @@ static ParseResult parsePrivateVarList(
 static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
                                 OperandRange privateVars,
                                 TypeRange privateVarTypes,
-                                std::optional<ArrayAttr> privateInitsAttr) {
+                                std::optional<ArrayAttr> privatizersAttr) {
   unsigned argIndex = 0;
-  assert(privateVars.size() == privateVarTypes.size() &&
-         ((privateVars.empty()) ||
-          (*privateInitsAttr &&
-           (privateInitsAttr->size() == privateVars.size()))));
+  // TODO Add an op verifier instead of this assertion.
+  assert(
+      privateVars.size() == privateVarTypes.size() &&
+      ((privateVars.empty()) ||
+       (*privatizersAttr && (privatizersAttr->size() == privateVars.size()))));
 
   for (const auto &privateVar : privateVars) {
-    assert(privateInitsAttr);
-    const auto &privateInitSym = (*privateInitsAttr)[argIndex];
-    printer << privateInitSym << " " << privateVar;
+    assert(privatizersAttr);
+    const auto &privatizerSym = (*privatizersAttr)[argIndex];
+    printer << privatizerSym << " " << privateVar;
 
     argIndex++;
     if (argIndex < privateVars.size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 253f06d1e4d5f..2d0bbc8442aa9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1099,7 +1099,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
           privInit] = [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
       if (!opInst.getPrivateVars().empty()) {
         auto privVars = opInst.getPrivateVars();
-        auto privInits = opInst.getPrivateInits();
+        auto privInits = opInst.getPrivatizers();
         assert(privInits && privInits->size() == privVars.size());
 
         const auto *privInitIt = privInits->begin();



More information about the flang-commits mailing list