[flang-commits] [flang] [mlir] WIP : Delayed privatisation of index variables (PR #75836)

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Mon Dec 18 09:56:20 PST 2023


https://github.com/kiranchandramohan created https://github.com/llvm/llvm-project/pull/75836

Not for merge.

This patch is a WIP to delay privatisation of index variables. 

>From f2089f8e44963dae9e8af47612f13e02b1eb168c Mon Sep 17 00:00:00 2001
From: Raghu Maddhipatla <Raghu.Maddhipatla at amd.com>
Date: Wed, 25 Oct 2023 18:19:08 -0500
Subject: [PATCH 1/2] [Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add
 lowering support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET
 directive.

---
 flang/lib/Lower/OpenMP.cpp                    | 54 +++++++++++++++++--
 flang/test/Lower/OpenMP/FIR/target.f90        | 41 +++++++++++++-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 18 +++++--
 mlir/test/Dialect/OpenMP/ops.mlir             |  8 +--
 4 files changed, 110 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..f0d6a3f382e5ff 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -604,6 +604,18 @@ class ClauseProcessor {
                       llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                           &useDeviceSymbols) const;
+  bool
+  processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+                     llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+                     llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+                     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                         &isDeviceSymbols) const;
+  bool
+  processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+                       llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+                       llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                           &isDeviceSymbols) const;
 
   // Call this method for these clauses that should be supported but are not
   // implemented yet. It triggers a compilation error if any of the given
@@ -1890,6 +1902,34 @@ bool ClauseProcessor::processUseDevicePtr(
       });
 }
 
+bool ClauseProcessor::processIsDevicePtr(
+    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+    llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+    const {
+  return findRepeatableClause<ClauseTy::IsDevicePtr>(
+      [&](const ClauseTy::IsDevicePtr *devPtrClause,
+          const Fortran::parser::CharBlock &) {
+        addUseDeviceClause(converter, devPtrClause->v, operands, isDeviceTypes,
+                           isDeviceLocs, isDeviceSymbols);
+      });
+}
+
+bool ClauseProcessor::processHasDeviceAddr(
+    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
+    llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
+    const {
+  return findRepeatableClause<ClauseTy::HasDeviceAddr>(
+      [&](const ClauseTy::HasDeviceAddr *devAddrClause,
+          const Fortran::parser::CharBlock &) {
+        addUseDeviceClause(converter, devAddrClause->v, operands, isDeviceTypes,
+                           isDeviceLocs, isDeviceSymbols);
+      });
+}
+
 template <typename... Ts>
 void ClauseProcessor::processTODO(mlir::Location currentLocation,
                                   llvm::omp::Directive directive) const {
@@ -2617,6 +2657,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Type> mapSymTypes;
   llvm::SmallVector<mlir::Location> mapSymLocs;
   llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
+  llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
+  llvm::SmallVector<mlir::Type> useDeviceTypes;
+  llvm::SmallVector<mlir::Location> useDeviceLocs;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
 
   ClauseProcessor cp(converter, clauseList);
   cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
@@ -2626,11 +2670,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
                 mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
+  cp.processIsDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
+                        useDeviceSymbols);
+  cp.processHasDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
+                          useDeviceSymbols);
   cp.processTODO<Fortran::parser::OmpClause::Private,
                  Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
-                 Fortran::parser::OmpClause::IsDevicePtr,
-                 Fortran::parser::OmpClause::HasDeviceAddr,
                  Fortran::parser::OmpClause::Reduction,
                  Fortran::parser::OmpClause::InReduction,
                  Fortran::parser::OmpClause::Allocate,
@@ -2705,7 +2751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      nowaitAttr, mapOperands);
+      nowaitAttr, devicePtrOperands, deviceAddrOperands, mapOperands);
 
   genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
                     mapSymbols, currentLocation);
@@ -3101,6 +3147,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
         !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
       TODO(clauseLocation, "OpenMP Block construct clause");
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 2034ac84334e54..8f38261fc1aa62 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -411,4 +411,43 @@ subroutine omp_target_parallel_do
    !CHECK: omp.terminator
    !CHECK: }
    !$omp end target parallel do
- end subroutine omp_target_parallel_do
+end subroutine omp_target_parallel_do
+
+!===============================================================================
+! Target `is_device_ptr` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
+subroutine omp_target_is_device_ptr
+   use iso_c_binding, only : c_ptr, c_loc
+   !CHECK: %[[DEV_PTR:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
+   type(c_ptr) :: a
+   !CHECK %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
+   integer, target :: b
+   !CHECK: %[[MAP_0:.*]] = omp.map_info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
+   !CHECK: %[[MAP_1:.*]] = omp.map_info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32)   map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
+   !CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]], %[[MAP_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+   !$omp target map(tofrom: a,b) is_device_ptr(a)
+      !CHECK: {{.*}} = fir.coordinate_of %[[DEV_PTR:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
+      a = c_loc(b)
+   !CHECK: omp.terminator
+   !$omp end target
+   !CHECK: }
+end subroutine omp_target_is_device_ptr
+
+ !===============================================================================
+ ! Target `has_device_addr` clause
+ !===============================================================================
+
+ !CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
+ subroutine omp_target_has_device_addr
+   integer, pointer :: a
+   !CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
+   !CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+   !$omp target has_device_addr(a)
+   !CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+      a = 10
+   !CHECK: omp.terminator
+   !$omp end target
+   !CHECK: }
+end subroutine omp_target_has_device_addr
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad45..92ab32b0131fa6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1389,10 +1389,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
 
     The optional $thread_limit specifies the limit on the number of threads
 
-    The optional $nowait elliminates the implicit barrier so the parent task can make progress
+    The optional $nowait eliminates the implicit barrier so the parent task can make progress
     even if the target task is not yet completed.
 
-    TODO:  is_device_ptr, depend, defaultmap, in_reduction
+    The optional $is_device_ptr indicates list items are device pointers
+
+    The optional $has_device_addr indicates that list items already have device
+    addresses, so may be directly accessed from target device. May include array
+    sections.
+
+    The optional $map_operands maps data from the task’s environment to the
+    device environment.
+
+    TODO:  depend, defaultmap, in_reduction
 
   }];
 
@@ -1400,8 +1409,9 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
                        Optional<AnyInteger>:$device,
                        Optional<AnyInteger>:$thread_limit,
                        UnitAttr:$nowait,
+                       Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
+                       Variadic<OpenMP_PointerLikeType>:$has_device_addr,
                        Variadic<AnyType>:$map_operands);
-
   let regions = (region AnyRegion:$region);
 
   let assemblyFormat = [{
@@ -1409,6 +1419,8 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
     | `device` `(` $device `:` type($device) `)`
     | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
     | `nowait` $nowait
+    | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
+    | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
     ) $region attr-dict
   }];
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4d88d9ac86fe16..b153b1b8221d80 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -480,22 +480,22 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
 }
 
 // CHECK-LABEL: omp_target
-func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
+func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
 
     // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
     // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
     // CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
+    // CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
     %mapv1 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
     %mapv2 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>)   map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
+    omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
     ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
       omp.terminator
     }

>From f6c8748647fd4160f43dc7eeb995f59fc235b89e Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Mon, 11 Dec 2023 22:51:12 +0000
Subject: [PATCH 2/2] [Flang][MLIR][OpenMP] WIP: Privatisation for index
 variables

---
 flang/lib/Lower/OpenMP.cpp                    | 118 +++++++++++++++++-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   2 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  73 ++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  20 +++
 .../OpenMPToLLVM/convert-to-llvmir.mlir       |   2 +-
 mlir/test/Dialect/OpenMP/ops.mlir             |  10 +-
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |   6 +-
 7 files changed, 217 insertions(+), 14 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index f0d6a3f382e5ff..428b4b171bbabe 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2115,6 +2115,7 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
   firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
 
   mlir::Type tempTy = converter.genType(*sym);
+  llvm::outs() << "Temp type = " << tempTy << "\n";
   mlir::Value temp = firOpBuilder.create<fir::AllocaOp>(
       loc, tempTy, /*pinned=*/true, /*lengthParams=*/mlir::ValueRange{},
       /*shapeParams*/ mlir::ValueRange{},
@@ -2128,6 +2129,103 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
   return storeOp;
 }
 
+/// Create the body (block) for an OpenMP Loop Operation.
+///
+/// \param [in]    op - the operation the body belongs to.
+/// \param [inout] converter - converter to use for the clauses.
+/// \param [in]    loc - location in source code.
+/// \param [in]    eval - current PFT node/evaluation.
+/// \oaran [in]    clauses - list of clauses to process.
+/// \param [in]    args - block arguments (induction variable[s]) for the
+////                      region.
+/// \param [in]    outerCombined - is this an outer operation - prevents
+///                                privatization.
+template <typename Op>
+static void createBodyOfLoopOp(
+    Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OmpClauseList *clauses = nullptr,
+    const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
+    bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  // If an argument for the region is provided then create the block with that
+  // argument. Also update the symbol's address with the mlir argument value.
+  // e.g. For loops the argument is the induction variable. And all further
+  // uses of the induction variable should use this mlir value.
+  mlir::Operation *storeOp = nullptr;
+  assert(args.size() > 0);
+  std::size_t loopVarTypeSize = 0;
+  for (const Fortran::semantics::Symbol *arg : args)
+    loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
+  mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+  llvm::SmallVector<mlir::Type> tiv;
+  llvm::SmallVector<mlir::Location> locs;
+  for (int i = 0; i < (int)args.size(); i++) {
+    tiv.push_back(loopVarType);
+    locs.push_back(loc);
+  }
+  int offset = 0;
+  // The argument is not currently in memory, so make a temporary for the
+  // argument, and store it there, then bind that location to the argument.
+  for (const Fortran::semantics::Symbol *arg : args) {
+    mlir::Type symType = converter.genType(*arg);
+    mlir::Type symRefType = firOpBuilder.getRefType(symType);
+    tiv.push_back(symRefType);
+    locs.push_back(loc);
+    offset++;
+  }
+  firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
+
+  int argIndex = 0;
+  for (const Fortran::semantics::Symbol *arg : args) {
+    mlir::Value addrVal =
+          fir::getBase(op.getRegion().front().getArgument(argIndex+offset));
+    converter.bindSymbol(*arg, addrVal);
+    mlir::Type symType = converter.genType(*arg);
+    mlir::Value indexVal =
+          fir::getBase(op.getRegion().front().getArgument(argIndex));
+    mlir::Value cvtVal = firOpBuilder.createConvert(loc, symType, indexVal);
+    addrVal = converter.getSymbolAddress(*arg);
+    storeOp = firOpBuilder.create<fir::StoreOp>(loc, cvtVal, addrVal);
+    argIndex++;
+  }
+  // Set the insert for the terminator operation to go at the end of the
+  // block - this is either empty or the block with the stores above,
+  // the end of the block works for both.
+  mlir::Block &block = op.getRegion().back();
+  firOpBuilder.setInsertionPointToEnd(&block);
+
+  // If it is an unstructured region and is not the outer region of a combined
+  // construct, create empty blocks for all evaluations.
+  if (eval.lowerAsUnstructured() && !outerCombined)
+    Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
+                                            mlir::omp::YieldOp>(
+        firOpBuilder, eval.getNestedEvaluations());
+
+  // Insert the terminator.
+  Fortran::lower::genOpenMPTerminator(firOpBuilder, op.getOperation(), loc);
+  // Reset the insert point to before the terminator.
+  resetBeforeTerminator(firOpBuilder, storeOp, block);
+
+  // Handle privatization. Do not privatize if this is the outer operation.
+  if (clauses && !outerCombined) {
+    constexpr bool isLoop = std::is_same_v<Op, mlir::omp::WsLoopOp> ||
+                            std::is_same_v<Op, mlir::omp::SimdLoopOp>;
+    if (!dsp) {
+      DataSharingProcessor proc(converter, *clauses, eval);
+      proc.processStep1();
+      proc.processStep2(op, isLoop);
+    } else {
+      if (isLoop && args.size() > 0)
+        dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
+      dsp->processStep2(op, isLoop);
+    }
+
+    if (storeOp)
+      firOpBuilder.setInsertionPointAfter(storeOp);
+  }
+}
+
 /// Create the body (block) for an OpenMP Operation.
 ///
 /// \param [in]    op - the operation the body belongs to.
@@ -2960,7 +3058,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
                    const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
-      linearStepVars, reductionVars;
+      linearStepVars, privateVars, reductionVars;
   mlir::Value scheduleChunkClauseOperand;
   mlir::IntegerAttr orderedClauseOperand;
   mlir::omp::ClauseOrderKindAttr orderClauseOperand;
@@ -3069,9 +3167,23 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     return;
   }
 
+    // Collect the loops to collapse.
+  Fortran::lower::pft::Evaluation *doConstructEval =
+      &eval.getFirstNestedEvaluation();
+  Fortran::lower::pft::Evaluation *doLoop =
+      &doConstructEval->getFirstNestedEvaluation();
+  auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
+  assert(doStmt && "Expected do loop to be in the nested evaluation");
+  const auto &loopControl =
+      std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
+  const Fortran::parser::LoopControl::Bounds *bounds =
+      std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+  assert(bounds && "Expected bounds for worksharing do loop");
+  privateVars.push_back(converter.getSymbolAddress(*bounds->name.thing.symbol));
+
   auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
       currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
-      reductionVars,
+      privateVars, reductionVars,
       reductionDeclSymbols.empty()
           ? nullptr
           : mlir::ArrayAttr::get(firOpBuilder.getContext(),
@@ -3107,7 +3219,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
       wsLoopOp.setNowaitAttr(nowaitClauseOperand);
   }
 
-  createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
+  createBodyOfLoopOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
                                       eval, &loopOpClauseList, iv,
                                       /*outer=*/false, &dsp);
 }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 92ab32b0131fa6..b19efcaa867676 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -479,6 +479,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
              Variadic<IntLikeType>:$step,
              Variadic<AnyType>:$linear_vars,
              Variadic<I32>:$linear_step_vars,
+             Variadic<OpenMP_PointerLikeType>:$privates,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,
              OptionalAttr<SymbolRefArrayAttr>:$reductions,
              OptionalAttr<ScheduleKindAttr>:$schedule_val,
@@ -517,6 +518,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
           |`nowait` $nowait
           |`ordered` `(` $ordered_val `)`
           |`order` `(` custom<ClauseAttr>($order_val) `)`
+          |`private` `(` custom<PrivateEntries>($privates, type($privates)) `)`
           |`reduction` `(`
               custom<ReductionVarList>(
                 $reduction_vars, type($reduction_vars), $reductions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 20df0099cbd24d..3da37231e3c52c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -178,6 +178,67 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
   p << stringifyEnum(attr.getValue());
 }
 
+static ParseResult
+parsePrivateEntries(OpAsmParser &parser,
+                SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+                SmallVectorImpl<Type> &privateOperandTypes) {
+  OpAsmParser::UnresolvedOperand arg;
+  OpAsmParser::UnresolvedOperand blockArg;
+  Type argType;
+  auto parseEntries = [&]() -> ParseResult {
+    if (parser.parseOperand(arg) || parser.parseArrow() ||
+        parser.parseOperand(blockArg))
+      return failure();
+    privateOperands.push_back(arg);
+    return success();
+  };
+
+  auto parseTypes = [&]() -> ParseResult {
+    if (parser.parseType(argType))
+      return failure();
+    privateOperandTypes.push_back(argType);
+    return success();
+  };
+
+  if (parser.parseCommaSeparatedList(parseEntries))
+    return failure();
+
+  if (parser.parseColon())
+    return failure();
+
+  if (parser.parseCommaSeparatedList(parseTypes))
+    return failure();
+
+  return success();
+}
+
+static void printPrivateEntries(OpAsmPrinter &p, Operation *op,
+                            OperandRange privateOperands,
+                            TypeRange privateOperandTypes) {
+  auto &region = op->getRegion(0);
+
+  unsigned argIndex = 0;
+  unsigned offset = 0;
+  if (auto wsLoop = dyn_cast<WsLoopOp>(op))
+    offset = wsLoop.getNumLoops();
+  for (const auto &privOperand : privateOperands) {
+    const auto &blockArg = region.front().getArgument(argIndex+offset);
+    p << privOperand << " -> " << blockArg;
+    argIndex++;
+    if (argIndex < privateOperands.size())
+      p << ", ";
+  }
+  p << " : ";
+
+  argIndex = 0;
+  for (const auto &privOperandType : privateOperandTypes) {
+    p << privOperandType;
+    argIndex++;
+    if (argIndex < privateOperands.size())
+      p << ", ";
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Parser and printer for Linear Clause
 //===----------------------------------------------------------------------===//
@@ -1086,7 +1147,14 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
                       ValueRange steps, TypeRange loopVarTypes,
                       UnitAttr inclusive) {
   auto args = region.front().getArguments();
-  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
+  p << " (";
+  unsigned numLoops = steps.size();
+  for (unsigned i=0; i<numLoops; i++) {
+    if (i != 0)
+      p << ", ";
+    p << args[i];
+  }
+  p << ") : " << args[0].getType() << " = (" << lowerBound
     << ") to (" << upperBound << ") ";
   if (inclusive)
     p << "inclusive ";
@@ -1269,7 +1337,8 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &state,
                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
   build(builder, state, lowerBound, upperBound, step,
         /*linear_vars=*/ValueRange(),
-        /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
+        /*linear_step_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
+	/*reduction_vars=*/ValueRange(),
         /*reductions=*/nullptr, /*schedule_val=*/nullptr,
         /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
         /*simd_modifier=*/false, /*nowait=*/false, /*ordered_val=*/nullptr,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..26ec12e427ff35 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -833,6 +833,24 @@ static void collectReductionInfo(
   }
 }
 
+/// Allocate space for privatized reduction variables.
+void
+allocPrivatizationVars(omp::WsLoopOp loop, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation,
+                   llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
+  unsigned offset = loop.getNumLoops();
+  unsigned numArgs = loop.getRegion().front().getNumArguments();
+  llvm::IRBuilderBase::InsertPointGuard guard(builder);
+  builder.restoreIP(allocaIP);
+  for (unsigned i = offset; i < numArgs; ++i) {
+    if (auto op = loop.getPrivates()[i-offset].getDefiningOp<LLVM::AllocaOp>()) {
+      llvm::Value *var = builder.CreateAlloca(moduleTranslation.convertType(op.getResultPtrElementType()));
+      //    moduleTranslation.convertType(loop.getPrivates()[i-offset].getType()));
+      moduleTranslation.mapValue(loop.getRegion().front().getArgument(i), var);
+    }
+  }
+}
+
 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -861,6 +879,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
+  allocPrivatizationVars(loop, builder, moduleTranslation, allocaIP);
+
   SmallVector<llvm::Value *> privateReductionVariables;
   DenseMap<Value, llvm::Value *> reductionVariableMap;
   allocReductionVars(loop, builder, moduleTranslation, allocaIP, reductionDecls,
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3fbeaebb592a4d..86576831f67dff 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -79,7 +79,7 @@ func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4:
       // CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> ()
       "test.payload"(%arg6, %arg7) : (index, index) -> ()
       omp.yield
-    }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
+    }) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
     omp.terminator
   }
   return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b153b1b8221d80..134cd4227af485 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -141,7 +141,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
   "omp.wsloop" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>, ordered_val = 1} :
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>, ordered_val = 1} :
     (index, index, index) -> ()
 
   // CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(static)
@@ -149,7 +149,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
   "omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0>, schedule_val = #omp<schedulekind static>} :
+  }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0,0>, schedule_val = #omp<schedulekind static>} :
     (index, index, index, memref<i32>, i32) -> ()
 
   // CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>, %{{.*}} = %{{.*}} : memref<i32>) schedule(static)
@@ -157,7 +157,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
   "omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %linear_var, %linear_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operandSegmentSizes = array<i32: 1,1,1,2,2,0,0>, schedule_val = #omp<schedulekind static>} :
+  }) {operandSegmentSizes = array<i32: 1,1,1,2,2,0,0,0>, schedule_val = #omp<schedulekind static>} :
     (index, index, index, memref<i32>, memref<i32>, i32, i32) -> ()
 
   // CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(dynamic = %{{.*}}) ordered(2)
@@ -165,7 +165,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
   "omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var, %chunk_var) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,1>, schedule_val = #omp<schedulekind dynamic>, ordered_val = 2} :
+  }) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0,1>, schedule_val = #omp<schedulekind dynamic>, ordered_val = 2} :
     (index, index, index, memref<i32>, i32, i32) -> ()
 
   // CHECK: omp.wsloop schedule(auto) nowait
@@ -173,7 +173,7 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
   "omp.wsloop" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>, nowait, schedule_val = #omp<schedulekind auto>} :
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>, nowait, schedule_val = #omp<schedulekind auto>} :
     (index, index, index) -> ()
 
   return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 1c02c0265462c2..e0caaf06c03a10 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -310,7 +310,7 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
       llvm.store %3, %4 : f32, !llvm.ptr
       omp.yield
       // CHECK: call void @__kmpc_for_static_fini(ptr @[[$loc_struct]],
-    }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+    }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
     omp.terminator
   }
   llvm.return
@@ -330,7 +330,7 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
     %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     llvm.store %3, %4 : f32, !llvm.ptr
     omp.yield
-  }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+  }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
   llvm.return
 }
 
@@ -348,7 +348,7 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr) {
     %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     llvm.store %3, %4 : f32, !llvm.ptr
     omp.yield
-  }) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+  }) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
   llvm.return
 }
 



More information about the flang-commits mailing list