[flang-commits] [flang] [mlir] [WIP] Delayed privatization. (PR #79862)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Mon Jan 29 22:26:57 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/79862

>From f2165b004791619c556c5e4ddf8e2744bb48a9d7 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 29 Jan 2024 04:45:18 -0600
Subject: [PATCH 1/2] [WIP] Delayed privatization.

This is a PoC for delayed privatization in OpenMP. Instead of directly
emitting privatization code in the frontend, we add a new op to outline
the privatization logic for a symbol and call-like mapping that maps
from the host symbol to a block argument in the OpenMP region.

Example:
```
!$omp target private(x)
!$end omp target
```

Would be code-generated by flang as:
```
  func.func @foo() {
    omp.target x.privatizer %x -> %argx: !fir.ref<i32> {
    bb0(%argx: !fir.ref<i32>):
      // ... use %argx ....
    }
  }

  "omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "x.privatizer"}> ({
  ^bb0(%arg0: !fir.ref<i32>):
    %0 = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFprivate_clause_allocatableEx"}
    %1 = fir.load %arg0 : !fir.ref<i32>
    fir.store %1 to %0 : !fir.ref<i32>
    omp.yield(%0 : !fir.ref<i32>)
  }) : () -> ()
```

Later, we would inline the delayed privatizer function-like op in the
OpenMP region to basically get the same code generated directly by the
fronend at the moment.

So far this PoC implements the following:
- Adds the delayed privatization op: `omp.private`.
- For simple symbols, emits the op.

Still TODO:
- Extend the `omp.target` op to somehow model the oulined privatization
  logic.
- Inline the outlined privatizer before emitting LLVM IR.
- Support more complex symbols like allocatables.
---
 flang/include/flang/Lower/AbstractConverter.h | 18 ++--
 flang/include/flang/Lower/SymbolMap.h         |  1 +
 flang/lib/Lower/Bridge.cpp                    | 54 +++++++-----
 flang/lib/Lower/OpenMP.cpp                    | 83 +++++++++++++++----
 .../OpenMP/FIR/delayed_privatization.f90      | 42 ++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 35 +++++++-
 mlir/lib/Dialect/Func/IR/FuncOps.cpp          |  4 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 13 +++
 8 files changed, 204 insertions(+), 46 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/FIR/delayed_privatization.f90

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c19dcbdcdb390..b3ca804256ee0 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -16,6 +16,7 @@
 #include "flang/Common/Fortran.h"
 #include "flang/Lower/LoweringOptions.h"
 #include "flang/Lower/PFTDefs.h"
+#include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Semantics/symbol.h"
 #include "mlir/IR/Builders.h"
@@ -92,7 +93,8 @@ class AbstractConverter {
 
   /// Binds the symbol to an fir extended value. The symbol binding will be
   /// added or replaced at the inner-most level of the local symbol map.
-  virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
+  virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval,
+                          Fortran::lower::SymMap *symMap = nullptr) = 0;
 
   /// Override lowering of expression with pre-lowered values.
   /// Associate mlir::Value to evaluate::Expr. All subsequent call to
@@ -111,14 +113,16 @@ class AbstractConverter {
   /// For a given symbol which is host-associated, create a clone using
   /// parameters from the host-associated symbol.
   virtual bool
-  createHostAssociateVarClone(const Fortran::semantics::Symbol &sym) = 0;
+  createHostAssociateVarClone(const Fortran::semantics::Symbol &sym,
+                              Fortran::lower::SymMap *symMap = nullptr) = 0;
 
   virtual void
   createHostAssociateVarCloneDealloc(const Fortran::semantics::Symbol &sym) = 0;
 
-  virtual void copyHostAssociateVar(
-      const Fortran::semantics::Symbol &sym,
-      mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) = 0;
+  virtual void
+  copyHostAssociateVar(const Fortran::semantics::Symbol &sym,
+                       mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr,
+                       Fortran::lower::SymMap *symMap = nullptr) = 0;
 
   /// For a given symbol, check if it is present in the inner-most
   /// level of the symbol map.
@@ -295,6 +299,10 @@ class AbstractConverter {
     return loweringOptions;
   }
 
+  virtual Fortran::lower::SymbolBox
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym,
+                         Fortran::lower::SymMap *symMap = nullptr) = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index a55e4b133fe0a..1031b479eb619 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -101,6 +101,7 @@ struct SymbolBox : public fir::details::matcher<SymbolBox> {
                  [](const fir::FortranVariableOpInterface &x) {
                    return fir::FortranVariableOpInterface(x).getBase();
                  },
+                 [](const fir::MutableBoxValue &x) { return x.getAddr(); },
                  [](const auto &x) { return x.getAddr(); });
   }
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c3a14125bba85..eaa8eaeae84af 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -498,16 +498,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Add the symbol binding to the inner-most level of the symbol map and
   /// return true if it is not already present. Otherwise, return false.
   bool bindIfNewSymbol(Fortran::lower::SymbolRef sym,
-                       const fir::ExtendedValue &exval) {
-    if (shallowLookupSymbol(sym))
+                       const fir::ExtendedValue &exval,
+                       Fortran::lower::SymMap *symMap = nullptr) {
+    if (shallowLookupSymbol(sym, symMap))
       return false;
-    bindSymbol(sym, exval);
+    bindSymbol(sym, exval, symMap);
     return true;
   }
 
   void bindSymbol(Fortran::lower::SymbolRef sym,
-                  const fir::ExtendedValue &exval) override final {
-    addSymbol(sym, exval, /*forced=*/true);
+                  const fir::ExtendedValue &exval,
+                  Fortran::lower::SymMap *symMap = nullptr) override final {
+    addSymbol(sym, exval, /*forced=*/true, symMap);
   }
 
   void
@@ -610,14 +612,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   }
 
   bool createHostAssociateVarClone(
-      const Fortran::semantics::Symbol &sym) override final {
+      const Fortran::semantics::Symbol &sym,
+      Fortran::lower::SymMap *symMap = nullptr) override final {
     mlir::Location loc = genLocation(sym.name());
     mlir::Type symType = genType(sym);
     const auto *details = sym.detailsIf<Fortran::semantics::HostAssocDetails>();
     assert(details && "No host-association found");
     const Fortran::semantics::Symbol &hsym = details->symbol();
     mlir::Type hSymType = genType(hsym);
-    Fortran::lower::SymbolBox hsb = lookupSymbol(hsym);
+    Fortran::lower::SymbolBox hsb = lookupSymbol(hsym, symMap);
 
     auto allocate = [&](llvm::ArrayRef<mlir::Value> shape,
                         llvm::ArrayRef<mlir::Value> typeParams) -> mlir::Value {
@@ -720,7 +723,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           // Do nothing
         });
 
-    return bindIfNewSymbol(sym, exv);
+    return bindIfNewSymbol(sym, exv, symMap);
   }
 
   void createHostAssociateVarCloneDealloc(
@@ -745,16 +748,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   void copyHostAssociateVar(
       const Fortran::semantics::Symbol &sym,
-      mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) override final {
+      mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr,
+      Fortran::lower::SymMap *symMap = nullptr) override final {
     // 1) Fetch the original copy of the variable.
     assert(sym.has<Fortran::semantics::HostAssocDetails>() &&
            "No host-association found");
     const Fortran::semantics::Symbol &hsym = sym.GetUltimate();
-    Fortran::lower::SymbolBox hsb = lookupOneLevelUpSymbol(hsym);
+    Fortran::lower::SymbolBox hsb = lookupOneLevelUpSymbol(hsym, symMap);
     assert(hsb && "Host symbol box not found");
 
     // 2) Fetch the copied one that will mask the original.
-    Fortran::lower::SymbolBox sb = shallowLookupSymbol(sym);
+    Fortran::lower::SymbolBox sb = shallowLookupSymbol(sym, symMap);
     assert(sb && "Host-associated symbol box not found");
     assert(hsb.getAddr() != sb.getAddr() &&
            "Host and associated symbol boxes are the same");
@@ -763,8 +767,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
     if (copyAssignIP && copyAssignIP->isSet())
       builder->restoreInsertionPoint(*copyAssignIP);
-    else
+    else {
       builder->setInsertionPointAfter(sb.getAddr().getDefiningOp());
+    }
 
     Fortran::lower::SymbolBox *lhs_sb, *rhs_sb;
     if (copyAssignIP && copyAssignIP->isSet() &&
@@ -1060,8 +1065,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Find the symbol in the inner-most level of the local map or return null.
   Fortran::lower::SymbolBox
-  shallowLookupSymbol(const Fortran::semantics::Symbol &sym) {
-    if (Fortran::lower::SymbolBox v = localSymbols.shallowLookupSymbol(sym))
+  shallowLookupSymbol(const Fortran::semantics::Symbol &sym,
+                      Fortran::lower::SymMap *symMap = nullptr) {
+    auto &map = (symMap == nullptr ? localSymbols : *symMap);
+    if (Fortran::lower::SymbolBox v = map.shallowLookupSymbol(sym))
       return v;
     return {};
   }
@@ -1069,8 +1076,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Find the symbol in one level up of symbol map such as for host-association
   /// in OpenMP code or return null.
   Fortran::lower::SymbolBox
-  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
-    if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym,
+                         Fortran::lower::SymMap *symMap = nullptr) override {
+    auto &map = (symMap == nullptr ? localSymbols : *symMap);
+    if (Fortran::lower::SymbolBox v = map.lookupOneLevelUpSymbol(sym))
       return v;
     return {};
   }
@@ -1079,15 +1088,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
   bool addSymbol(const Fortran::semantics::SymbolRef sym,
-                 fir::ExtendedValue val, bool forced = false) {
-    if (!forced && lookupSymbol(sym))
+                 fir::ExtendedValue val, bool forced = false,
+                 Fortran::lower::SymMap *symMap = nullptr) {
+    auto &map = (symMap == nullptr ? localSymbols : *symMap);
+    if (!forced && lookupSymbol(sym, &map))
       return false;
     if (lowerToHighLevelFIR()) {
-      Fortran::lower::genDeclareSymbol(*this, localSymbols, sym, val,
-                                       fir::FortranVariableFlagsEnum::None,
-                                       forced);
+      Fortran::lower::genDeclareSymbol(
+          *this, map, sym, val, fir::FortranVariableFlagsEnum::None, forced);
     } else {
-      localSymbols.addSymbol(sym, val, forced);
+      map.addSymbol(sym, val, forced);
     }
     return true;
   }
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index be2117efbabc0..93a696020d3ea 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -169,12 +169,15 @@ class DataSharingProcessor {
   void collectSymbolsForPrivatization();
   void insertBarrier();
   void collectDefaultSymbols();
-  void privatize();
+  void
+  privatize(llvm::SetVector<mlir::omp::PrivateClauseOp> *privateInitializers);
   void defaultPrivatize();
   void copyLastPrivatize(mlir::Operation *op);
   void insertLastPrivateCompare(mlir::Operation *op);
-  void cloneSymbol(const Fortran::semantics::Symbol *sym);
-  void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
+  void cloneSymbol(const Fortran::semantics::Symbol *sym,
+                   Fortran::lower::SymMap *symMap = nullptr);
+  void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym,
+                              Fortran::lower::SymMap *symMap);
   void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
                              mlir::OpBuilder::InsertPoint *lastPrivIP);
   void insertDeallocs();
@@ -197,7 +200,8 @@ class DataSharingProcessor {
   // Step2 performs the copying for lastprivates and requires knowledge of the
   // MLIR operation to insert the last private update. Step2 adds
   // dealocation code as well.
-  void processStep1();
+  void processStep1(llvm::SetVector<mlir::omp::PrivateClauseOp>
+                        *privateInitializers = nullptr);
   void processStep2(mlir::Operation *op, bool isLoop);
 
   void setLoopIV(mlir::Value iv) {
@@ -206,10 +210,11 @@ class DataSharingProcessor {
   }
 };
 
-void DataSharingProcessor::processStep1() {
+void DataSharingProcessor::processStep1(
+    llvm::SetVector<mlir::omp::PrivateClauseOp> *privateInitializers) {
   collectSymbolsForPrivatization();
   collectDefaultSymbols();
-  privatize();
+  privatize(privateInitializers);
   defaultPrivatize();
   insertBarrier();
 }
@@ -239,20 +244,23 @@ void DataSharingProcessor::insertDeallocs() {
     }
 }
 
-void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
+void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym,
+                                       Fortran::lower::SymMap *symMap) {
   // Privatization for symbols which are pre-determined (like loop index
   // variables) happen separately, for everything else privatize here.
   if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
     return;
-  bool success = converter.createHostAssociateVarClone(*sym);
+  bool success = converter.createHostAssociateVarClone(*sym, symMap);
   (void)success;
   assert(success && "Privatization failed due to existing binding");
 }
 
 void DataSharingProcessor::copyFirstPrivateSymbol(
-    const Fortran::semantics::Symbol *sym) {
-  if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
-    converter.copyHostAssociateVar(*sym);
+    const Fortran::semantics::Symbol *sym,
+    Fortran::lower::SymMap *symMap = nullptr) {
+  if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate)) {
+    converter.copyHostAssociateVar(*sym, nullptr, symMap);
+  }
 }
 
 void DataSharingProcessor::copyLastPrivateSymbol(
@@ -487,8 +495,11 @@ void DataSharingProcessor::collectDefaultSymbols() {
   }
 }
 
-void DataSharingProcessor::privatize() {
+void DataSharingProcessor::privatize(
+    llvm::SetVector<mlir::omp::PrivateClauseOp> *privateInitializers) {
+
   for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
+
     if (const auto *commonDet =
             sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
       for (const auto &mem : commonDet->objects()) {
@@ -496,6 +507,42 @@ void DataSharingProcessor::privatize() {
         copyFirstPrivateSymbol(&*mem);
       }
     } else {
+      if (privateInitializers != nullptr) {
+        auto ip = firOpBuilder.saveInsertionPoint();
+
+        auto moduleOp = firOpBuilder.getInsertionBlock()
+                            ->getParentOp()
+                            ->getParentOfType<mlir::ModuleOp>();
+
+        firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
+                                       moduleOp.getBodyRegion().front().end());
+
+        Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
+        assert(hsb && "Host symbol box not found");
+
+        auto privatizerOp = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
+            hsb.getAddr().getLoc(), hsb.getAddr().getType(),
+            sym->name().ToString());
+        firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());
+
+        Fortran::semantics::Symbol cp = *sym;
+        Fortran::lower::SymMap privatizerSymbolMap;
+        privatizerSymbolMap.addSymbol(cp, privatizerOp.getArgument(0));
+        privatizerSymbolMap.pushScope();
+
+        cloneSymbol(&cp, &privatizerSymbolMap);
+        copyFirstPrivateSymbol(&cp, &privatizerSymbolMap);
+
+        firOpBuilder.create<mlir::omp::YieldOp>(
+            hsb.getAddr().getLoc(),
+            privatizerSymbolMap.shallowLookupSymbol(cp).getAddr());
+
+        firOpBuilder.restoreInsertionPoint(ip);
+      }
+
+      // TODO: This will eventually be an else to the `if` above it. For now, I
+      // emit both the outlined privatizer AND directly emitted cloning and
+      // copying ops while I am testing.
       cloneSymbol(sym);
       copyFirstPrivateSymbol(sym);
     }
@@ -2272,6 +2319,7 @@ static void createBodyOfOp(
     llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
     llvm::SmallVector<mlir::Location> locs(args.size(), loc);
     firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
+
     // The argument is not currently in memory, so make a temporary for the
     // argument, and store it there, then bind that location to the argument.
     mlir::Operation *storeOp = nullptr;
@@ -2291,10 +2339,11 @@ static void createBodyOfOp(
 
   // If it is an unstructured region and is not the outer region of a combined
   // construct, create empty blocks for all evaluations.
-  if (eval.lowerAsUnstructured() && !outerCombined)
+  if (eval.lowerAsUnstructured() && !outerCombined) {
     Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
                                             mlir::omp::YieldOp>(
         firOpBuilder, eval.getNestedEvaluations());
+  }
 
   // Start with privatization, so that the lowering of the nested
   // code will use the right symbols.
@@ -2307,12 +2356,14 @@ static void createBodyOfOp(
   if (privatize) {
     if (!dsp) {
       tempDsp.emplace(converter, *clauses, eval);
-      tempDsp->processStep1();
+      llvm::SetVector<mlir::omp::PrivateClauseOp> privateInitializers;
+      tempDsp->processStep1(&privateInitializers);
     }
   }
 
   if constexpr (std::is_same_v<Op, mlir::omp::ParallelOp>) {
     threadPrivatizeVars(converter, eval);
+
     if (clauses) {
       firOpBuilder.setInsertionPoint(marker);
       ClauseProcessor(converter, *clauses).processCopyin();
@@ -2361,6 +2412,7 @@ static void createBodyOfOp(
     if (exits.size() == 1)
       return exits[0];
     mlir::Block *exit = firOpBuilder.createBlock(&region);
+
     for (mlir::Block *b : exits) {
       firOpBuilder.setInsertionPointToEnd(b);
       firOpBuilder.create<mlir::cf::BranchOp>(loc, exit);
@@ -2382,8 +2434,9 @@ static void createBodyOfOp(
         assert(tempDsp.has_value());
         tempDsp->processStep2(op, isLoop);
       } else {
-        if (isLoop && args.size() > 0)
+        if (isLoop && args.size() > 0) {
           dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
+        }
         dsp->processStep2(op, isLoop);
       }
     }
diff --git a/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90 b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
new file mode 100644
index 0000000000000..7b4d9135c6e07
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/delayed_privatization.f90
@@ -0,0 +1,42 @@
+subroutine private_clause_allocatable()
+        integer :: xxx
+        integer :: yyy
+
+!$OMP PARALLEL FIRSTPRIVATE(xxx, yyy)
+!$OMP END PARALLEL
+
+end subroutine
+
+! This is what flang emits with the PoC:
+! --------------------------------------
+!
+!func.func @_QPprivate_clause_allocatable() {
+!  %0 = fir.alloca i32 {bindc_name = "xxx", uniq_name = "_QFprivate_clause_allocatableExxx"}
+!  %1 = fir.alloca i32 {bindc_name = "yyy", uniq_name = "_QFprivate_clause_allocatableEyyy"}
+!  omp.parallel {
+!    %2 = fir.alloca i32 {bindc_name = "xxx", pinned, uniq_name = "_QFprivate_clause_allocatableExxx"}
+!    %3 = fir.load %0 : !fir.ref<i32>
+!    fir.store %3 to %2 : !fir.ref<i32>
+!    %4 = fir.alloca i32 {bindc_name = "yyy", pinned, uniq_name = "_QFprivate_clause_allocatableEyyy"}
+!    %5 = fir.load %1 : !fir.ref<i32>
+!    fir.store %5 to %4 : !fir.ref<i32>
+!    omp.terminator
+!  }
+!  return
+!}
+!
+!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "xxx.privatizer"}> ({
+!^bb0(%arg0: !fir.ref<i32>):
+!  %0 = fir.alloca i32 {bindc_name = "xxx", pinned, uniq_name = "_QFprivate_clause_allocatableExxx"}
+!  %1 = fir.load %arg0 : !fir.ref<i32>
+!  fir.store %1 to %0 : !fir.ref<i32>
+!  omp.yield(%0 : !fir.ref<i32>)
+!}) : () -> ()
+!
+!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "yyy.privatizer"}> ({
+!^bb0(%arg0: !fir.ref<i32>):
+!  %0 = fir.alloca i32 {bindc_name = "yyy", pinned, uniq_name = "_QFprivate_clause_allocatableEyyy"}
+!  %1 = fir.load %arg0 : !fir.ref<i32>
+!  fir.store %1 to %0 : !fir.ref<i32>
+!  omp.yield(%0 : !fir.ref<i32>)
+!}) : () -> ()
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 96c15e775a302..4892b4c24afc5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -16,6 +16,7 @@
 
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
@@ -621,7 +622,7 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
 def YieldOp : OpenMP_Op<"yield",
     [Pure, ReturnLike, Terminator,
      ParentOneOf<["WsLoopOp", "ReductionDeclareOp",
-     "AtomicUpdateOp", "SimdLoopOp"]>]> {
+     "AtomicUpdateOp", "SimdLoopOp", "PrivateClauseOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -1478,6 +1479,38 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
 //===----------------------------------------------------------------------===//
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
+def PrivateClauseOp : OpenMP_Op<"private", [
+    IsolatedFromAbove, FunctionOpInterface
+  ]> {
+  let summary = "TODO";
+  let description = [{}];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type);
+
+  let regions = (region AnyRegion:$body);
+
+  let builders = [OpBuilder<(ins
+    "::mlir::Type":$privateVar,
+    "::llvm::StringRef":$privateVarName
+  )>];
+
+  let extraClassDeclaration = [{
+    ::mlir::Region *getCallableRegion() {
+      return &getBody();
+    }
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+  }];
+}
 
 def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
   let summary = "target construct";
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index d18ec279e85c0..5cfa8c790bbee 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -138,9 +138,7 @@ LogicalResult ConstantOp::verify() {
   return success();
 }
 
-OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
-  return getValueAttr();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
 void ConstantOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 13cc16125a273..d8a4a99886835 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1594,6 +1594,19 @@ LogicalResult DataBoundsOp::verify() {
   return success();
 }
 
+void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                            Type privateVarType, StringRef privateVarName) {
+  FunctionType initializerType = FunctionType::get(
+      odsBuilder.getContext(), {privateVarType}, {privateVarType});
+  std::string privatizerName = (privateVarName + ".privatizer").str();
+
+  build(odsBuilder, odsState, privatizerName, initializerType);
+
+  mlir::Block &block = odsState.regions.front()->emplaceBlock();
+  block.addArguments({privateVarType},
+                     SmallVector<Location>(1, odsState.location));
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 

>From aa72c74875fa09b004dcfed2f2d27f683bd41a19 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 30 Jan 2024 00:09:22 -0600
Subject: [PATCH 2/2] Modify the local symbol map in the converter.

---
 flang/include/flang/Lower/AbstractConverter.h | 18 +++----
 flang/lib/Lower/Bridge.cpp                    | 53 ++++++++-----------
 flang/lib/Lower/OpenMP.cpp                    | 32 ++++++-----
 3 files changed, 46 insertions(+), 57 deletions(-)

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index b3ca804256ee0..df200c72c4ae6 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -93,8 +93,7 @@ class AbstractConverter {
 
   /// Binds the symbol to an fir extended value. The symbol binding will be
   /// added or replaced at the inner-most level of the local symbol map.
-  virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval,
-                          Fortran::lower::SymMap *symMap = nullptr) = 0;
+  virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
 
   /// Override lowering of expression with pre-lowered values.
   /// Associate mlir::Value to evaluate::Expr. All subsequent call to
@@ -113,16 +112,14 @@ class AbstractConverter {
   /// For a given symbol which is host-associated, create a clone using
   /// parameters from the host-associated symbol.
   virtual bool
-  createHostAssociateVarClone(const Fortran::semantics::Symbol &sym,
-                              Fortran::lower::SymMap *symMap = nullptr) = 0;
+  createHostAssociateVarClone(const Fortran::semantics::Symbol &sym) = 0;
 
   virtual void
   createHostAssociateVarCloneDealloc(const Fortran::semantics::Symbol &sym) = 0;
 
-  virtual void
-  copyHostAssociateVar(const Fortran::semantics::Symbol &sym,
-                       mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr,
-                       Fortran::lower::SymMap *symMap = nullptr) = 0;
+  virtual void copyHostAssociateVar(
+      const Fortran::semantics::Symbol &sym,
+      mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) = 0;
 
   /// For a given symbol, check if it is present in the inner-most
   /// level of the symbol map.
@@ -300,8 +297,9 @@ class AbstractConverter {
   }
 
   virtual Fortran::lower::SymbolBox
-  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym,
-                         Fortran::lower::SymMap *symMap = nullptr) = 0;
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
+
+  virtual Fortran::lower::SymMap *getLocalSymbols() { return nullptr; }
 
 private:
   /// Options controlling lowering behavior.
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index eaa8eaeae84af..1e47f6d1a13da 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -498,18 +498,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Add the symbol binding to the inner-most level of the symbol map and
   /// return true if it is not already present. Otherwise, return false.
   bool bindIfNewSymbol(Fortran::lower::SymbolRef sym,
-                       const fir::ExtendedValue &exval,
-                       Fortran::lower::SymMap *symMap = nullptr) {
-    if (shallowLookupSymbol(sym, symMap))
+                       const fir::ExtendedValue &exval) {
+    if (shallowLookupSymbol(sym))
       return false;
-    bindSymbol(sym, exval, symMap);
+    bindSymbol(sym, exval);
     return true;
   }
 
   void bindSymbol(Fortran::lower::SymbolRef sym,
-                  const fir::ExtendedValue &exval,
-                  Fortran::lower::SymMap *symMap = nullptr) override final {
-    addSymbol(sym, exval, /*forced=*/true, symMap);
+                  const fir::ExtendedValue &exval) override final {
+    addSymbol(sym, exval, /*forced=*/true);
   }
 
   void
@@ -612,15 +610,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   }
 
   bool createHostAssociateVarClone(
-      const Fortran::semantics::Symbol &sym,
-      Fortran::lower::SymMap *symMap = nullptr) override final {
+      const Fortran::semantics::Symbol &sym) override final {
     mlir::Location loc = genLocation(sym.name());
     mlir::Type symType = genType(sym);
     const auto *details = sym.detailsIf<Fortran::semantics::HostAssocDetails>();
     assert(details && "No host-association found");
     const Fortran::semantics::Symbol &hsym = details->symbol();
     mlir::Type hSymType = genType(hsym);
-    Fortran::lower::SymbolBox hsb = lookupSymbol(hsym, symMap);
+    Fortran::lower::SymbolBox hsb = lookupSymbol(hsym);
 
     auto allocate = [&](llvm::ArrayRef<mlir::Value> shape,
                         llvm::ArrayRef<mlir::Value> typeParams) -> mlir::Value {
@@ -723,7 +720,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           // Do nothing
         });
 
-    return bindIfNewSymbol(sym, exv, symMap);
+    return bindIfNewSymbol(sym, exv);
   }
 
   void createHostAssociateVarCloneDealloc(
@@ -748,17 +745,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   void copyHostAssociateVar(
       const Fortran::semantics::Symbol &sym,
-      mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr,
-      Fortran::lower::SymMap *symMap = nullptr) override final {
+      mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) override final {
     // 1) Fetch the original copy of the variable.
     assert(sym.has<Fortran::semantics::HostAssocDetails>() &&
            "No host-association found");
     const Fortran::semantics::Symbol &hsym = sym.GetUltimate();
-    Fortran::lower::SymbolBox hsb = lookupOneLevelUpSymbol(hsym, symMap);
+    Fortran::lower::SymbolBox hsb = lookupOneLevelUpSymbol(hsym);
     assert(hsb && "Host symbol box not found");
 
     // 2) Fetch the copied one that will mask the original.
-    Fortran::lower::SymbolBox sb = shallowLookupSymbol(sym, symMap);
+    Fortran::lower::SymbolBox sb = shallowLookupSymbol(sym);
     assert(sb && "Host-associated symbol box not found");
     assert(hsb.getAddr() != sb.getAddr() &&
            "Host and associated symbol boxes are the same");
@@ -998,6 +994,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return name;
   }
 
+  Fortran::lower::SymMap *getLocalSymbols() override { return &localSymbols; }
+
 private:
   FirConverter() = delete;
   FirConverter(const FirConverter &) = delete;
@@ -1065,10 +1063,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Find the symbol in the inner-most level of the local map or return null.
   Fortran::lower::SymbolBox
-  shallowLookupSymbol(const Fortran::semantics::Symbol &sym,
-                      Fortran::lower::SymMap *symMap = nullptr) {
-    auto &map = (symMap == nullptr ? localSymbols : *symMap);
-    if (Fortran::lower::SymbolBox v = map.shallowLookupSymbol(sym))
+  shallowLookupSymbol(const Fortran::semantics::Symbol &sym) {
+    if (Fortran::lower::SymbolBox v = localSymbols.shallowLookupSymbol(sym))
       return v;
     return {};
   }
@@ -1076,10 +1072,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Find the symbol in one level up of symbol map such as for host-association
   /// in OpenMP code or return null.
   Fortran::lower::SymbolBox
-  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym,
-                         Fortran::lower::SymMap *symMap = nullptr) override {
-    auto &map = (symMap == nullptr ? localSymbols : *symMap);
-    if (Fortran::lower::SymbolBox v = map.lookupOneLevelUpSymbol(sym))
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
+    if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
       return v;
     return {};
   }
@@ -1088,16 +1082,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
   bool addSymbol(const Fortran::semantics::SymbolRef sym,
-                 fir::ExtendedValue val, bool forced = false,
-                 Fortran::lower::SymMap *symMap = nullptr) {
-    auto &map = (symMap == nullptr ? localSymbols : *symMap);
-    if (!forced && lookupSymbol(sym, &map))
+                 fir::ExtendedValue val, bool forced = false) {
+    if (!forced && lookupSymbol(sym))
       return false;
     if (lowerToHighLevelFIR()) {
-      Fortran::lower::genDeclareSymbol(
-          *this, map, sym, val, fir::FortranVariableFlagsEnum::None, forced);
+      Fortran::lower::genDeclareSymbol(*this, localSymbols, sym, val,
+                                       fir::FortranVariableFlagsEnum::None,
+                                       forced);
     } else {
-      map.addSymbol(sym, val, forced);
+      localSymbols.addSymbol(sym, val, forced);
     }
     return true;
   }
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 93a696020d3ea..953b5a6ffc9a8 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -174,10 +174,8 @@ class DataSharingProcessor {
   void defaultPrivatize();
   void copyLastPrivatize(mlir::Operation *op);
   void insertLastPrivateCompare(mlir::Operation *op);
-  void cloneSymbol(const Fortran::semantics::Symbol *sym,
-                   Fortran::lower::SymMap *symMap = nullptr);
-  void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym,
-                              Fortran::lower::SymMap *symMap);
+  void cloneSymbol(const Fortran::semantics::Symbol *sym);
+  void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
   void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
                              mlir::OpBuilder::InsertPoint *lastPrivIP);
   void insertDeallocs();
@@ -244,22 +242,20 @@ void DataSharingProcessor::insertDeallocs() {
     }
 }
 
-void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym,
-                                       Fortran::lower::SymMap *symMap) {
+void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
   // Privatization for symbols which are pre-determined (like loop index
   // variables) happen separately, for everything else privatize here.
   if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
     return;
-  bool success = converter.createHostAssociateVarClone(*sym, symMap);
+  bool success = converter.createHostAssociateVarClone(*sym);
   (void)success;
   assert(success && "Privatization failed due to existing binding");
 }
 
 void DataSharingProcessor::copyFirstPrivateSymbol(
-    const Fortran::semantics::Symbol *sym,
-    Fortran::lower::SymMap *symMap = nullptr) {
+    const Fortran::semantics::Symbol *sym) {
   if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate)) {
-    converter.copyHostAssociateVar(*sym, nullptr, symMap);
+    converter.copyHostAssociateVar(*sym, nullptr);
   }
 }
 
@@ -525,18 +521,20 @@ void DataSharingProcessor::privatize(
             sym->name().ToString());
         firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());
 
-        Fortran::semantics::Symbol cp = *sym;
-        Fortran::lower::SymMap privatizerSymbolMap;
-        privatizerSymbolMap.addSymbol(cp, privatizerOp.getArgument(0));
-        privatizerSymbolMap.pushScope();
+        converter.getLocalSymbols()->pushScope();
+        converter.getLocalSymbols()->addSymbol(*sym,
+                                               privatizerOp.getArgument(0));
+        converter.getLocalSymbols()->pushScope();
 
-        cloneSymbol(&cp, &privatizerSymbolMap);
-        copyFirstPrivateSymbol(&cp, &privatizerSymbolMap);
+        cloneSymbol(sym);
+        copyFirstPrivateSymbol(sym);
 
         firOpBuilder.create<mlir::omp::YieldOp>(
             hsb.getAddr().getLoc(),
-            privatizerSymbolMap.shallowLookupSymbol(cp).getAddr());
+            converter.getLocalSymbols()->shallowLookupSymbol(*sym).getAddr());
 
+        converter.getLocalSymbols()->popScope();
+        converter.getLocalSymbols()->popScope();
         firOpBuilder.restoreInsertionPoint(ip);
       }
 



More information about the flang-commits mailing list