[flang-commits] [flang] [mlir] [flang][OpenMP][MLIR] Basic support for delayed privatization code-gen (PR #81833)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Thu Feb 15 22:16:53 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/81833

>From a8a577013ace78e38d0443e532213840cbcab0ff Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Sun, 11 Feb 2024 09:37:59 -0600
Subject: [PATCH 1/2] [MLIR][OpenMP] Add `private` clause to `omp.parallel`

Extends the `omp.parallel` op by adding a `private` clause to model
[first]private variables. This uses the `omp.private` op to map
privatized variables to their corresponding privatizers.

Example `omp.private` op with `private` variable:
```
omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) {
    // ... use %arg1 ...
    omp.terminator
}
```

Whether the variable is private or firstprivate is determined by the
attributes of the corresponding `omp.private` op.
---
 flang/lib/Lower/OpenMP.cpp                    |   3 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   8 +-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |   4 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 160 ++++++++++++++----
 mlir/test/Dialect/OpenMP/invalid.mlir         |  56 ++++++
 mlir/test/Dialect/OpenMP/ops-2.mlir           |  74 ++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             |  10 +-
 mlir/test/Dialect/OpenMP/roundtrip.mlir       |  21 ---
 8 files changed, 269 insertions(+), 67 deletions(-)
 create mode 100644 mlir/test/Dialect/OpenMP/ops-2.mlir
 delete mode 100644 mlir/test/Dialect/OpenMP/roundtrip.mlir

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 7519da844eebba..9397af8b8bd05e 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  reductionDeclSymbols),
-      procBindKindAttr);
+      procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
+      /*privatizers=*/nullptr);
 }
 
 static mlir::omp::SectionOp
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 024f43f7e7e3b6..f907e21e9b4de2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
              Variadic<AnyType>:$allocators_vars,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,
              OptionalAttr<SymbolRefArrayAttr>:$reductions,
-             OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
+             OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
+             Variadic<AnyType>:$private_vars,
+             OptionalAttr<SymbolRefArrayAttr>:$privatizers);
 
   let regions = (region AnyRegion:$region);
 
@@ -297,7 +299,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
                 $allocators_vars, type($allocators_vars)
               ) `)`
           | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
-    ) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
+    ) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars),
+                             $reductions, $private_vars, type($private_vars),
+                             $privatizers) attr-dict
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ea5f31ee8c6aa7..464a647564aced 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocators_vars = */ llvm::SmallVector<Value>{},
         /* reduction_vars = */ llvm::SmallVector<Value>{},
         /* reductions = */ ArrayAttr{},
-        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
+        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
+        /* private_vars = */ ValueRange(),
+        /* privatizers = */ nullptr);
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 82e32da91aaeeb..c2b471ab96183f 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
 // Parser, printer and verifier for ReductionVarList
 //===----------------------------------------------------------------------===//
 
-ParseResult
-parseReductionClause(OpAsmParser &parser, Region &region,
-                     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
-                     SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
-                     SmallVectorImpl<OpAsmParser::Argument> &privates) {
-  if (failed(parser.parseOptionalKeyword("reduction")))
-    return failure();
-
+ParseResult parseClauseWithRegionArgs(
+    OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+    SmallVectorImpl<Type> &types, ArrayAttr &symbols,
+    SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
   SmallVector<SymbolRefAttr> reductionVec;
+  unsigned regionArgOffset = regionPrivateArgs.size();
 
   if (failed(
           parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
             if (parser.parseAttribute(reductionVec.emplace_back()) ||
                 parser.parseOperand(operands.emplace_back()) ||
                 parser.parseArrow() ||
-                parser.parseArgument(privates.emplace_back()) ||
+                parser.parseArgument(regionPrivateArgs.emplace_back()) ||
                 parser.parseColonType(types.emplace_back()))
               return failure();
             return success();
           })))
     return failure();
 
-  for (auto [prv, type] : llvm::zip_equal(privates, types)) {
+  auto *argsBegin = regionPrivateArgs.begin();
+  MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
+                               argsBegin + regionArgOffset + types.size());
+  for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
     prv.type = type;
   }
   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
-  reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
+  symbols = ArrayAttr::get(parser.getContext(), reductions);
   return success();
 }
 
-static void printReductionClause(OpAsmPrinter &p, Operation *op,
-                                 ValueRange reductionArgs, ValueRange operands,
-                                 TypeRange types, ArrayAttr reductionSymbols) {
-  p << "reduction(";
+static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
+                                      ValueRange argsSubrange,
+                                      StringRef clauseName, ValueRange operands,
+                                      TypeRange types, ArrayAttr symbols) {
+  p << clauseName << "(";
   llvm::interleaveComma(
-      llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
-      [&p](auto t) {
+      llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
         auto [sym, op, arg, type] = t;
         p << sym << " " << op << " -> " << arg << " : " << type;
       });
   p << ") ";
 }
 
-static ParseResult
-parseParallelRegion(OpAsmParser &parser, Region &region,
-                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
-                    SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
+static ParseResult parseParallelRegion(
+    OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
+    SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
+    llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
+    llvm::SmallVectorImpl<Type> &privateVarsTypes,
+    ArrayAttr &privatizerSymbols) {
+  llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
 
-  llvm::SmallVector<OpAsmParser::Argument> privates;
-  if (succeeded(parseReductionClause(parser, region, operands, types,
-                                     reductionSymbols, privates)))
-    return parser.parseRegion(region, privates);
+  if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+    if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
+                                         reductionVarTypes, reductionSymbols,
+                                         regionPrivateArgs)))
+      return failure();
+  }
 
-  return parser.parseRegion(region);
+  if (succeeded(parser.parseOptionalKeyword("private"))) {
+    if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
+                                         privateVarsTypes, privatizerSymbols,
+                                         regionPrivateArgs)))
+      return failure();
+  }
+
+  return parser.parseRegion(region, regionPrivateArgs);
 }
 
 static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
-                                ValueRange operands, TypeRange types,
-                                ArrayAttr reductionSymbols) {
-  if (reductionSymbols)
-    printReductionClause(p, op, region.front().getArguments(), operands, types,
-                         reductionSymbols);
+                                ValueRange reductionVarOperands,
+                                TypeRange reductionVarTypes,
+                                ArrayAttr reductionSymbols,
+                                ValueRange privateVarOperands,
+                                TypeRange privateVarTypes,
+                                ArrayAttr privatizerSymbols) {
+  if (reductionSymbols) {
+    auto *argsBegin = region.front().getArguments().begin();
+    MutableArrayRef argsSubrange(argsBegin,
+                                 argsBegin + reductionVarTypes.size());
+    printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
+                              reductionVarOperands, reductionVarTypes,
+                              reductionSymbols);
+  }
+
+  if (privatizerSymbols) {
+    auto *argsBegin = region.front().getArguments().begin();
+    MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
+                                 argsBegin + reductionVarOperands.size() +
+                                     privateVarTypes.size());
+    printClauseWithRegionArgs(p, op, argsSubrange, "private",
+                              privateVarOperands, privateVarTypes,
+                              privatizerSymbols);
+  }
+
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
@@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
       /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
-      /*proc_bind_val=*/nullptr);
+      /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
+      /*privatizers=*/nullptr);
   state.addAttributes(attributes);
 }
 
+template <typename OpType>
+static LogicalResult verifyPrivateVarList(OpType &op) {
+  auto privateVars = op.getPrivateVars();
+  auto privatizers = op.getPrivatizersAttr();
+
+  if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
+    return success();
+
+  auto numPrivateVars = privateVars.size();
+  auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
+
+  if (numPrivateVars != numPrivatizers)
+    return op.emitError() << "inconsistent number of private variables and "
+                             "privatizer op symbols, private vars: "
+                          << numPrivateVars
+                          << " vs. privatizer op symbols: " << numPrivatizers;
+
+  for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
+    Type varType = std::get<0>(privateVarInfo).getType();
+    SymbolRefAttr privatizerSym =
+        std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
+    PrivateClauseOp privatizerOp =
+        SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
+                                                              privatizerSym);
+
+    if (privatizerOp == nullptr)
+      return op.emitError() << "failed to lookup privatizer op with symbol: '"
+                            << privatizerSym << "'";
+
+    Type privatizerType = privatizerOp.getType();
+
+    if (varType != privatizerType)
+      return op.emitError()
+             << "type mismatch between a "
+             << (privatizerOp.getDataSharingType() ==
+                         DataSharingClauseType::Private
+                     ? "private"
+                     : "firstprivate")
+             << " variable and its privatizer op, var type: " << varType
+             << " vs. privatizer op type: " << privatizerType;
+  }
+
+  return success();
+}
+
 LogicalResult ParallelOp::verify() {
   if (getAllocateVars().size() != getAllocatorsVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
+
+  if (failed(verifyPrivateVarList(*this)))
+    return failure();
+
   return verifyReductionVarList(*this, getReductions(), getReductionVars());
 }
 
@@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region &region,
 
   // Parse an optional reduction clause
   llvm::SmallVector<OpAsmParser::Argument> privates;
-  bool hasReduction = succeeded(
-      parseReductionClause(parser, region, reductionOperands, reductionTypes,
-                           reductionSymbols, privates));
+  bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
+                      succeeded(parseClauseWithRegionArgs(
+                          parser, region, reductionOperands, reductionTypes,
+                          reductionSymbols, privates));
 
   if (parser.parseKeyword("for"))
     return failure();
@@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region &region,
   if (reductionSymbols) {
     auto reductionArgs =
         region.front().getArguments().drop_front(loopVarTypes.size());
-    printReductionClause(p, op, reductionArgs, reductionOperands,
-                         reductionTypes, reductionSymbols);
+    printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
+                              reductionOperands, reductionTypes,
+                              reductionSymbols);
   }
 
   p << " for ";
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 7965d6dc284203..d9261b89e24e3a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1865,3 +1865,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
 ^bb0(%arg0: f32):
   omp.yield(%arg0 : f32)
 }
+
+// -----
+
+func.func @private_type_mismatch(%arg0: index) {
+// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
+  omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
+    omp.terminator
+  }
+
+  return
+}
+
+omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+  omp.yield(%arg0 : !llvm.ptr)
+}
+
+// -----
+
+func.func @firstprivate_type_mismatch(%arg0: index) {
+  // expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
+  omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
+    omp.terminator
+  }
+
+  return
+}
+
+omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+  omp.yield(%arg0 : !llvm.ptr)
+} copy {
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  omp.yield(%arg0 : !llvm.ptr)
+}
+
+// -----
+
+func.func @undefined_privatizer(%arg0: index) {
+  // expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
+  omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+func.func @undefined_privatizer(%arg0: !llvm.ptr) {
+  // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
+    ^bb0(%arg2: !llvm.ptr):
+      omp.terminator
+    }) : (!llvm.ptr) -> ()
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops-2.mlir b/mlir/test/Dialect/OpenMP/ops-2.mlir
new file mode 100644
index 00000000000000..9340e0b5fdf971
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/ops-2.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: parallel_op_privatizers
+func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+  // CHECK: omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr)
+  omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
+    %0 = llvm.load %arg2 : !llvm.ptr -> i32
+    %1 = llvm.load %arg3 : !llvm.ptr -> i32
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
+omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
+// CHECK: ^bb0(%arg0: {{.*}}):
+^bb0(%arg0: !llvm.ptr):
+  omp.yield(%arg0 : !llvm.ptr)
+}
+
+// CHECK: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
+omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
+// CHECK: ^bb0(%arg0: {{.*}}):
+^bb0(%arg0: !llvm.ptr):
+  omp.yield(%arg0 : !llvm.ptr)
+// CHECK: } copy {
+} copy {
+// CHECK: ^bb0(%arg0: {{.*}}, %arg1: {{.*}}):
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+  omp.yield(%arg0 : !llvm.ptr)
+}
+
+// CHECK-LABEL: parallel_op_reduction_and_private
+func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) {
+  // CHECK: omp.parallel
+  // CHECK-SAME: reduction(
+  // CHECK-SAME: @add_f32 %[[reduc_var:[0-9a-z]+]] -> %[[reduc_arg:[0-9a-z]+]] : !llvm.ptr,
+  // CHECK-SAME: @add_f32 %[[reduc_var2:[0-9a-z]+]] -> %[[reduc_arg2:[0-9a-z]+]] : !llvm.ptr)
+  //
+  // CHECK-SAME: private(
+  // CHECK-SAME: @x.privatizer %[[priv_var:[0-9a-z]+]] -> %[[priv_arg:[0-9a-z]+]] : !llvm.ptr,
+  // CHECK-SAME: @y.privatizer %[[priv_var2:[0-9a-z]+]] -> %[[priv_arg2:[0-9a-z]+]] : !llvm.ptr)
+  omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr)
+               private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) {
+    // CHECK: llvm.load %[[priv_arg]]
+    %0 = llvm.load %priv_arg : !llvm.ptr -> f32
+    // CHECK: llvm.load %[[priv_arg2]]
+    %1 = llvm.load %priv_arg2 : !llvm.ptr -> f32
+    // CHECK: llvm.load %[[reduc_arg]]
+    %2 = llvm.load %reduc_arg : !llvm.ptr -> f32
+    // CHECK: llvm.load %[[reduc_arg2]]
+    %3 = llvm.load %reduc_arg2 : !llvm.ptr -> f32
+    omp.terminator
+  }
+  return
+}
+
+omp.reduction.declare @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+  %2 = llvm.load %arg3 : !llvm.ptr -> f32
+  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+  omp.yield
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d92a554caf77ca..c92ef895306ed3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -59,7 +59,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%num_threads, %data_var, %data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,1,1,1,0>} : (i32, memref<i32>, memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 0,1,1,1,0,0>} : (i32, memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -68,22 +68,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%if_cond, %data_var, %data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,0,1,1,0>} : (i1, memref<i32>, memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 1,0,1,1,0,0>} : (i1, memref<i32>, memref<i32>) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,1,1,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   return
 }
diff --git a/mlir/test/Dialect/OpenMP/roundtrip.mlir b/mlir/test/Dialect/OpenMP/roundtrip.mlir
deleted file mode 100644
index 2553442638ee84..00000000000000
--- a/mlir/test/Dialect/OpenMP/roundtrip.mlir
+++ /dev/null
@@ -1,21 +0,0 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
-
-// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
-omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
-// CHECK: ^bb0(%arg0: {{.*}}):
-^bb0(%arg0: !llvm.ptr):
-  omp.yield(%arg0 : !llvm.ptr)
-}
-
-// CHECK: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
-omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
-// CHECK: ^bb0(%arg0: {{.*}}):
-^bb0(%arg0: !llvm.ptr):
-  omp.yield(%arg0 : !llvm.ptr)
-// CHECK: } copy {
-} copy {
-// CHECK: ^bb0(%arg0: {{.*}}, %arg1: {{.*}}):
-^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
-  omp.yield(%arg0 : !llvm.ptr)
-}
-

>From 441e95906e7018090b08aa2f880f1cca57d1ae6a Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 14 Feb 2024 04:24:06 -0600
Subject: [PATCH 2/2] [flang][OpenMP][MLIR] Basic support for delayed
 privatization code-gen

Adds basic support for emitting delayed privatizers from flang. So far,
only types of symbols are supported (i.e. scalars), support for more
complicated types will be added later. This also makes sure that
reductio and delayed privatization work properly together by merging the
body-gen callbacks for both in case both clauses are present on the
parallel construct.
---
 flang/include/flang/Lower/AbstractConverter.h |   4 +
 flang/lib/Lower/Bridge.cpp                    |   2 +-
 flang/lib/Lower/OpenMP.cpp                    | 235 ++++++++++++++++--
 .../delayed-privatization-firstprivate.f90    |  26 ++
 .../FIR/delayed-privatization-private.f90     |  38 +++
 .../OpenMP/FIR/parallel-reduction-private.f90 |  30 +++
 6 files changed, 307 insertions(+), 28 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/FIR/delayed-privatization-firstprivate.f90
 create mode 100644 flang/test/Lower/OpenMP/FIR/delayed-privatization-private.f90
 create mode 100644 flang/test/Lower/OpenMP/FIR/parallel-reduction-private.f90

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 796933a4eb5f68..55bc33e76e5f6e 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -16,6 +16,7 @@
 #include "flang/Common/Fortran.h"
 #include "flang/Lower/LoweringOptions.h"
 #include "flang/Lower/PFTDefs.h"
+#include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Semantics/symbol.h"
 #include "mlir/IR/Builders.h"
@@ -296,6 +297,9 @@ class AbstractConverter {
     return loweringOptions;
   }
 
+  virtual Fortran::lower::SymbolBox
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2d7f748cefa2d8..a7a5b826ef7a9e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1070,7 +1070,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Find the symbol in one level up of symbol map such as for host-association
   /// in OpenMP code or return null.
   Fortran::lower::SymbolBox
-  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
+  lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
     if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
       return v;
     return {};
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 9397af8b8bd05e..c95f9a955e54d4 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -40,6 +40,12 @@ static llvm::cl::opt<bool> treatIndexAsSection(
     llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
     llvm::cl::init(true));
 
+static llvm::cl::opt<bool> enableDelayedPrivatization(
+    "openmp-enable-delayed-privatization",
+    llvm::cl::desc(
+        "Emit `[first]private` variables as clauses on the MLIR ops."),
+    llvm::cl::init(false));
+
 using DeclareTargetCapturePair =
     std::pair<mlir::omp::DeclareTargetCaptureClause,
               Fortran::semantics::Symbol>;
@@ -147,6 +153,14 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
 //===----------------------------------------------------------------------===//
 
 class DataSharingProcessor {
+public:
+  struct DelayedPrivatizationInfo {
+    llvm::SmallVector<mlir::SymbolRefAttr> privatizers;
+    llvm::SmallVector<mlir::Value> hostAddresses;
+    llvm::SmallVector<const Fortran::semantics::Symbol *> hostSymbols;
+  };
+
+private:
   bool hasLastPrivateOp;
   mlir::OpBuilder::InsertPoint lastPrivIP;
   mlir::OpBuilder::InsertPoint insPt;
@@ -161,6 +175,11 @@ class DataSharingProcessor {
   const Fortran::parser::OmpClauseList &opClauseList;
   Fortran::lower::pft::Evaluation &eval;
 
+  bool useDelayedPrivatization;
+  Fortran::lower::SymMap *symTable;
+
+  DelayedPrivatizationInfo delayedPrivatizationInfo;
+
   bool needBarrier();
   void collectSymbols(Fortran::semantics::Symbol::Flag flag);
   void collectOmpObjectListSymbol(
@@ -171,10 +190,13 @@ class DataSharingProcessor {
   void collectDefaultSymbols();
   void privatize();
   void defaultPrivatize();
+  void doPrivatize(const Fortran::semantics::Symbol *sym);
   void copyLastPrivatize(mlir::Operation *op);
   void insertLastPrivateCompare(mlir::Operation *op);
   void cloneSymbol(const Fortran::semantics::Symbol *sym);
-  void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
+  void
+  copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym,
+                         mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr);
   void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
                              mlir::OpBuilder::InsertPoint *lastPrivIP);
   void insertDeallocs();
@@ -182,10 +204,14 @@ class DataSharingProcessor {
 public:
   DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
                        const Fortran::parser::OmpClauseList &opClauseList,
-                       Fortran::lower::pft::Evaluation &eval)
+                       Fortran::lower::pft::Evaluation &eval,
+                       bool useDelayedPrivatization = false,
+                       Fortran::lower::SymMap *symTable = nullptr)
       : hasLastPrivateOp(false), converter(converter),
         firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
-        eval(eval) {}
+        eval(eval), useDelayedPrivatization(useDelayedPrivatization),
+        symTable(symTable) {}
+
   // Privatisation is split into two steps.
   // Step1 performs cloning of all privatisation clauses and copying for
   // firstprivates. Step1 is performed at the place where process/processStep1
@@ -204,6 +230,10 @@ class DataSharingProcessor {
     assert(!loopIV && "Loop iteration variable already set");
     loopIV = iv;
   }
+
+  const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
+    return delayedPrivatizationInfo;
+  }
 };
 
 void DataSharingProcessor::processStep1() {
@@ -250,9 +280,10 @@ void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
 }
 
 void DataSharingProcessor::copyFirstPrivateSymbol(
-    const Fortran::semantics::Symbol *sym) {
+    const Fortran::semantics::Symbol *sym,
+    mlir::OpBuilder::InsertPoint *copyAssignIP) {
   if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
-    converter.copyHostAssociateVar(*sym);
+    converter.copyHostAssociateVar(*sym, copyAssignIP);
 }
 
 void DataSharingProcessor::copyLastPrivateSymbol(
@@ -491,14 +522,10 @@ void DataSharingProcessor::privatize() {
   for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
     if (const auto *commonDet =
             sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
-      for (const auto &mem : commonDet->objects()) {
-        cloneSymbol(&*mem);
-        copyFirstPrivateSymbol(&*mem);
-      }
-    } else {
-      cloneSymbol(sym);
-      copyFirstPrivateSymbol(sym);
-    }
+      for (const auto &mem : commonDet->objects())
+        doPrivatize(&*mem);
+    } else
+      doPrivatize(sym);
   }
 }
 
@@ -522,11 +549,96 @@ void DataSharingProcessor::defaultPrivatize() {
         !sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
         !symbolsInNestedRegions.contains(sym) &&
         !symbolsInParentRegions.contains(sym) &&
-        !privatizedSymbols.contains(sym)) {
+        !privatizedSymbols.contains(sym))
+      doPrivatize(sym);
+  }
+}
+
+void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
+  if (!useDelayedPrivatization) {
+    cloneSymbol(sym);
+    copyFirstPrivateSymbol(sym);
+    return;
+  }
+
+  Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
+  assert(hsb && "Host symbol box not found");
+
+  mlir::Type symType = hsb.getAddr().getType();
+  mlir::Location symLoc = hsb.getAddr().getLoc();
+  std::string privatizerName = sym->name().ToString() + ".privatizer";
+  bool isFirstPrivate =
+      sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
+
+  mlir::omp::PrivateClauseOp privatizerOp = [&]() {
+    auto moduleOp = firOpBuilder.getModule();
+
+    auto uniquePrivatizerName = fir::getTypeAsString(
+        symType, converter.getKindMap(),
+        sym->name().ToString() +
+            (isFirstPrivate ? "_firstprivate" : "_private"));
+
+    if (auto existingPrivatizer =
+            moduleOp.lookupSymbol<mlir::omp::PrivateClauseOp>(
+                uniquePrivatizerName))
+      return existingPrivatizer;
+
+    auto ip = firOpBuilder.saveInsertionPoint();
+    firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
+                                   moduleOp.getBodyRegion().front().begin());
+    auto result = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
+        symLoc, uniquePrivatizerName, symType,
+        isFirstPrivate ? mlir::omp::DataSharingClauseType ::FirstPrivate
+                       : mlir::omp::DataSharingClauseType::Private);
+
+    symTable->pushScope();
+
+    // Populate the `alloc` region.
+    {
+      mlir::Region &allocRegion = result.getAllocRegion();
+      mlir::Block *allocEntryBlock = firOpBuilder.createBlock(
+          &allocRegion, /*insertPt=*/{}, symType, symLoc);
+
+      firOpBuilder.setInsertionPointToEnd(allocEntryBlock);
+      symTable->addSymbol(*sym, allocRegion.getArgument(0));
+      symTable->pushScope();
       cloneSymbol(sym);
-      copyFirstPrivateSymbol(sym);
+      firOpBuilder.create<mlir::omp::YieldOp>(
+          hsb.getAddr().getLoc(),
+          symTable->shallowLookupSymbol(*sym).getAddr());
+      symTable->popScope();
     }
-  }
+
+    // Poplate the `copy` region if this is a `firstprivate`.
+    if (isFirstPrivate) {
+      mlir::Region &copyRegion = result.getCopyRegion();
+      // First block argument corresponding to the original/host value while
+      // second block argument corresponding to the privatized value.
+      mlir::Block *copyEntryBlock = firOpBuilder.createBlock(
+          &copyRegion, /*insertPt=*/{}, {symType, symType}, {symLoc, symLoc});
+      firOpBuilder.setInsertionPointToEnd(copyEntryBlock);
+      symTable->addSymbol(*sym, copyRegion.getArgument(0),
+                          /*force=*/true);
+      symTable->pushScope();
+      symTable->addSymbol(*sym, copyRegion.getArgument(1));
+      auto ip = firOpBuilder.saveInsertionPoint();
+      copyFirstPrivateSymbol(sym, &ip);
+
+      firOpBuilder.create<mlir::omp::YieldOp>(
+          hsb.getAddr().getLoc(),
+          symTable->shallowLookupSymbol(*sym).getAddr());
+      symTable->popScope();
+    }
+
+    symTable->popScope();
+    firOpBuilder.restoreInsertionPoint(ip);
+    return result;
+  }();
+
+  delayedPrivatizationInfo.privatizers.push_back(
+      mlir::SymbolRefAttr::get(privatizerOp));
+  delayedPrivatizationInfo.hostAddresses.push_back(hsb.getAddr());
+  delayedPrivatizationInfo.hostSymbols.push_back(sym);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2585,6 +2697,7 @@ genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
 
 static mlir::omp::ParallelOp
 genParallelOp(Fortran::lower::AbstractConverter &converter,
+              Fortran::lower::SymMap &symTable,
               Fortran::semantics::SemanticsContext &semaCtx,
               Fortran::lower::pft::Evaluation &eval, bool genNested,
               mlir::Location currentLocation,
@@ -2617,8 +2730,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   auto reductionCallback = [&](mlir::Operation *op) {
     llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
                                            currentLocation);
-    auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
-                                                         reductionTypes, locs);
+    auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
+                                                          reductionTypes, locs);
     for (auto [arg, prv] :
          llvm::zip_equal(reductionSymbols, block->getArguments())) {
       converter.bindSymbol(*arg, prv);
@@ -2626,13 +2739,78 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     return reductionSymbols;
   };
 
-  return genOpWithBody<mlir::omp::ParallelOp>(
+  OpWithBodyGenInfo genInfo =
       OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
           .setGenNested(genNested)
           .setOuterCombined(outerCombined)
           .setClauses(&clauseList)
           .setReductions(&reductionSymbols, &reductionTypes)
-          .setGenRegionEntryCb(reductionCallback),
+          .setGenRegionEntryCb(reductionCallback);
+
+  if (!enableDelayedPrivatization) {
+    return genOpWithBody<mlir::omp::ParallelOp>(
+        genInfo,
+        /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
+        numThreadsClauseOperand, allocateOperands, allocatorOperands,
+        reductionVars,
+        reductionDeclSymbols.empty()
+            ? nullptr
+            : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                   reductionDeclSymbols),
+        procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
+        /*privatizers=*/nullptr);
+  }
+
+  bool privatize = !outerCombined;
+  DataSharingProcessor dsp(converter, clauseList, eval,
+                           /*useDelayedPrivatization=*/true, &symTable);
+
+  if (privatize)
+    dsp.processStep1();
+
+  const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
+
+  auto genRegionEntryCB = [&](mlir::Operation *op) {
+    auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
+
+    llvm::SmallVector<mlir::Location> reductionLocs(reductionVars.size(),
+                                                    currentLocation);
+
+    auto privateVars = parallelOp.getPrivateVars();
+    auto &region = parallelOp.getRegion();
+
+    llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
+    privateVarTypes.reserve(privateVars.size());
+    llvm::transform(privateVars, std::back_inserter(privateVarTypes),
+                    [](mlir::Value v) { return v.getType(); });
+
+    llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
+    privateVarLocs.reserve(privateVars.size());
+    llvm::transform(privateVars, std::back_inserter(privateVarLocs),
+                    [](mlir::Value v) { return v.getLoc(); });
+
+    converter.getFirOpBuilder().createBlock(&region, /*insertPt=*/{},
+                                            privateVarTypes, privateVarLocs);
+
+    llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
+        reductionSymbols;
+    allSymbols.append(delayedPrivatizationInfo.hostSymbols);
+    for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
+      converter.bindSymbol(*arg, prv);
+    }
+
+    return allSymbols;
+  };
+
+  // TODO Merge with the reduction CB.
+  genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
+
+  llvm::SmallVector<mlir::Attribute> privatizers(
+      delayedPrivatizationInfo.privatizers.begin(),
+      delayedPrivatizationInfo.privatizers.end());
+
+  return genOpWithBody<mlir::omp::ParallelOp>(
+      genInfo,
       /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
@@ -2640,8 +2818,11 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                  reductionDeclSymbols),
-      procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
-      /*privatizers=*/nullptr);
+      procBindKindAttr, delayedPrivatizationInfo.hostAddresses,
+      delayedPrivatizationInfo.privatizers.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 privatizers));
 }
 
 static mlir::omp::SectionOp
@@ -3635,7 +3816,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genParallelOp(converter, semaCtx, eval, /*genNested=*/false,
+      genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
                     currentLocation, loopOpClauseList,
                     /*outerCombined=*/true);
     }
@@ -3724,8 +3905,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
                        currentLocation);
     break;
   case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-                  beginClauseList);
+    genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
+                  currentLocation, beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_single:
     genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
@@ -3782,7 +3963,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
           .test(directive.v)) {
     bool outerCombined =
         directive.v != llvm::omp::Directive::OMPD_target_parallel;
-    genParallelOp(converter, semaCtx, eval, /*genNested=*/false,
+    genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
                   currentLocation, beginClauseList, outerCombined);
     combinedDirective = true;
   }
@@ -3865,7 +4046,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
 
   // Parallel wrapper of PARALLEL SECTIONS construct
   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
-    genParallelOp(converter, semaCtx, eval,
+    genParallelOp(converter, symTable, semaCtx, eval,
                   /*genNested=*/false, currentLocation, sectionsClauseList,
                   /*outerCombined=*/true);
   } else {
diff --git a/flang/test/Lower/OpenMP/FIR/delayed-privatization-firstprivate.f90 b/flang/test/Lower/OpenMP/FIR/delayed-privatization-firstprivate.f90
new file mode 100644
index 00000000000000..21ffc1d008e313
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/delayed-privatization-firstprivate.f90
@@ -0,0 +1,26 @@
+! Test delayed privatization for the `private` clause.
+
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --openmp-enable-delayed-privatization -o - %s 2>&1 | FileCheck %s
+
+subroutine delayed_privatization_firstprivate
+  implicit none
+  integer :: var1
+
+!$OMP PARALLEL FIRSTPRIVATE(var1)
+  var1 = 10
+!$OMP END PARALLEL
+end subroutine
+
+! CHECK-LABEL: omp.private {type = firstprivate}
+! CHECK-SAME: @[[VAR1_PRIVATIZER_SYM:.*]] : !fir.ref<i32> alloc {
+! CHECK: } copy {
+! CHECK: ^bb0(%[[PRIV_ORIG_ARG:.*]]: !fir.ref<i32>, %[[PRIV_PRIV_ARG:.*]]: !fir.ref<i32>):
+! CHECK:    %[[ORIG_VAL:.*]] = fir.load %[[PRIV_ORIG_ARG]] : !fir.ref<i32>
+! CHECK:    fir.store %[[ORIG_VAL]] to %[[PRIV_PRIV_ARG]] : !fir.ref<i32>
+! CHECK:    omp.yield(%[[PRIV_PRIV_ARG]] : !fir.ref<i32>)
+! CHECK: }
+
+! CHECK-LABEL: @_QPdelayed_privatization_firstprivate
+! CHECK: omp.parallel private(@[[VAR1_PRIVATIZER_SYM]] %{{.*}} -> %{{.*}} : !fir.ref<i32>) {
+! CHECK: omp.terminator
+
diff --git a/flang/test/Lower/OpenMP/FIR/delayed-privatization-private.f90 b/flang/test/Lower/OpenMP/FIR/delayed-privatization-private.f90
new file mode 100644
index 00000000000000..91eaa092f11d69
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/delayed-privatization-private.f90
@@ -0,0 +1,38 @@
+! Test delayed privatization for the `private` clause.
+
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --openmp-enable-delayed-privatization -o - %s 2>&1 | FileCheck %s
+
+subroutine delayed_privatization_private
+  implicit none
+  integer :: var1
+
+!$OMP PARALLEL PRIVATE(var1)
+  var1 = 10
+!$OMP END PARALLEL
+end subroutine
+
+! CHECK-LABEL: omp.private {type = private}
+! CHECK-SAME: @[[PRIVATIZER_SYM:.*]] : !fir.ref<i32> alloc {
+! CHECK-NEXT: ^bb0(%[[PRIV_ARG:.*]]: !fir.ref<i32>):
+! CHECK-NEXT:   %[[PRIV_ALLOC:.*]] = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatization_privateEvar1"}
+! CHECK-NEXT:   omp.yield(%[[PRIV_ALLOC]] : !fir.ref<i32>)
+
+! CHECK-LABEL: @_QPdelayed_privatization_private
+! CHECK: %[[ORIG_ALLOC:.*]] = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatization_privateEvar1"}
+! CHECK: omp.parallel private(@[[PRIVATIZER_SYM]] %[[ORIG_ALLOC]] -> %[[PAR_ARG:.*]] : !fir.ref<i32>) {
+! CHECK: fir.store %{{.*}} to %[[PAR_ARG]] : !fir.ref<i32>
+! CHECK: omp.terminator
+
+! Test that the same privatizer is used if the a variable with the same type and
+! name was previously privatized.
+subroutine delayed_privatization_private_same_type_and_name
+  implicit none
+  integer :: var1
+
+!$OMP PARALLEL PRIVATE(var1)
+  var1 = 10
+!$OMP END PARALLEL
+end subroutine
+
+! CHECK-LABEL: @_QPdelayed_privatization_private_same_type_and_name
+! CHECK: omp.parallel private(@[[PRIVATIZER_SYM]] %[[ORIG_ALLOC]] -> %[[PAR_ARG:.*]] : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-private.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-private.f90
new file mode 100644
index 00000000000000..d72093c8eb73ef
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-private.f90
@@ -0,0 +1,30 @@
+! Test that reductions and delayed privatization work properly togehter. Since
+! both types of clauses add block arguments to the OpenMP region, we make sure
+! that the block arguments are added in the proper order (reductions first and
+! then delayed privatization.
+
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --openmp-enable-delayed-privatization -o - %s 2>&1 | FileCheck %s
+
+subroutine red_and_delayed_private
+    integer :: red
+    integer :: prv
+
+    red = 0
+    prv = 10
+
+    !$omp parallel reduction(+:red) private(prv)
+    red = red + 1
+    prv = 20
+    !$omp end parallel
+end subroutine
+
+! CHECK-LABEL: omp.private {type = private}
+! CHECK-SAME: @[[PRIVATIZER_SYM:.*]] : !fir.ref<i32> alloc {
+
+! CHECK-LABEL: omp.reduction.declare
+! CHECK-SAME: @[[REDUCTION_SYM:.*]] : i32 init
+
+! CHECK-LABEL: _QPred_and_delayed_private
+! CHECK: omp.parallel
+! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {



More information about the flang-commits mailing list