[flang] [llvm] [WIP][PoC][flang] Re-use OpenMP data environemnt clauses for locality spec (PR #128148)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 00:35:43 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

<details>
<summary>Changes</summary>

**This is a PoC to write a proper RFC based on later. This is not meant to be merged!**

Now that we started working on mapping `do concurrent` loop nests to corresponding OpenMP constructs (and later to OpenACC), we come across the following problem: How can we map `do concurrent'`s locality specifiers to their corresponding OpenMP/ACC data environment clauses?

This is not easy at the moment because locality specifiers are handled on the PFT to MLIR lowering level which makes discovering the ops corresponding to them more difficult (or even not possible) during `do concurrent` to OpenMP mapping.

One way to handle this problem would be use something similar to delayed privatization that we have been working on for the OpenMP dialect recently. So on the MLIR level, the following `do concurrent` loop:
```fortran
subroutine foo
  implicit none
  integer :: i, local_var!, local_init_var

  do concurrent (i=1:10) local(local_var) local_init(local_init_var)
    if (i < 5) then
      local_var = 42
    else 
      !local_init_var = 84
    end if
  end do
end subroutine
```
would look something like this:
```mlir
    %0 = fir.alloca i32 {bindc_name = "i"}
    %1:2 = hlfir.declare %0 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFomploopEi"}
    %3:2 = hlfir.declare %2 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %4 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFomploopElocal_init_var"}
    %5:2 = hlfir.declare %4 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %6 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFomploopElocal_var"}
    %7:2 = hlfir.declare %6 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %c1_i32 = arith.constant 1 : i32
    %8 = fir.convert %c1_i32 : (i32) -> index
    %c10_i32 = arith.constant 10 : i32
    %9 = fir.convert %c10_i32 : (i32) -> index
    %c1 = arith.constant 1 : index
    // Instead of using "private" we can use "local".
    fir.do_loop %arg0 = %8 to %9 step %c1 unordered private(@<!-- -->local_privatizer %7#<!-- -->0 -> %arg1, @<!-- -->local_init_privatizer %5#<!-- -->0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
      %10 = fir.convert %arg0 : (index) -> i32
      fir.store %10 to %1#<!-- -->1 : !fir.ref<i32>
      %11:2 = hlfir.declare %arg1 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
      %12:2 = hlfir.declare %arg2 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
      %13 = fir.load %1#<!-- -->0 : !fir.ref<i32>
      %c5_i32 = arith.constant 5 : i32
      %14 = arith.cmpi slt, %13, %c5_i32 : i32
      fir.if %14 {
        %c42_i32 = arith.constant 42 : i32
        hlfir.assign %c42_i32 to %11#<!-- -->0 : i32, !fir.ref<i32>
      } else {
        %c84_i32 = arith.constant 84 : i32
        hlfir.assign %c84_i32 to %12#<!-- -->0 : i32, !fir.ref<i32>
      }
    }
```

To that end, it would be nice to:
1. Extract the table-gen records we already for OpenMP into a separate "Data Environment" dialect.
2. Use the records in that dialect for both of OpenMP and `do concurrent` (and later for OpenACC).
3. We can do this hopefully for both local/private-related clauses/specifiers as well as reduction.

This is a PoC to validate that idea. For now it only reuses the OpenMP stuff just to showcase how it looks like for `do concurrent`. The PoC contains a sample to test the current prototyped functionality.

Current status of the PoC:
- [x] Extend `fir.do_loop` to reuse OpenMP clause table-gen records
- [x] Parsing and printing for `fir.do_loop` with `private` specifiers
- [x] Basic lowering of `fir.do_loop` `local` specifiers
- [x] Basic lowering of `fir.do_loop`'s `local_init` specifier
- [ ] PFT to MLIR lowring using the MLIR locality specifiers.

Each of the checked items above has a corresponding self-contained commit to demo the needed changes in that part of the pipeline.

---
Full diff: https://github.com/llvm/llvm-project/pull/128148.diff


5 Files Affected:

- (added) do_loop_with_local_and_local_init.mlir (+49) 
- (modified) flang/include/flang/Optimizer/Dialect/CMakeLists.txt (+2-2) 
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+29-8) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+89-17) 
- (modified) flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp (+57) 


``````````diff
diff --git a/do_loop_with_local_and_local_init.mlir b/do_loop_with_local_and_local_init.mlir
new file mode 100644
index 0000000000000..06510b4433f1a
--- /dev/null
+++ b/do_loop_with_local_and_local_init.mlir
@@ -0,0 +1,49 @@
+// For testing:
+// 1. parsing/printing (roundtripping): `fir-opt do_loop_with_local_and_local_init.mlir -o roundtrip.mlir`
+// 2. Lowering locality specs during CFG: `fir-opt --cfg-conversion do_loop_with_local_and_local_init.mlir -o after_cfg_lowering.mlir`
+
+// TODO I will add both of the above steps as proper tests when the PoC is complete.
+module attributes {dlti.dl_spec = #dlti.dl_spec<i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 21.0.0 (/home/kaergawy/git/aomp20.0/llvm-project/flang c8cf5a644886bb8dd3ad19be6e3b916ffcbd222c)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+
+  omp.private {type = private} @local_privatizer : i32
+
+  omp.private {type = firstprivate} @local_init_privatizer : i32 copy {
+  ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
+      %0 = fir.load %arg0 : !fir.ref<i32>
+      fir.store %0 to %arg1 : !fir.ref<i32>
+      omp.yield(%arg1 : !fir.ref<i32>)
+  }
+
+  func.func @_QPomploop() {
+    %0 = fir.alloca i32 {bindc_name = "i"}
+    %1:2 = hlfir.declare %0 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFomploopEi"}
+    %3:2 = hlfir.declare %2 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %4 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFomploopElocal_init_var"}
+    %5:2 = hlfir.declare %4 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %6 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFomploopElocal_var"}
+    %7:2 = hlfir.declare %6 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %c1_i32 = arith.constant 1 : i32
+    %8 = fir.convert %c1_i32 : (i32) -> index
+    %c10_i32 = arith.constant 10 : i32
+    %9 = fir.convert %c10_i32 : (i32) -> index
+    %c1 = arith.constant 1 : index
+    fir.do_loop %arg0 = %8 to %9 step %c1 unordered private(@local_privatizer %7#0 -> %arg1, @local_init_privatizer %5#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
+      %10 = fir.convert %arg0 : (index) -> i32
+      fir.store %10 to %1#1 : !fir.ref<i32>
+      %12:2 = hlfir.declare %arg1 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %14:2 = hlfir.declare %arg2 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %16 = fir.load %1#0 : !fir.ref<i32>
+      %c5_i32 = arith.constant 5 : i32
+      %17 = arith.cmpi slt, %16, %c5_i32 : i32
+      fir.if %17 {
+        %c42_i32 = arith.constant 42 : i32
+        hlfir.assign %c42_i32 to %12#0 : i32, !fir.ref<i32>
+      } else {
+        %c84_i32 = arith.constant 84 : i32
+        hlfir.assign %c84_i32 to %14#0 : i32, !fir.ref<i32>
+      }
+    }
+    return
+  }
+}
diff --git a/flang/include/flang/Optimizer/Dialect/CMakeLists.txt b/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
index 73f388cbab6c9..da14fcd25a8d3 100644
--- a/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
@@ -16,8 +16,8 @@ mlir_tablegen(FIRAttr.cpp.inc -gen-attrdef-defs)
 set(LLVM_TARGET_DEFINITIONS FIROps.td)
 mlir_tablegen(FIROps.h.inc -gen-op-decls)
 mlir_tablegen(FIROps.cpp.inc -gen-op-defs)
-mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls)
-mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs)
+mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls -typedefs-dialect=fir)
+mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs -typedefs-dialect=fir)
 add_public_tablegen_target(FIROpsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS FortranVariableInterface.td)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553d..34647263d6cc7 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -16,6 +16,7 @@
 
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
+include "mlir/Dialect/OpenMP/OpenMPClauses.td"
 include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
 include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
@@ -2171,7 +2172,7 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
   let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 
-  let arguments = (ins
+  defvar opArgs = (ins
     Index:$lowerBound,
     Index:$upperBound,
     Index:$step,
@@ -2182,6 +2183,8 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
     OptionalAttr<ArrayAttr>:$reduceAttrs,
     OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
   );
+
+  let arguments = !con(opArgs, OpenMP_PrivateClause.arguments);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
@@ -2193,24 +2196,38 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
       CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
       CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
       CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
-      CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
+      CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes,
+      CArg<"mlir::ValueRange", "std::nullopt">:$private_vars,
+      CArg<"mlir::ArrayRef<mlir::Attribute>", "{}">:$private_syms
+      )>
   ];
 
-  let extraClassDeclaration = [{
-    mlir::Value getInductionVar() { return getBody()->getArgument(0); }
+  defvar opExtraClassDeclaration = [{
     mlir::OpBuilder getBodyBuilder() {
       return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
     }
+
+    /// Region argument accessors.
+    mlir::Value getInductionVar() { return getBody()->getArgument(0); }
     mlir::Block::BlockArgListType getRegionIterArgs() {
-      return getBody()->getArguments().drop_front();
+      // 1 for skipping the induction variable.
+      return getBody()->getArguments().slice(1, getNumIterOperands());
+    }
+    mlir::Block::BlockArgListType getRegionPrivateArgs() {
+     return getBody()->getArguments().slice(1 + getNumIterOperands(),
+                                            numPrivateBlockArgs());
     }
+
+    /// Operation operand accessors.
     mlir::Operation::operand_range getIterOperands() {
       return getOperands()
-          .drop_front(getNumControlOperands() + getNumReduceOperands());
+          .slice(getNumControlOperands() + getNumReduceOperands(),
+                 getNumIterOperands());
     }
     llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
       return getOperation()->getOpOperands()
-          .drop_front(getNumControlOperands() + getNumReduceOperands());
+          .slice(getNumControlOperands() + getNumReduceOperands(),
+                 getNumIterOperands());
     }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2219,7 +2236,7 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
 
     /// Number of region arguments for loop-carried values
     unsigned getNumRegionIterArgs() {
-      return getBody()->getNumArguments() - 1;
+      return getNumIterOperands();
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
@@ -2258,6 +2275,10 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
                            unsigned resultNum);
     mlir::Value blockArgToSourceOp(unsigned blockArgNum);
   }];
+
+  let extraClassDeclaration =
+    !strconcat(opExtraClassDeclaration, "\n",
+               OpenMP_PrivateClause.extraClassDeclaration);
 }
 
 def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 7e50622db08c9..c729414cd2393 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2478,14 +2478,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
                           bool finalCountValue, mlir::ValueRange iterArgs,
                           mlir::ValueRange reduceOperands,
                           llvm::ArrayRef<mlir::Attribute> reduceAttrs,
-                          llvm::ArrayRef<mlir::NamedAttribute> attributes) {
+                          llvm::ArrayRef<mlir::NamedAttribute> attributes,
+                          mlir::ValueRange privateVars,
+                          mlir::ArrayRef<mlir::Attribute> privateSyms) {
   result.addOperands({lb, ub, step});
   result.addOperands(reduceOperands);
   result.addOperands(iterArgs);
   result.addAttribute(getOperandSegmentSizeAttr(),
                       builder.getDenseI32ArrayAttr(
                           {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
-                           static_cast<int32_t>(iterArgs.size())}));
+                           static_cast<int32_t>(iterArgs.size()), 0}));
   if (finalCountValue) {
     result.addTypes(builder.getIndexType());
     result.addAttribute(getFinalValueAttrName(result.name),
@@ -2561,8 +2563,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
 
   // Parse the optional initial iteration arguments.
   llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
-  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   llvm::SmallVector<mlir::Type> argTypes;
+
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   bool prependCount = false;
   regionArgs.push_back(inductionVariable);
 
@@ -2587,15 +2590,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     prependCount = true;
   }
 
-  // Set the operandSegmentSizes attribute
-  result.addAttribute(getOperandSegmentSizeAttr(),
-                      builder.getDenseI32ArrayAttr(
-                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
-                           static_cast<int32_t>(iterOperands.size())}));
-
-  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
-    return mlir::failure();
-
   // Induction variable.
   if (prependCount)
     result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name),
@@ -2604,15 +2598,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     argTypes.push_back(indexType);
   // Loop carried variables
   argTypes.append(result.types.begin(), result.types.end());
-  // Parse the body region.
-  auto *body = result.addRegion();
+
   if (regionArgs.size() != argTypes.size())
     return parser.emitError(
         parser.getNameLoc(),
         "mismatch in number of loop-carried values and defined values");
+
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
+  if (succeeded(parser.parseOptionalKeyword("private"))) {
+    std::size_t oldArgTypesSize = argTypes.size();
+    if (failed(parser.parseLParen()))
+      return mlir::failure();
+
+    llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
+            return mlir::failure();
+
+          if (parser.parseOperand(privateOperands.emplace_back()) ||
+              parser.parseArrow() ||
+              parser.parseArgument(regionArgs.emplace_back()))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (failed(parser.parseColon()))
+      return mlir::failure();
+
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseType(argTypes.emplace_back())))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (regionArgs.size() != argTypes.size())
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of private arg and types");
+
+    if (failed(parser.parseRParen()))
+      return mlir::failure();
+
+    for (auto operandType : llvm::zip_equal(
+             privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
+      if (parser.resolveOperand(std::get<0>(operandType),
+                                std::get<1>(operandType), result.operands))
+        return mlir::failure();
+
+    llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
+                                                   privateSymbolVec.end());
+    result.addAttribute(getPrivateSymsAttrName(result.name),
+                        builder.getArrayAttr(symbolAttrs));
+  }
+
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return mlir::failure();
+
+  // Set the operandSegmentSizes attribute
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterOperands.size()),
+                           static_cast<int32_t>(privateOperands.size())}));
+
   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
     regionArgs[i].type = argTypes[i];
 
+  // Parse the body region.
+  auto *body = result.addRegion();
   if (parser.parseRegion(*body, regionArgs))
     return mlir::failure();
 
@@ -2706,9 +2762,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     p << " -> " << getResultTypes();
     printBlockTerminators = true;
   }
-  p.printOptionalAttrDictWithKeyword(
-      (*this)->getAttrs(),
-      {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
+
+  if (numPrivateBlockArgs() > 0) {
+    p << " private(";
+    llvm::interleaveComma(llvm::zip_equal(getPrivateSymsAttr(),
+                                          getPrivateVars(),
+                                          getRegionPrivateArgs()),
+                          p, [&](auto it) {
+                            p << std::get<0>(it) << " " << std::get<1>(it)
+                              << " -> " << std::get<2>(it);
+                          });
+    p << " : ";
+    llvm::interleaveComma(getPrivateVars(), p,
+                          [&](auto it) { p << it.getType(); });
+    p << ")";
+  }
+
+  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
+                                     {"unordered", "finalValue", "reduceAttrs",
+                                      "operandSegmentSizes", "private_syms"});
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 printBlockTerminators);
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index b09bbf6106dbb..88779e6ebd977 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -32,6 +32,19 @@ using namespace fir;
 using namespace mlir;
 
 namespace {
+/// Looks up from the operation from and returns the PrivateClauseOp with
+/// name symbolName
+///
+/// TODO Copied from OpenMPToLLVMIRTranslation.cpp, move to a shared location.
+/// Maybe a static function on the `PrivateClauseOp`.
+static omp::PrivateClauseOp findPrivatizer(Operation *from,
+                                           SymbolRefAttr symbolName) {
+  omp::PrivateClauseOp privatizer =
+      SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
+                                                                 symbolName);
+  assert(privatizer && "privatizer not found in the symbol table");
+  return privatizer;
+}
 
 // Conversion of fir control ops to more primitive control-flow.
 //
@@ -57,6 +70,50 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
     auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
         rewriter.getContext(), flags);
 
+    // Handle privatization
+    if (!loop.getPrivateVars().empty()) {
+      mlir::OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&loop.getRegion().front());
+      std::optional<ArrayAttr> privateSyms = loop.getPrivateSyms();
+
+      for (auto [privateVar, privateArg, privatizerSym] :
+           llvm::zip_equal(loop.getPrivateVars(), loop.getRegionPrivateArgs(),
+                           *privateSyms)) {
+        SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privatizerSym);
+        omp::PrivateClauseOp privatizer = findPrivatizer(loop, privatizerName);
+
+        mlir::Value localAlloc =
+            rewriter.create<fir::AllocaOp>(loop.getLoc(), privatizer.getType());
+
+        if (privatizer.getDataSharingType() ==
+            omp::DataSharingClauseType::FirstPrivate) {
+          mlir::Block *beforeLocalInit = rewriter.getInsertionBlock();
+          mlir::Block *afterLocalInit = rewriter.splitBlock(
+              rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+          rewriter.cloneRegionBefore(privatizer.getCopyRegion(),
+                                     afterLocalInit);
+          mlir::Block* copyRegionFront = beforeLocalInit->getNextNode();
+          mlir::Block* copyRegionBack = afterLocalInit->getPrevNode();
+
+          rewriter.setInsertionPoint(beforeLocalInit, beforeLocalInit->end());
+          rewriter.create<mlir::cf::BranchOp>(
+              loc, copyRegionFront,
+              llvm::SmallVector<mlir::Value>{privateVar, privateArg});
+
+          rewriter.eraseOp(copyRegionBack->getTerminator());
+          rewriter.setInsertionPoint(copyRegionBack, copyRegionBack->end());
+          rewriter.create<mlir::cf::BranchOp>(loc, afterLocalInit);
+        }
+
+        rewriter.replaceAllUsesWith(privateArg, localAlloc);
+      }
+
+      loop.getRegion().front().eraseArguments(1 + loop.getNumRegionIterArgs(),
+                                              loop.numPrivateBlockArgs());
+      loop.getPrivateVarsMutable().clear();
+      loop.setPrivateSymsAttr(nullptr);
+    }
+
     // Create the start and end blocks that will wrap the DoLoopOp with an
     // initalizer and an end point
     auto *initBlock = rewriter.getInsertionBlock();

``````````

</details>


https://github.com/llvm/llvm-project/pull/128148


More information about the llvm-commits mailing list