[flang] [mlir] [llvm] [flang][OpenMP][RFC] Add support for COPYPRIVATE (PR #73128)

Leandro Lupori via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 20 09:02:04 PST 2023


https://github.com/luporl updated https://github.com/llvm/llvm-project/pull/73128

>From 59685602bab72edb3946542e925454714b14ee73 Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Wed, 20 Dec 2023 11:05:47 -0300
Subject: [PATCH] [flang][mlir][OpenMP] Add support for COPYPRIVATE

Add initial handling of OpenMP COPYPRIVATE clause in Flang.

MLIR's omp.single operation was modified to support an optional
CopyPrivateVarList. It consists of pairs of variables and
functions. When present, each thread variable is updated with the
variable value of the thread that executed the single region,
using the specified functions to perform the copy.

When lowering COPYPRIVATE, Flang then generates the copy function
needed by each variable and builds the appropriate
CopyPrivateVarList. The translation to LLVM IR is done in
OMPIRBuilder, by calling createCopyPrivate() for each variable in
the list, which generates calls to __kmpc_copyprivate.

Fixes https://github.com/llvm/llvm-project/issues/63933
---
 flang/include/flang/Lower/AbstractConverter.h |   3 +
 flang/lib/Lower/Bridge.cpp                    | 137 ++++++-----
 flang/lib/Lower/OpenMP.cpp                    | 212 +++++++++++++++++-
 flang/lib/Semantics/resolve-directives.cpp    |   3 +-
 flang/test/Lower/OpenMP/Todo/copyprivate.f90  |  13 --
 flang/test/Lower/OpenMP/copyprivate.f90       |  48 ++++
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   6 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  24 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  10 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 105 ++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  20 +-
 11 files changed, 498 insertions(+), 83 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/copyprivate.f90
 create mode 100644 flang/test/Lower/OpenMP/copyprivate.f90

diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index b91303387f3d71..2d289ec6a984f9 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -118,6 +118,9 @@ class AbstractConverter {
       const Fortran::semantics::Symbol &sym,
       mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) = 0;
 
+  virtual void copyVar(mlir::Location loc, mlir::Value dst,
+                       mlir::Value src) = 0;
+
   /// For a given symbol, check if it is present in the inner-most
   /// level of the symbol map.
   virtual bool isPresentShallowLookup(Fortran::semantics::Symbol &sym) = 0;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 6ca910d2696742..0b9d25c8025dc8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -740,6 +740,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         });
   }
 
+  void copyVar(mlir::Location loc, mlir::Value dst,
+               mlir::Value src) override final {
+    copyVarHLFIR(loc, dst, src);
+  }
+
   void copyHostAssociateVar(
       const Fortran::semantics::Symbol &sym,
       mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) override final {
@@ -774,64 +779,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       rhs_sb = &hsb;
     }
 
-    mlir::Location loc = genLocation(sym.name());
-
-    if (lowerToHighLevelFIR()) {
-      hlfir::Entity lhs{lhs_sb->getAddr()};
-      hlfir::Entity rhs{rhs_sb->getAddr()};
-      // Temporary_lhs is set to true in hlfir.assign below to avoid user
-      // assignment to be used and finalization to be called on the LHS.
-      // This may or may not be correct but mimics the current behaviour
-      // without HLFIR.
-      auto copyData = [&](hlfir::Entity l, hlfir::Entity r) {
-        // Dereference RHS and load it if trivial scalar.
-        r = hlfir::loadTrivialScalar(loc, *builder, r);
-        builder->create<hlfir::AssignOp>(
-            loc, r, l,
-            /*isWholeAllocatableAssignment=*/false,
-            /*keepLhsLengthInAllocatableAssignment=*/false,
-            /*temporary_lhs=*/true);
-      };
-      if (lhs.isAllocatable()) {
-        // Deep copy allocatable if it is allocated.
-        // Note that when allocated, the RHS is already allocated with the LHS
-        // shape for copy on entry in createHostAssociateVarClone.
-        // For lastprivate, this assumes that the RHS was not reallocated in
-        // the OpenMP region.
-        lhs = hlfir::derefPointersAndAllocatables(loc, *builder, lhs);
-        mlir::Value addr = hlfir::genVariableRawAddress(loc, *builder, lhs);
-        mlir::Value isAllocated = builder->genIsNotNullAddr(loc, addr);
-        builder->genIfThen(loc, isAllocated)
-            .genThen([&]() {
-              // Copy the DATA, not the descriptors.
-              copyData(lhs, rhs);
-            })
-            .end();
-      } else if (lhs.isPointer()) {
-        // Set LHS target to the target of RHS (do not copy the RHS
-        // target data into the LHS target storage).
-        auto loadVal = builder->create<fir::LoadOp>(loc, rhs);
-        builder->create<fir::StoreOp>(loc, loadVal, lhs);
-      } else {
-        // Non ALLOCATABLE/POINTER variable. Simple DATA copy.
-        copyData(lhs, rhs);
-      }
-    } else {
-      fir::ExtendedValue lhs = symBoxToExtendedValue(*lhs_sb);
-      fir::ExtendedValue rhs = symBoxToExtendedValue(*rhs_sb);
-      mlir::Type symType = genType(sym);
-      if (auto seqTy = symType.dyn_cast<fir::SequenceType>()) {
-        Fortran::lower::StatementContext stmtCtx;
-        Fortran::lower::createSomeArrayAssignment(*this, lhs, rhs, localSymbols,
-                                                  stmtCtx);
-        stmtCtx.finalizeAndReset();
-      } else if (lhs.getBoxOf<fir::CharBoxValue>()) {
-        fir::factory::CharacterExprHelper{*builder, loc}.createAssign(lhs, rhs);
-      } else {
-        auto loadVal = builder->create<fir::LoadOp>(loc, fir::getBase(rhs));
-        builder->create<fir::StoreOp>(loc, loadVal, fir::getBase(lhs));
-      }
-    }
+    copyVar(sym, *lhs_sb, *rhs_sb);
 
     if (copyAssignIP && copyAssignIP->isSet() &&
         sym.test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) {
@@ -1089,6 +1037,79 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return true;
   }
 
+  void copyVar(const Fortran::semantics::Symbol &sym,
+               const Fortran::lower::SymbolBox &lhs_sb,
+               const Fortran::lower::SymbolBox &rhs_sb) {
+    mlir::Location loc = genLocation(sym.name());
+    if (lowerToHighLevelFIR())
+      copyVarHLFIR(loc, lhs_sb.getAddr(), rhs_sb.getAddr());
+    else
+      copyVarFIR(loc, sym, lhs_sb, rhs_sb);
+  }
+
+  void copyVarHLFIR(mlir::Location loc, mlir::Value dst, mlir::Value src) {
+    assert(lowerToHighLevelFIR());
+    hlfir::Entity lhs{dst};
+    hlfir::Entity rhs{src};
+    // Temporary_lhs is set to true in hlfir.assign below to avoid user
+    // assignment to be used and finalization to be called on the LHS.
+    // This may or may not be correct but mimics the current behaviour
+    // without HLFIR.
+    auto copyData = [&](hlfir::Entity l, hlfir::Entity r) {
+      // Dereference RHS and load it if trivial scalar.
+      r = hlfir::loadTrivialScalar(loc, *builder, r);
+      builder->create<hlfir::AssignOp>(
+          loc, r, l,
+          /*isWholeAllocatableAssignment=*/false,
+          /*keepLhsLengthInAllocatableAssignment=*/false,
+          /*temporary_lhs=*/true);
+    };
+    if (lhs.isAllocatable()) {
+      // Deep copy allocatable if it is allocated.
+      // Note that when allocated, the RHS is already allocated with the LHS
+      // shape for copy on entry in createHostAssociateVarClone.
+      // For lastprivate, this assumes that the RHS was not reallocated in
+      // the OpenMP region.
+      lhs = hlfir::derefPointersAndAllocatables(loc, *builder, lhs);
+      mlir::Value addr = hlfir::genVariableRawAddress(loc, *builder, lhs);
+      mlir::Value isAllocated = builder->genIsNotNullAddr(loc, addr);
+      builder->genIfThen(loc, isAllocated)
+          .genThen([&]() {
+            // Copy the DATA, not the descriptors.
+            copyData(lhs, rhs);
+          })
+          .end();
+    } else if (lhs.isPointer()) {
+      // Set LHS target to the target of RHS (do not copy the RHS
+      // target data into the LHS target storage).
+      auto loadVal = builder->create<fir::LoadOp>(loc, rhs);
+      builder->create<fir::StoreOp>(loc, loadVal, lhs);
+    } else {
+      // Non ALLOCATABLE/POINTER variable. Simple DATA copy.
+      copyData(lhs, rhs);
+    }
+  }
+
+  void copyVarFIR(mlir::Location loc, const Fortran::semantics::Symbol &sym,
+                  const Fortran::lower::SymbolBox &lhs_sb,
+                  const Fortran::lower::SymbolBox &rhs_sb) {
+    assert(!lowerToHighLevelFIR());
+    fir::ExtendedValue lhs = symBoxToExtendedValue(lhs_sb);
+    fir::ExtendedValue rhs = symBoxToExtendedValue(rhs_sb);
+    mlir::Type symType = genType(sym);
+    if (auto seqTy = symType.dyn_cast<fir::SequenceType>()) {
+      Fortran::lower::StatementContext stmtCtx;
+      Fortran::lower::createSomeArrayAssignment(*this, lhs, rhs, localSymbols,
+                                                stmtCtx);
+      stmtCtx.finalizeAndReset();
+    } else if (lhs.getBoxOf<fir::CharBoxValue>()) {
+      fir::factory::CharacterExprHelper{*builder, loc}.createAssign(lhs, rhs);
+    } else {
+      auto loadVal = builder->create<fir::LoadOp>(loc, fir::getBase(rhs));
+      builder->create<fir::StoreOp>(loc, loadVal, fir::getBase(lhs));
+    }
+  }
+
   /// Map a block argument to a result or dummy symbol. This is not the
   /// definitive mapping. The specification expression have not been lowered
   /// yet. The final mapping will be done using this pre-mapping in
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 764d2175c0a962..87a2eae7189ed9 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -562,6 +562,10 @@ class ClauseProcessor {
   processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
                   llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
   bool processCopyin() const;
+  bool processCopyPrivate(
+      mlir::Location currentLocation,
+      llvm::SmallVectorImpl<mlir::Value> &copyPrivateVars,
+      llvm::SmallVectorImpl<mlir::Attribute> &copyPrivateFuncs) const;
   bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
                      llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
   bool
@@ -880,6 +884,156 @@ createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
   return decl;
 }
 
+/// Class that extracts information from the specified type.
+class TypeInfo {
+public:
+  TypeInfo(mlir::Location loc, mlir::Type ty) : loc(loc) {
+    name = typeScan(ty);
+  }
+
+  // Returns a textual representation of the type, with characters that are
+  // valid in identifiers.
+  const std::string &getName() const { return name; }
+
+  // Returns the length of character types.
+  std::optional<fir::CharacterType::LenType> getCharLength() const {
+    return charLen;
+  }
+
+  // Returns the shape of array types.
+  const llvm::SmallVector<int64_t> &getShape() const { return shape; }
+
+  // Is the type inside a box?
+  bool isBox() const { return inBox; }
+
+private:
+  // Scan type and return an unique name for it.
+  std::string typeScan(mlir::Type type);
+
+  mlir::Location loc;
+  std::string name;
+  std::optional<fir::CharacterType::LenType> charLen;
+  llvm::SmallVector<int64_t> shape;
+  bool inBox = false;
+};
+
+std::string TypeInfo::typeScan(mlir::Type ty) {
+  std::ostringstream ss;
+
+  auto unexpectedType = [&] {
+    std::string errmsg;
+    llvm::raw_string_ostream rss(errmsg);
+    rss << "Unexpected type: " << ty;
+    fir::emitFatalError(loc, errmsg);
+  };
+
+  if (auto aty = mlir::dyn_cast<fir::SequenceType>(ty)) {
+    // array -> A<rank>(_<extent>)+_<eleTy>
+    assert(shape.empty() && !aty.getShape().empty());
+    shape = llvm::SmallVector<int64_t>(aty.getShape());
+    ss << "A" << aty.getShape().size();
+    for (auto extent : aty.getShape()) {
+      assert(extent > 0 ||
+             extent == aty.getUnknownExtent() && "Unexpected array extent");
+      if (extent == aty.getUnknownExtent())
+        ss << "_u";
+      else
+        ss << "_" << extent;
+    }
+    ss << "_" << typeScan(aty.getEleTy());
+  } else if (auto dty = mlir::dyn_cast<fir::RecordType>(ty)) {
+    ss << "D" << dty.getName().str();
+  } else if (auto bty = mlir::dyn_cast<fir::BoxType>(ty)) {
+    inBox = true;
+    // allocatable (box<heap<...>>)
+    if (auto hty = mlir::dyn_cast<fir::HeapType>(bty.getEleTy()))
+      ss << "H" << typeScan(hty.getEleTy());
+    // pointer (box<ptr<...>>)
+    else if (auto pty = mlir::dyn_cast<fir::PointerType>(bty.getEleTy()))
+      ss << "P" << typeScan(pty.getEleTy());
+    else
+      unexpectedType();
+  } else if (auto sty = mlir::dyn_cast<fir::CharacterType>(ty)) {
+    // character -> s<kind>l<len>
+    fir::CharacterType::LenType len = sty.getLen();
+    assert(len > 0 || len == fir::CharacterType::unknownLen() &&
+                          "Unexpected character length");
+    charLen = len;
+    ss << "s" << sty.getFKind() << "l";
+    if (len == fir::CharacterType::unknownLen())
+      ss << "u";
+    else
+      ss << len;
+  } else if (auto cty = mlir::dyn_cast<fir::ComplexType>(ty)) {
+    ss << "c" << cty.getFKind();
+  } else if (auto lty = mlir::dyn_cast<fir::LogicalType>(ty)) {
+    ss << "l" << lty.getFKind();
+  } else if (ty.isIntOrIndexOrFloat()) {
+    if (ty.isIntOrIndex())
+      ss << "i";
+    else
+      ss << "f";
+    ss << ty.getIntOrFloatBitWidth();
+  } else {
+    unexpectedType();
+  }
+  return ss.str();
+}
+
+// Create a function that performs a copy between two variables, compatible
+// with their types and attributes.
+static mlir::func::FuncOp
+createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
+               mlir::Type varType, fir::FortranVariableFlagsEnum varAttrs) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  mlir::ModuleOp module = builder.getModule();
+  TypeInfo typeInfo(loc,
+                    mlir::cast<fir::ReferenceType>(varType).getElementType());
+  std::string copyFuncName = std::string("_copy_") + typeInfo.getName();
+
+  if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+    return decl;
+
+  // create function
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+  llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+  auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+  mlir::func::FuncOp funcOp =
+      modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+  funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+  builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+                      {loc, loc});
+  builder.setInsertionPointToStart(&funcOp.getRegion().back());
+  // generate body
+  fir::FortranVariableFlagsAttr attrs;
+  if (varAttrs != fir::FortranVariableFlagsEnum::None)
+    attrs = fir::FortranVariableFlagsAttr::get(builder.getContext(), varAttrs);
+  llvm::SmallVector<mlir::Value> typeparams;
+  if (typeInfo.getCharLength().has_value()) {
+    mlir::Value charLen = builder.createIntegerConstant(
+        loc, builder.getCharacterLengthType(), *typeInfo.getCharLength());
+    typeparams.push_back(charLen);
+  }
+  mlir::Value shape;
+  if (!typeInfo.isBox() && !typeInfo.getShape().empty()) {
+    llvm::SmallVector<mlir::Value> extents;
+    for (auto extent : typeInfo.getShape())
+      extents.push_back(
+          builder.createIntegerConstant(loc, builder.getIndexType(), extent));
+    shape = builder.create<fir::ShapeOp>(loc, extents);
+  }
+  auto declDst = builder.create<hlfir::DeclareOp>(loc, funcOp.getArgument(0),
+                                                  copyFuncName + "_dst", shape,
+                                                  typeparams, attrs);
+  auto declSrc = builder.create<hlfir::DeclareOp>(loc, funcOp.getArgument(1),
+                                                  copyFuncName + "_src", shape,
+                                                  typeparams, attrs);
+  converter.copyVar(loc, declDst.getBase(), declSrc.getBase());
+  builder.create<mlir::func::ReturnOp>(loc);
+  return funcOp;
+}
+
 /// Creates an OpenMP reduction declaration and inserts it into the provided
 /// symbol table. The declaration has a constant initializer with the neutral
 /// value `initValue`, and the reduction combiner carried over from `reduce`.
@@ -1634,6 +1788,46 @@ bool ClauseProcessor::processCopyin() const {
   return hasCopyin;
 }
 
+bool ClauseProcessor::processCopyPrivate(
+    mlir::Location currentLocation,
+    llvm::SmallVectorImpl<mlir::Value> &copyPrivateVars,
+    llvm::SmallVectorImpl<mlir::Attribute> &copyPrivateFuncs) const {
+  auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
+    mlir::Value symVal = converter.getSymbolAddress(*sym);
+    auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
+    if (!declOp)
+      fir::emitFatalError(currentLocation,
+                          "COPYPRIVATE is supported only in HLFIR mode");
+    symVal = declOp.getBase();
+    fir::FortranVariableFlagsEnum attrs = fir::FortranVariableFlagsEnum::None;
+    if (declOp.getFortranAttrs().has_value())
+      attrs = *declOp.getFortranAttrs();
+    copyPrivateVars.push_back(symVal);
+    mlir::func::FuncOp funcOp =
+        createCopyFunc(currentLocation, converter, symVal.getType(), attrs);
+    copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
+  };
+
+  bool hasCopyPrivate = findRepeatableClause<ClauseTy::Copyprivate>(
+      [&](const ClauseTy::Copyprivate *copyPrivateClause,
+          const Fortran::parser::CharBlock &) {
+        const Fortran::parser::OmpObjectList &ompObjectList =
+            copyPrivateClause->v;
+        for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
+          Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+          if (const auto *commonDetails =
+                  sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
+            for (const auto &mem : commonDetails->objects())
+              addCopyPrivateVar(&*mem);
+            break;
+          }
+          addCopyPrivateVar(sym);
+        }
+      });
+
+  return hasCopyPrivate;
+}
+
 bool ClauseProcessor::processDepend(
     llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
     llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
@@ -2311,18 +2505,25 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
             const Fortran::parser::OmpClauseList &beginClauseList,
             const Fortran::parser::OmpClauseList &endClauseList) {
   llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
+  llvm::SmallVector<mlir::Value> copyPrivateVars;
+  llvm::SmallVector<mlir::Attribute> copyPrivateFuncs;
   mlir::UnitAttr nowaitAttr;
 
   ClauseProcessor cp(converter, beginClauseList);
   cp.processAllocate(allocatorOperands, allocateOperands);
-  cp.processTODO<Fortran::parser::OmpClause::Copyprivate>(
-      currentLocation, llvm::omp::Directive::OMPD_single);
 
-  ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr);
+  ClauseProcessor ecp(converter, endClauseList);
+  ecp.processNowait(nowaitAttr);
+  ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs);
 
   return genOpWithBody<mlir::omp::SingleOp>(
       converter, eval, currentLocation, /*outerCombined=*/false,
-      &beginClauseList, allocateOperands, allocatorOperands, nowaitAttr);
+      &beginClauseList, allocateOperands, allocatorOperands, copyPrivateVars,
+      copyPrivateFuncs.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 copyPrivateFuncs),
+      nowaitAttr);
 }
 
 static mlir::omp::TaskOp
@@ -3105,7 +3306,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
 
   for (const auto &clause : endClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
-    if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
+    if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::Copyprivate>(&clause.u))
       TODO(clauseLocation, "OpenMP Block construct clause");
   }
 
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index da6c865ad56a3b..e59cdd7b439b4a 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -2382,7 +2382,8 @@ void OmpAttributeVisitor::CheckDataCopyingClause(
       // either 'private' or 'threadprivate' in enclosing context.
       if (!checkSymbol->test(Symbol::Flag::OmpThreadprivate) &&
           !(HasSymbolInEnclosingScope(symbol, currScope()) &&
-              symbol.test(Symbol::Flag::OmpPrivate))) {
+              (symbol.test(Symbol::Flag::OmpPrivate) ||
+                  symbol.test(Symbol::Flag::OmpFirstPrivate)))) {
         context_.Say(name.source,
             "COPYPRIVATE variable '%s' is not PRIVATE or THREADPRIVATE in "
             "outer context"_err_en_US,
diff --git a/flang/test/Lower/OpenMP/Todo/copyprivate.f90 b/flang/test/Lower/OpenMP/Todo/copyprivate.f90
deleted file mode 100644
index 0d871427ce60ff..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/copyprivate.f90
+++ /dev/null
@@ -1,13 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: OpenMP Block construct clause
-subroutine sb
-  integer, save :: a
-  !$omp threadprivate(a)
-  !$omp parallel
-  !$omp single
-  a = 3
-  !$omp end single copyprivate(a)
-  !$omp end parallel
-end subroutine
diff --git a/flang/test/Lower/OpenMP/copyprivate.f90 b/flang/test/Lower/OpenMP/copyprivate.f90
new file mode 100644
index 00000000000000..0d2740c77896f0
--- /dev/null
+++ b/flang/test/Lower/OpenMP/copyprivate.f90
@@ -0,0 +1,48 @@
+! Test COPYPRIVATE.
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-DAG: func private @_copy_f32(%{{.*}}: !fir.ref<f32>, %{{.*}}: !fir.ref<f32>)
+!CHECK-DAG: func private @_copy_A1_10_i32(%{{.*}}: !fir.ref<!fir.array<10xi32>>, %{{.*}}: !fir.ref<!fir.array<10xi32>>)
+
+!CHECK-LABEL: func private @_copy_i32(
+!CHECK-SAME:                  %[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<i32>) {
+!CHECK-NEXT:    %[[DST:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_copy_i32_dst"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK-NEXT:    %[[SRC:.*]]:2 = hlfir.declare %[[ARG1]] {uniq_name = "_copy_i32_src"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK-NEXT:    %[[SRC_VAL:.*]] = fir.load %[[SRC]]#0 : !fir.ref<i32>
+!CHECK-NEXT:    hlfir.assign %[[SRC_VAL]] to %[[DST]]#0 temporary_lhs : i32, !fir.ref<i32>
+!CHECK-NEXT:    return
+!CHECK-NEXT:  }
+
+!CHECK-LABEL: func @_QPtest_scalar
+!CHECK:         omp.parallel
+!CHECK:           %[[I:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_scalarEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:           %[[J:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_scalarEj"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:           %[[K:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_scalarEk"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+!CHECK:           omp.single copyprivate(%[[I]]#0 -> @_copy_i32 : !fir.ref<i32>, %[[J]]#0 -> @_copy_i32 : !fir.ref<i32>, %[[K]]#0 -> @_copy_f32 : !fir.ref<f32>)
+subroutine test_scalar()
+  integer, save :: i, j
+  !$omp threadprivate(i, j)
+  real :: k
+
+  k = 33.3
+  !$omp parallel firstprivate(k)
+  !$omp single
+  i = 11
+  j = 22
+  !$omp end single copyprivate(i, j, k)
+  !$omp end parallel
+end subroutine
+
+!CHECK-LABEL: func @_QPtest_array
+!CHECK:         omp.parallel
+!CHECK:           %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFtest_arrayEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+!CHECK:           omp.single copyprivate(%[[A]]#0 -> @_copy_A1_10_i32 : !fir.ref<!fir.array<10xi32>>)
+subroutine test_array()
+  integer :: a(10)
+
+  !$omp parallel private(a)
+  !$omp single
+  a = 100
+  !$omp end single copyprivate(a)
+  !$omp end parallel
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index abbef03d02cb10..09a4e7f130e609 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1819,12 +1819,16 @@ class OpenMPIRBuilder {
   /// \param FiniCB Callback to finalize variable copies.
   /// \param IsNowait If false, a barrier is emitted.
   /// \param DidIt Local variable used as a flag to indicate 'single' thread
+  /// \param CPVars copyprivate variables.
+  /// \param CPFuncs copy functions to use for each copyprivate variable.
   ///
   /// \returns The insertion position *after* the single call.
   InsertPointTy createSingle(const LocationDescription &Loc,
                              BodyGenCallbackTy BodyGenCB,
                              FinalizeCallbackTy FiniCB, bool IsNowait,
-                             llvm::Value *DidIt);
+                             llvm::Value *DidIt,
+                             const SmallVector<llvm::Value *> &CPVars = {},
+                             const SmallVector<llvm::Function *> &CPFuncs = {});
 
   /// Generator for '#omp master'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ce428f78dc843e..91036928cfeedc 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3992,7 +3992,9 @@ OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
 
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
-    FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
+    FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt,
+    const SmallVector<llvm::Value *> &CPVars,
+    const SmallVector<llvm::Function *> &CPFuncs) {
 
   if (!updateToLocation(Loc))
     return Loc.IP;
@@ -4015,17 +4017,33 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
 
+  auto FiniCBWrapper = [&](InsertPointTy IP) {
+    FiniCB(IP);
+
+    if (DidIt)
+      Builder.CreateStore(Builder.getInt32(1), DidIt);
+  };
+
   // generates the following:
   // if (__kmpc_single()) {
   //		.... single region ...
   // 		__kmpc_end_single
   // }
+  // __kmpc_copyprivate
   // __kmpc_barrier
 
-  EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
+  EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
                        /*Conditional*/ true,
                        /*hasFinalize*/ true);
-  if (!IsNowait)
+
+  if (DidIt) {
+    for (size_t I = 0, E = CPVars.size(); I < E; ++I)
+      // NOTE BufSize is currently unused, so just pass 0.
+      createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
+                        /*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
+                        CPFuncs[I], DidIt);
+    // NOTE __kmpc_copyprivate already inserts a barrier
+  } else if (!IsNowait)
     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
                   /* CheckCancelFlag */ false);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index b9989b335a2aef..1d249a3d7c00b9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -386,10 +386,16 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
     master thread), in the context of its implicit task. The other threads
     in the team, which do not execute the block, wait at an implicit barrier
     at the end of the single construct unless a nowait clause is specified.
+
+    If copyprivate variables and functions are specified, then each thread
+    variable is updated with the variable value of the thread that executed
+    the single region, using the specified copy functions.
   }];
 
   let arguments = (ins Variadic<AnyType>:$allocate_vars,
                        Variadic<AnyType>:$allocators_vars,
+                       Variadic<OpenMP_PointerLikeType>:$copyprivate_vars,
+                       OptionalAttr<SymbolRefArrayAttr>:$copyprivate_funcs,
                        UnitAttr:$nowait);
 
   let regions = (region AnyRegion:$region);
@@ -401,6 +407,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
                 $allocators_vars, type($allocators_vars)
               ) `)`
           |`nowait` $nowait
+          |`copyprivate` `(`
+              custom<CopyPrivateVarList>(
+                $copyprivate_vars, type($copyprivate_vars), $copyprivate_funcs
+              ) `)`
     ) $region attr-dict
   }];
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6e69cd0d386bd2..015620eb0a623b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -512,6 +512,108 @@ static LogicalResult verifyReductionVarList(Operation *op,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Parser, printer and verifier for CopyPrivateVarList
+//===----------------------------------------------------------------------===//
+
+/// copyprivate-entry-list ::= copyprivate-entry
+///                          | copyprivate-entry-list `,` copyprivate-entry
+/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
+static ParseResult parseCopyPrivateVarList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+    SmallVectorImpl<Type> &types, ArrayAttr &copyPrivateSymbols) {
+  SmallVector<SymbolRefAttr> copyPrivateFuncsVec;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseOperand(operands.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
+  SmallVector<Attribute> copyPrivateFuncs(copyPrivateFuncsVec.begin(),
+                                          copyPrivateFuncsVec.end());
+  copyPrivateSymbols = ArrayAttr::get(parser.getContext(), copyPrivateFuncs);
+  return success();
+}
+
+/// Print CopyPrivate clause
+static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op,
+                                    OperandRange copyPrivateVars,
+                                    TypeRange copyPrivateTypes,
+                                    std::optional<ArrayAttr> copyPrivateFuncs) {
+  assert(copyPrivateFuncs.has_value() || copyPrivateVars.empty());
+  for (unsigned i = 0, e = copyPrivateVars.size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << copyPrivateVars[i] << " -> " << (*copyPrivateFuncs)[i] << " : "
+      << copyPrivateTypes[i];
+  }
+}
+
+/// Verifies CopyPrivate Clause
+static LogicalResult
+verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars,
+                         std::optional<ArrayAttr> copyPrivateFuncs) {
+  if (!copyPrivateVars.empty()) {
+    if (!copyPrivateFuncs || copyPrivateFuncs->size() != copyPrivateVars.size())
+      return op->emitOpError() << "expected as many copyPrivate functions as "
+                                  "copyPrivate variables";
+  } else {
+    if (copyPrivateFuncs)
+      return op->emitOpError() << "unexpected copyPrivate functions";
+    return success();
+  }
+
+  for (auto args : llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
+    auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
+    std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
+        funcOp;
+    if (mlir::func::FuncOp mlirFuncOp =
+            SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
+                                                                     symbolRef))
+      funcOp = mlirFuncOp;
+    else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
+                 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
+                     op, symbolRef))
+      funcOp = llvmFuncOp;
+
+    auto getNumArguments = [&] {
+      return std::visit([](auto &f) { return f.getArguments().size(); },
+                        *funcOp);
+    };
+
+    auto getArgumentType = [&](unsigned i) {
+      return std::visit([i](auto &f) { return f.getArgument(i).getType(); },
+                        *funcOp);
+    };
+
+    if (!funcOp)
+      return op->emitOpError() << "expected symbol reference " << symbolRef
+                               << " to point to a copy function";
+
+    if (getNumArguments() != 2)
+      return op->emitOpError()
+             << "expected copy function " << symbolRef << " to have 2 operands";
+
+    Type argTy = getArgumentType(0);
+    if (argTy != getArgumentType(1))
+      return op->emitOpError() << "expected copy function " << symbolRef
+                               << " arguments to have the same type";
+
+    Type varType = std::get<0>(args).getType();
+    if (argTy != varType)
+      return op->emitOpError()
+             << "expected copy function arguments' type (" << argTy
+             << ") to be the same as copyprivate variable's type (" << varType
+             << ")";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for DependVarList
 //===----------------------------------------------------------------------===//
@@ -1079,7 +1181,8 @@ LogicalResult SingleOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  return success();
+  return verifyCopyPrivateVarList(*this, getCopyprivateVars(),
+                                  getCopyprivateFuncs());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..293a6f69616b79 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -656,8 +656,26 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
                         moduleTranslation, bodyGenStatus);
   };
   auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+  // Handle copyprivate
+  Operation::operand_range cpVars = singleOp.getCopyprivateVars();
+  std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateFuncs();
+  llvm::SmallVector<llvm::Value *> llvmCPVars;
+  llvm::SmallVector<llvm::Function *> llvmCPFuncs;
+  for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
+    llvmCPVars.push_back(moduleTranslation.lookupValue(cpVars[i]));
+    auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
+        singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
+    llvmCPFuncs.push_back(
+        moduleTranslation.lookupFunction(llvmFuncOp.getName()));
+  }
+  llvm::Value *didIt = nullptr;
+  if (!llvmCPVars.empty())
+    didIt = builder.CreateAlloca(llvm::Type::getInt32Ty(builder.getContext()));
+
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
-      ompLoc, bodyCB, finiCB, singleOp.getNowait(), /*DidIt=*/nullptr));
+      ompLoc, bodyCB, finiCB, singleOp.getNowait(), didIt, llvmCPVars,
+      llvmCPFuncs));
   return bodyGenStatus;
 }
 



More information about the llvm-commits mailing list