[flang-commits] [flang] [flang][mlir] Add flang to mlir lowering for groupprivate (PR #180934)

via flang-commits flang-commits at lists.llvm.org
Wed Jun 10 02:54:17 PDT 2026


================
@@ -783,6 +783,115 @@ static void threadPrivatizeVars(lower::AbstractConverter &converter,
   }
 }
 
+// Translate a semantics-layer device_type to the MLIR enum used by
+// omp.groupprivate.
+static mlir::omp::DeclareTargetDeviceType
+toMLIRDeclareTargetDeviceType(Fortran::common::OmpDeviceType deviceType) {
+  switch (deviceType) {
+  case Fortran::common::OmpDeviceType::Any:
+    return mlir::omp::DeclareTargetDeviceType::any;
+  case Fortran::common::OmpDeviceType::Host:
+    return mlir::omp::DeclareTargetDeviceType::host;
+  case Fortran::common::OmpDeviceType::Nohost:
+    return mlir::omp::DeclareTargetDeviceType::nohost;
+  }
+  llvm_unreachable("invalid OmpDeviceType");
+}
+
+static void groupprivatizeVars(lower::AbstractConverter &converter,
+                               lower::pft::Evaluation &eval) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+  mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+  firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+
+  auto module = converter.getModuleOp();
+
+  // Create a groupprivate operation for the symbol.
+  auto genGroupprivateOp = [&](const semantics::Symbol &sym) -> mlir::Value {
+    std::string globalName = converter.mangleName(sym);
+    fir::GlobalOp global = module.lookupSymbol<fir::GlobalOp>(globalName);
+    if (!global) {
+      return mlir::Value();
+    }
+
+    // The device_type modifier was recorded on the symbol during semantic
+    // analysis.
+    mlir::omp::DeclareTargetDeviceType deviceTypeEnum =
+        mlir::omp::DeclareTargetDeviceType::any;
+    Fortran::common::visit(
+        [&](auto &&details) {
+          using TypeD = llvm::remove_cvref_t<decltype(details)>;
+          if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative,
+                                          TypeD>) {
+            if (auto dt = details.ompGroupprivateDeviceType())
+              deviceTypeEnum = toMLIRDeclareTargetDeviceType(*dt);
+          }
+        },
+        sym.GetUltimate().details());
+    mlir::omp::DeclareTargetDeviceTypeAttr deviceTypeAttr =
+        mlir::omp::DeclareTargetDeviceTypeAttr::get(firOpBuilder.getContext(),
+                                                    deviceTypeEnum);
+
+    // omp.groupprivate takes a flat symbol reference and returns
+    // the address of the per-team copy of the global variable.
+    return mlir::omp::GroupprivateOp::create(
+        firOpBuilder, currentLocation, global.resultType(), global.getSymbol(),
+        deviceTypeAttr);
+  };
+
+  llvm::SetVector<const semantics::Symbol *> groupprivateSyms;
+  converter.collectSymbolSet(eval, groupprivateSyms,
+                             semantics::Symbol::Flag::OmpGroupPrivate,
+                             /*collectSymbols=*/true,
+                             /*collectHostAssociatedSymbols=*/true);
+  std::set<semantics::SourceName> groupprivateSymNames;
+
+  // For a COMMON block, the GroupprivateOp is generated for the block itself
+  // instead of its members.
+  llvm::SetVector<const semantics::Symbol *> commonSyms;
+
+  for (std::size_t i = 0; i < groupprivateSyms.size(); i++) {
+    const semantics::Symbol *sym = groupprivateSyms[i];
+    mlir::Value symGroupprivateValue;
+    // The variable may be used more than once, and each reference has one
+    // symbol with the same name. Only do once for references of one variable.
+    if (groupprivateSymNames.find(sym->name()) != groupprivateSymNames.end())
+      continue;
+    groupprivateSymNames.insert(sym->name());
+
+    if (const semantics::Symbol *common =
+            semantics::FindCommonBlockContaining(sym->GetUltimate())) {
+      // Handle common block members: create groupprivate op for the entire
+      // common block, then compute member offset.
+      mlir::Value commonGroupprivateValue;
+      if (commonSyms.contains(common)) {
+        commonGroupprivateValue = converter.getSymbolAddress(*common);
+      } else {
+        commonGroupprivateValue = genGroupprivateOp(*common);
+        if (!commonGroupprivateValue)
+          continue;
+        converter.bindSymbol(*common, commonGroupprivateValue);
+        commonSyms.insert(common);
+      }
+      symGroupprivateValue = lower::genCommonBlockMember(
+          converter, currentLocation, sym->GetUltimate(),
+          commonGroupprivateValue, common->size());
+    } else {
+      symGroupprivateValue = genGroupprivateOp(*sym);
+    }
+
+    if (!symGroupprivateValue) {
+      continue;
+    }
----------------
skc7 wrote:

Updated. Thanks

https://github.com/llvm/llvm-project/pull/180934


More information about the flang-commits mailing list