[flang-commits] [flang] [mlir] [WIP] Delayed privatization. (PR #79862)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Wed Jan 31 03:06:24 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/79862

>From ef847abd555d2bb84494c4dd06b8a04ac9413641 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 29 Jan 2024 04:45:18 -0600
Subject: [PATCH] [WIP] Delayed privatization.

This is a PoC for delayed privatization in OpenMP. Instead of directly
emitting privatization code in the frontend, we add a new op to outline
the privatization logic for a symbol and call-like mapping that maps
from the host symbol to an outlined function-like privatizer op.

Later, we would inline the delayed privatizer function-like op in the
OpenMP region to basically get the same code generated directly by the
fronend at the moment.
---
 flang/include/flang/Lower/AbstractConverter.h |   4 +
 flang/lib/Lower/Bridge.cpp                    |   2 +-
 flang/lib/Lower/OpenMP.cpp                    | 132 ++++++++++++++----
 .../OpenMP/FIR/delayed_privatization.f90      |  43 ++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  43 +++++-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |   4 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  92 +++++++++++-
 mlir/test/Dialect/OpenMP/roundtrip.mlir       |  36 +++++
 8 files changed, 326 insertions(+), 30 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
 create mode 100644 mlir/test/Dialect/OpenMP/roundtrip.mlir

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c19dcbdcdb390..e7507c6fd705f 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -16,6 +16,7 @@
 #include "flang/Common/Fortran.h"
 #include "flang/Lower/LoweringOptions.h"
 #include "flang/Lower/PFTDefs.h"
+#include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Semantics/symbol.h"
 #include "mlir/IR/Builders.h"
@@ -295,6 +296,9 @@ class AbstractConverter {
     return loweringOptions;
   }
 
+  virtual Fortran::lower::SymbolBox
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index d657075d53efb..0dd8fda0d28ec 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1070,7 +1070,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Find the symbol in one level up of symbol map such as for host-association
   /// in OpenMP code or return null.
   Fortran::lower::SymbolBox
-  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
     if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
       return v;
     return {};
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index be2117efbabc0..4d012c45108fd 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -161,6 +161,11 @@ class DataSharingProcessor {
   const Fortran::parser::OmpClauseList &opClauseList;
   Fortran::lower::pft::Evaluation &eval;
 
+  bool useDelayedPrivatizationWhenPossible;
+  Fortran::lower::SymMap *symTable;
+  llvm::SetVector<mlir::SymbolRefAttr> privateInitializers;
+  llvm::SetVector<mlir::Value> privateSymHostAddrsses;
+
   bool needBarrier();
   void collectSymbols(Fortran::semantics::Symbol::Flag flag);
   void collectOmpObjectListSymbol(
@@ -182,10 +187,14 @@ class DataSharingProcessor {
 public:
   DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
                        const Fortran::parser::OmpClauseList &opClauseList,
-                       Fortran::lower::pft::Evaluation &eval)
+                       Fortran::lower::pft::Evaluation &eval,
+                       bool useDelayedPrivatizationWhenPossible = false,
+                       Fortran::lower::SymMap *symTable = nullptr)
       : hasLastPrivateOp(false), converter(converter),
         firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
-        eval(eval) {}
+        eval(eval), useDelayedPrivatizationWhenPossible(
+                        useDelayedPrivatizationWhenPossible),
+        symTable(symTable) {}
   // Privatisation is split into two steps.
   // Step1 performs cloning of all privatisation clauses and copying for
   // firstprivates. Step1 is performed at the place where process/processStep1
@@ -204,6 +213,14 @@ class DataSharingProcessor {
     assert(!loopIV && "Loop iteration variable already set");
     loopIV = iv;
   }
+
+  const llvm::SetVector<mlir::SymbolRefAttr> &getPrivateInitializers() const {
+    return privateInitializers;
+  };
+
+  const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
+    return privateSymHostAddrsses;
+  }
 };
 
 void DataSharingProcessor::processStep1() {
@@ -496,8 +513,46 @@ void DataSharingProcessor::privatize() {
         copyFirstPrivateSymbol(&*mem);
       }
     } else {
-      cloneSymbol(sym);
-      copyFirstPrivateSymbol(sym);
+      if (useDelayedPrivatizationWhenPossible) {
+        auto ip = firOpBuilder.saveInsertionPoint();
+
+        auto moduleOp = firOpBuilder.getInsertionBlock()
+                            ->getParentOp()
+                            ->getParentOfType<mlir::ModuleOp>();
+
+        firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
+                                       moduleOp.getBodyRegion().front().end());
+
+        Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
+        assert(hsb && "Host symbol box not found");
+
+        auto symType = hsb.getAddr().getType();
+        auto symLoc = hsb.getAddr().getLoc();
+        auto privatizerOp = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
+            symLoc, symType, sym->name().ToString());
+        firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());
+
+        symTable->pushScope();
+        symTable->addSymbol(*sym, privatizerOp.getArgument(0));
+        symTable->pushScope();
+
+        cloneSymbol(sym);
+        copyFirstPrivateSymbol(sym);
+
+        firOpBuilder.create<mlir::omp::YieldOp>(
+            hsb.getAddr().getLoc(),
+            symTable->shallowLookupSymbol(*sym).getAddr());
+
+        symTable->popScope();
+        symTable->popScope();
+        firOpBuilder.restoreInsertionPoint(ip);
+
+        privateInitializers.insert(mlir::SymbolRefAttr::get(privatizerOp));
+        privateSymHostAddrsses.insert(hsb.getAddr());
+      } else {
+        cloneSymbol(sym);
+        copyFirstPrivateSymbol(sym);
+      }
     }
   }
 }
@@ -2463,12 +2518,12 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
                           Fortran::lower::pft::Evaluation &eval, bool genNested,
                           mlir::Location currentLocation, bool outerCombined,
                           const Fortran::parser::OmpClauseList *clauseList,
-                          Args &&...args) {
+                          DataSharingProcessor *dsp, Args &&...args) {
   auto op = converter.getFirOpBuilder().create<OpTy>(
       currentLocation, std::forward<Args>(args)...);
   createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
                        clauseList,
-                       /*args=*/{}, outerCombined);
+                       /*args=*/{}, outerCombined, dsp);
   return op;
 }
 
@@ -2480,6 +2535,7 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
                                             currentLocation,
                                             /*outerCombined=*/false,
                                             /*clauseList=*/nullptr,
+                                            /*dsp=*/nullptr,
                                             /*resultTypes=*/mlir::TypeRange());
 }
 
@@ -2487,14 +2543,17 @@ static mlir::omp::OrderedRegionOp
 genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval, bool genNested,
                    mlir::Location currentLocation) {
-  return genOpWithBody<mlir::omp::OrderedRegionOp>(
-      converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false,
-      /*clauseList=*/nullptr, /*simd=*/false);
+  return genOpWithBody<mlir::omp::OrderedRegionOp>(converter, eval, genNested,
+                                                   currentLocation,
+                                                   /*outerCombined=*/false,
+                                                   /*clauseList=*/nullptr,
+                                                   /*dsp=*/nullptr,
+                                                   /*simd=*/false);
 }
 
 static mlir::omp::ParallelOp
 genParallelOp(Fortran::lower::AbstractConverter &converter,
+              Fortran::lower::SymMap &symTable,
               Fortran::lower::pft::Evaluation &eval, bool genNested,
               mlir::Location currentLocation,
               const Fortran::parser::OmpClauseList &clauseList,
@@ -2516,16 +2575,37 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   if (!outerCombined)
     cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
 
+  bool privatize = !outerCombined;
+  DataSharingProcessor dsp(converter, clauseList, eval,
+                           /*useDelayedPrivatizationWhenPossible=*/true,
+                           &symTable);
+
+  if (privatize) {
+    dsp.processStep1();
+  }
+
+  llvm::SmallVector<mlir::Attribute> privateInits(
+      dsp.getPrivateInitializers().begin(), dsp.getPrivateInitializers().end());
+
+  llvm::SmallVector<mlir::Value> privateSymAddresses(
+      dsp.getPrivateSymHostAddrsses().begin(),
+      dsp.getPrivateSymHostAddrsses().end());
+
   return genOpWithBody<mlir::omp::ParallelOp>(
       converter, eval, genNested, currentLocation, outerCombined, &clauseList,
+      &dsp,
       /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
-      reductionVars,
+      reductionVars, privateSymAddresses,
       reductionDeclSymbols.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  reductionDeclSymbols),
-      procBindKindAttr);
+      procBindKindAttr,
+      privateInits.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 privateInits));
 }
 
 static mlir::omp::SectionOp
@@ -2537,7 +2617,8 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
   // all privatization is done within `omp.section` operations.
   return genOpWithBody<mlir::omp::SectionOp>(
       converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false, &sectionsClauseList);
+      /*outerCombined=*/false, &sectionsClauseList,
+      /*dsp=*/nullptr);
 }
 
 static mlir::omp::SingleOp
@@ -2558,8 +2639,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::SingleOp>(
       converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false, &beginClauseList, allocateOperands,
-      allocatorOperands, nowaitAttr);
+      /*outerCombined=*/false, &beginClauseList, /*dsp=*/nullptr,
+      allocateOperands, allocatorOperands, nowaitAttr);
 }
 
 static mlir::omp::TaskOp
@@ -2591,8 +2672,8 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::TaskOp>(
       converter, eval, genNested, currentLocation,
-      /*outerCombined=*/false, &clauseList, ifClauseOperand, finalClauseOperand,
-      untiedAttr, mergeableAttr,
+      /*outerCombined=*/false, &clauseList, /*dsp=*/nullptr, ifClauseOperand,
+      finalClauseOperand, untiedAttr, mergeableAttr,
       /*in_reduction_vars=*/mlir::ValueRange(),
       /*in_reductions=*/nullptr, priorityClauseOperand,
       dependTypeOperands.empty()
@@ -2615,6 +2696,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
   return genOpWithBody<mlir::omp::TaskGroupOp>(
       converter, eval, genNested, currentLocation,
       /*outerCombined=*/false, &clauseList,
+      /*dsp=*/nullptr,
       /*task_reduction_vars=*/mlir::ValueRange(),
       /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
 }
@@ -2994,6 +3076,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::TeamsOp>(
       converter, eval, genNested, currentLocation, outerCombined, &clauseList,
+      /*dsp=*/nullptr,
       /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
       threadLimitClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
@@ -3392,8 +3475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
-                    loopOpClauseList,
+      genParallelOp(converter, symTable, eval, /*genNested=*/false,
+                    currentLocation, loopOpClauseList,
                     /*outerCombined=*/true);
     }
   }
@@ -3481,8 +3564,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation);
     break;
   case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, eval, /*genNested=*/true, currentLocation,
-                  beginClauseList);
+    genParallelOp(converter, symTable, eval, /*genNested=*/true,
+                  currentLocation, beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_single:
     genSingleOp(converter, eval, /*genNested=*/true, currentLocation,
@@ -3541,8 +3624,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
           .test(directive.v)) {
     bool outerCombined =
         directive.v != llvm::omp::Directive::OMPD_target_parallel;
-    genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
-                  beginClauseList, outerCombined);
+    genParallelOp(converter, symTable, eval, /*genNested=*/false,
+                  currentLocation, beginClauseList, outerCombined);
     combinedDirective = true;
   }
   if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
@@ -3625,7 +3708,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
 
   // Parallel wrapper of PARALLEL SECTIONS construct
   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
-    genParallelOp(converter, eval,
+    genParallelOp(converter, symTable, eval,
                   /*genNested=*/false, currentLocation, sectionsClauseList,
                   /*outerCombined=*/true);
   } else {
@@ -3642,6 +3725,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
                                        /*genNested=*/false, currentLocation,
                                        /*outerCombined=*/false,
                                        /*clauseList=*/nullptr,
+                                       /*dsp=*/nullptr,
                                        /*reduction_vars=*/mlir::ValueRange(),
                                        /*reductions=*/nullptr, allocateOperands,
                                        allocatorOperands, nowaitClauseOperand);
diff --git a/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90 b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
new file mode 100644
index 0000000000000..6a668e6fb6660
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
@@ -0,0 +1,43 @@
+subroutine delayed_privatization()
+  integer :: var1
+  integer :: var2
+
+!$OMP PARALLEL FIRSTPRIVATE(var1, var2)
+  var1 = var1 + var2 + 2
+!$OMP END PARALLEL
+
+end subroutine
+
+! This is what flang emits with the PoC:
+! --------------------------------------
+!
+!func.func @_QPdelayed_privatization() {
+!  %0 = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatizationEvar1"}
+!  %1 = fir.alloca i32 {bindc_name = "var2", uniq_name = "_QFdelayed_privatizationEvar2"}
+!  omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
+!    %2 = fir.load %0 : !fir.ref<i32>
+!    %3 = fir.load %1 : !fir.ref<i32>
+!    %4 = arith.addi %2, %3 : i32
+!    %c2_i32 = arith.constant 2 : i32
+!    %5 = arith.addi %4, %c2_i32 : i32
+!    fir.store %5 to %0 : !fir.ref<i32>
+!    omp.terminator
+!  }
+!  return
+!}
+!
+!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var1.privatizer"}> ({
+!^bb0(%arg0: !fir.ref<i32>):
+!  %0 = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatizationEvar1"}
+!  %1 = fir.load %arg0 : !fir.ref<i32>
+!  fir.store %1 to %0 : !fir.ref<i32>
+!  omp.yield(%0 : !fir.ref<i32>)
+!}) : () -> ()
+!
+!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var2.privatizer"}> ({
+!^bb0(%arg0: !fir.ref<i32>):
+!  %0 = fir.alloca i32 {bindc_name = "var2", pinned, uniq_name = "_QFdelayed_privatizationEvar2"}
+!  %1 = fir.load %arg0 : !fir.ref<i32>
+!  fir.store %1 to %0 : !fir.ref<i32>
+!  omp.yield(%0 : !fir.ref<i32>)
+!}) : () -> ()
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 96c15e775a302..23e058d372c79 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -16,6 +16,7 @@
 
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
@@ -187,8 +188,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
              Variadic<AnyType>:$allocate_vars,
              Variadic<AnyType>:$allocators_vars,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,
+             Variadic<AnyType>:$private_vars,
              OptionalAttr<SymbolRefArrayAttr>:$reductions,
-             OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
+             OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
+             OptionalAttr<SymbolRefArrayAttr>:$private_inits);
 
   let regions = (region AnyRegion:$region);
 
@@ -212,6 +215,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
                 $allocators_vars, type($allocators_vars)
               ) `)`
           | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
+          | `private` `(`
+              custom<PrivateVarList>(
+                $private_vars, type($private_vars), $private_inits
+              ) `)`
     ) $region attr-dict
   }];
   let hasVerifier = 1;
@@ -621,7 +628,7 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
 def YieldOp : OpenMP_Op<"yield",
     [Pure, ReturnLike, Terminator,
      ParentOneOf<["WsLoopOp", "ReductionDeclareOp",
-     "AtomicUpdateOp", "SimdLoopOp"]>]> {
+     "AtomicUpdateOp", "SimdLoopOp", "PrivateClauseOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -1478,6 +1485,38 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
 //===----------------------------------------------------------------------===//
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
+def PrivateClauseOp : OpenMP_Op<"private", [
+    IsolatedFromAbove, FunctionOpInterface
+  ]> {
+  let summary = "TODO";
+  let description = [{}];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type);
+
+  let regions = (region AnyRegion:$body);
+
+  let builders = [OpBuilder<(ins
+    "::mlir::Type":$privateVar,
+    "::llvm::StringRef":$privateVarName
+  )>];
+
+  let extraClassDeclaration = [{
+    ::mlir::Region *getCallableRegion() {
+      return &getBody();
+    }
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+  }];
+}
 
 def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
   let summary = "target construct";
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 2f8b3f7e11de1..b381aaf20bf89 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -419,8 +419,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocators_vars = */ llvm::SmallVector<Value>{},
         /* reduction_vars = */ llvm::SmallVector<Value>{},
+        /*private_vars=*/mlir::ValueRange{},
         /* reductions = */ ArrayAttr{},
-        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
+        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
+        /*private_inits*/ nullptr);
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 13cc16125a273..c4ef7ef3f2fb5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -989,8 +989,9 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(
       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
-      /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
-      /*proc_bind_val=*/nullptr);
+      /*reduction_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
+      /*reductions=*/nullptr,
+      /*proc_bind_val=*/nullptr, /*private_inits*/ nullptr);
   state.addAttributes(attributes);
 }
 
@@ -1594,6 +1595,93 @@ LogicalResult DataBoundsOp::verify() {
   return success();
 }
 
+void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                            Type privateVarType, StringRef privateVarName) {
+  FunctionType initializerType = FunctionType::get(
+      odsBuilder.getContext(), {privateVarType}, {privateVarType});
+  std::string privatizerName = (privateVarName + ".privatizer").str();
+
+  build(odsBuilder, odsState, privatizerName, initializerType);
+
+  mlir::Block &block = odsState.regions.front()->emplaceBlock();
+  block.addArgument(privateVarType, odsState.location);
+}
+
+static ParseResult parsePrivateVarList(
+    OpAsmParser &parser,
+    llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
+    llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privateInitsAttr) {
+  SymbolRefAttr privatizerSym;
+  OpAsmParser::UnresolvedOperand arg;
+  OpAsmParser::UnresolvedOperand blockArg;
+  Type argType;
+
+  SmallVector<SymbolRefAttr> privateInitsVec;
+
+  auto parsePrivatizers = [&]() -> ParseResult {
+    if (parser.parseAttribute(privatizerSym) || parser.parseOperand(arg)) {
+      return failure();
+    }
+
+    privateInitsVec.push_back(privatizerSym);
+    privateVarsOperands.push_back(arg);
+    return success();
+  };
+
+  auto parseTypes = [&]() -> ParseResult {
+    if (parser.parseType(argType))
+      return failure();
+    privateVarsTypes.push_back(argType);
+    return success();
+  };
+
+  if (parser.parseCommaSeparatedList(parsePrivatizers))
+    return failure();
+
+  SmallVector<Attribute> privateInits(privateInitsVec.begin(),
+                                      privateInitsVec.end());
+  privateInitsAttr = ArrayAttr::get(parser.getContext(), privateInits);
+
+  if (parser.parseColon())
+    return failure();
+
+  if (parser.parseCommaSeparatedList(parseTypes))
+    return failure();
+
+  return success();
+}
+
+static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
+                                OperandRange privateVars,
+                                TypeRange privateVarTypes,
+                                std::optional<ArrayAttr> privateInitsAttr) {
+  unsigned argIndex = 0;
+  assert(privateVars.size() == privateVarTypes.size() &&
+         ((privateVars.empty()) ||
+          (*privateInitsAttr &&
+           (privateInitsAttr->size() == privateVars.size()))));
+
+  for (const auto &privateVar : privateVars) {
+    assert(privateInitsAttr);
+    const auto &privateInitSym = (*privateInitsAttr)[argIndex];
+    printer << privateInitSym << " " << privateVar;
+
+    argIndex++;
+    if (argIndex < privateVars.size())
+      printer << ", ";
+  }
+
+  printer << " : ";
+
+  argIndex = 0;
+  for (const auto &mapType : privateVarTypes) {
+    printer << mapType;
+    argIndex++;
+    if (argIndex < privateVarTypes.size())
+      printer << ", ";
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/roundtrip.mlir b/mlir/test/Dialect/OpenMP/roundtrip.mlir
new file mode 100644
index 0000000000000..c6e9fab6f7f98
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/roundtrip.mlir
@@ -0,0 +1,36 @@
+// RUN: fir-opt -verify-diagnostics %s | fir-opt | FileCheck %s
+
+// CHECK-LABEL: _QPprivate_clause
+func.func @_QPprivate_clause() {
+  %0 = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFprivate_clause_allocatableEx"}
+  %1 = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFprivate_clause_allocatableEy"}
+
+  // CHECK: omp.parallel private(@x.privatizer %0, @y.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>)
+  omp.parallel private(@x.privatizer %0, @y.privatizer %1: !fir.ref<i32>, !fir.ref<i32>) {
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
+"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
+// CHECK: ^bb0(%arg0: {{.*}}):
+^bb0(%arg0: !fir.ref<i32>):
+
+  // CHECK: %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
+  %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
+
+  // CHECK: omp.yield(%0 : !fir.ref<i32>)
+  omp.yield(%0 : !fir.ref<i32>)
+}) : () -> ()
+
+// CHECK: "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "y.privatizer"}> ({
+"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "y.privatizer"}> ({
+^bb0(%arg0: !fir.ref<i32>):
+
+  // CHECK: %0 = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFprivate_clause_allocatableEy"}
+  %0 = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFprivate_clause_allocatableEy"}
+
+  // CHECK: omp.yield(%0 : !fir.ref<i32>)
+  omp.yield(%0 : !fir.ref<i32>)
+}) : () -> ()



More information about the flang-commits mailing list