[flang-commits] [flang] [flang][cuda] Move cuf.set_allocator_idx after derived-type init (PR #148936)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Jul 15 12:21:27 PDT 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/148936

Derived type initialization overwrite the component descriptor. Place the `cuf.set_allocator_idx` after the initialization is performed. 

>From 054a5a37a98acb09852018e4443b1d624e577f6b Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 15 Jul 2025 12:19:50 -0700
Subject: [PATCH] [flang][cuda] Move cuf.set_allocator_idx after derived-type
 init

---
 flang/lib/Lower/ConvertVariable.cpp          | 167 +++++++++++--------
 flang/test/Lower/CUDA/cuda-set-allocator.cuf |  14 +-
 2 files changed, 105 insertions(+), 76 deletions(-)

diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index ffe456de56630..bfd4782c764fc 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
       return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
                                                  indices);
 
-    if (!cuf::isCUDADeviceContext(builder.getRegion())) {
-      mlir::Value alloc = builder.create<cuf::AllocOp>(
-          loc, ty, nm, symNm, dataAttr, lenParams, indices);
-      if (const auto *details{
-              ultimateSymbol
-                  .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
-        const Fortran::semantics::DeclTypeSpec *type{details->type()};
-        const Fortran::semantics::DerivedTypeSpec *derived{
-            type ? type->AsDerived() : nullptr};
-        if (derived) {
-          Fortran::semantics::UltimateComponentIterator components{*derived};
-          auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
-
-          llvm::SmallVector<mlir::Value> coordinates;
-          for (const auto &sym : components) {
-            if (Fortran::semantics::IsDeviceAllocatable(sym)) {
-              unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
-              mlir::Type fieldTy;
-              std::vector<mlir::Value> coordinates;
-
-              if (fieldIdx != std::numeric_limits<unsigned>::max()) {
-                // Field found in the base record type.
-                auto fieldName = recTy.getTypeList()[fieldIdx].first;
-                fieldTy = recTy.getTypeList()[fieldIdx].second;
-                mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
-                    loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
-                    recTy,
-                    /*typeParams=*/mlir::ValueRange{});
-                coordinates.push_back(fieldIndex);
-              } else {
-                // Field not found in base record type, search in potential
-                // record type components.
-                for (auto component : recTy.getTypeList()) {
-                  if (auto childRecTy =
-                          mlir::dyn_cast<fir::RecordType>(component.second)) {
-                    fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
-                    if (fieldIdx != std::numeric_limits<unsigned>::max()) {
-                      mlir::Value parentFieldIndex =
-                          builder.create<fir::FieldIndexOp>(
-                              loc, fir::FieldType::get(childRecTy.getContext()),
-                              component.first, recTy,
-                              /*typeParams=*/mlir::ValueRange{});
-                      coordinates.push_back(parentFieldIndex);
-                      auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
-                      fieldTy = childRecTy.getTypeList()[fieldIdx].second;
-                      mlir::Value childFieldIndex =
-                          builder.create<fir::FieldIndexOp>(
-                              loc, fir::FieldType::get(fieldTy.getContext()),
-                              fieldName, childRecTy,
-                              /*typeParams=*/mlir::ValueRange{});
-                      coordinates.push_back(childFieldIndex);
-                      break;
-                    }
-                  }
-                }
-              }
-
-              if (coordinates.empty())
-                TODO(loc, "device resident component in complex derived-type "
-                          "hierarchy");
-
-              mlir::Value comp = builder.create<fir::CoordinateOp>(
-                  loc, builder.getRefType(fieldTy), alloc, coordinates);
-              cuf::DataAttributeAttr dataAttr =
-                  Fortran::lower::translateSymbolCUFDataAttribute(
-                      builder.getContext(), sym);
-              builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
-            }
-          }
-        }
-      }
-      return alloc;
-    }
+    if (!cuf::isCUDADeviceContext(builder.getRegion()))
+      return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
+                                          lenParams, indices);
   }
 
   // Let the builder do all the heavy lifting.
@@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
   return res;
 }
 
+/// Device allocatable component in a derived-type don't have the correct
+/// allocator index in their descriptor when they are created. After
+/// initialization, cuf.set_allocator_idx operations are inserted to set the
+/// correct allocator index for each device component.
+static void
+initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &symbol,
+                                   Fortran::lower::SymMap &symMap) {
+  if (const auto *details{
+          symbol.GetUltimate()
+              .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
+    const Fortran::semantics::DeclTypeSpec *type{details->type()};
+    const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
+                                                            : nullptr};
+    if (derived) {
+      fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+      mlir::Location loc = converter.getCurrentLocation();
+
+      fir::ExtendedValue exv =
+          converter.getSymbolExtendedValue(symbol.GetUltimate(), &symMap);
+      auto recTy = mlir::dyn_cast<fir::RecordType>(
+          fir::unwrapRefType(fir::getBase(exv).getType()));
+      assert(recTy && "expected fir::RecordType");
+
+      llvm::SmallVector<mlir::Value> coordinates;
+      Fortran::semantics::UltimateComponentIterator components{*derived};
+      for (const auto &sym : components) {
+        if (Fortran::semantics::IsDeviceAllocatable(sym)) {
+          unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
+          mlir::Type fieldTy;
+          std::vector<mlir::Value> coordinates;
+
+          if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+            // Field found in the base record type.
+            auto fieldName = recTy.getTypeList()[fieldIdx].first;
+            fieldTy = recTy.getTypeList()[fieldIdx].second;
+            mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
+                loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
+                recTy,
+                /*typeParams=*/mlir::ValueRange{});
+            coordinates.push_back(fieldIndex);
+          } else {
+            // Field not found in base record type, search in potential
+            // record type components.
+            for (auto component : recTy.getTypeList()) {
+              if (auto childRecTy =
+                      mlir::dyn_cast<fir::RecordType>(component.second)) {
+                fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
+                if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+                  mlir::Value parentFieldIndex =
+                      builder.create<fir::FieldIndexOp>(
+                          loc, fir::FieldType::get(childRecTy.getContext()),
+                          component.first, recTy,
+                          /*typeParams=*/mlir::ValueRange{});
+                  coordinates.push_back(parentFieldIndex);
+                  auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
+                  fieldTy = childRecTy.getTypeList()[fieldIdx].second;
+                  mlir::Value childFieldIndex =
+                      builder.create<fir::FieldIndexOp>(
+                          loc, fir::FieldType::get(fieldTy.getContext()),
+                          fieldName, childRecTy,
+                          /*typeParams=*/mlir::ValueRange{});
+                  coordinates.push_back(childFieldIndex);
+                  break;
+                }
+              }
+            }
+          }
+
+          if (coordinates.empty())
+            TODO(loc, "device resident component in complex derived-type "
+                      "hierarchy");
+
+          mlir::Value comp = builder.create<fir::CoordinateOp>(
+              loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates);
+          cuf::DataAttributeAttr dataAttr =
+              Fortran::lower::translateSymbolCUFDataAttribute(
+                  builder.getContext(), sym);
+          builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
+        }
+      }
+    }
+  }
+}
+
 /// Must \p var be default initialized at runtime when entering its scope.
 static bool
 mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) {
@@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
   if (mustBeDefaultInitializedAtRuntime(var))
     Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
                                                symMap);
+  if (converter.getFoldingContext().languageFeatures().IsEnabled(
+          Fortran::common::LanguageFeature::CUDA))
+    initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
   auto *builder = &converter.getFirOpBuilder();
   if (needCUDAAlloc(var.getSymbol()) &&
       !cuf::isCUDADeviceContext(builder->getRegion())) {
@@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
   if (mustBeDefaultInitializedAtRuntime(var))
     Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
                                                symMap);
+  if (converter.getFoldingContext().languageFeatures().IsEnabled(
+          Fortran::common::LanguageFeature::CUDA))
+    initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
 }
 
 //===--------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-set-allocator.cuf b/flang/test/Lower/CUDA/cuda-set-allocator.cuf
index bf74e012a639d..85715edf26942 100644
--- a/flang/test/Lower/CUDA/cuda-set-allocator.cuf
+++ b/flang/test/Lower/CUDA/cuda-set-allocator.cuf
@@ -12,10 +12,18 @@ contains
   end subroutine
 
 ! CHECK-LABEL: func.func @_QMm1Psub1()
-! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
-! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
+! CHECK: %[[DT:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
+! CHECK: fir.address_of(@_QQ_QMm1Tty_device.DerivedInit)
+! CHECK: fir.copy 
+! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]]#0, x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
 ! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
-! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
 ! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
 
 end module
+
+
+
+    
+  



More information about the flang-commits mailing list