[flang-commits] [flang] [llvm] [flang][cuda] Add support for derived-type component with managed/unified attributes (PR #177409)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 22 09:47:26 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Derived-type components that have the `ALLOCATABLE` or `POINTER` attribute as well as the CUDA `MANAGED` or `UNIFIED` attribute need to have a specific allocator index set in the descriptor so the allocation is done correctly. Without this, the allocation is done in host memory and will trigger illegal read or write if the component is used on the device. The correct allocator index was set some time ago for the `DEVICE` attribute but the `MANAGED` and `UNIFIED` attribute need the same mechanism.

Since the `Component::Genre` has quite some room I opted to add specific genre for allocatable and pointer with both managed or unified attribute.
@<!-- -->klausler Let me know if you would prefer another solution. I was thinking about a separate field but I wanted to avoid wasting some bytes. 

---
Full diff: https://github.com/llvm/llvm-project/pull/177409.diff


8 Files Affected:

- (modified) flang-rt/include/flang-rt/runtime/type-info.h (+5-1) 
- (modified) flang-rt/lib/runtime/assign.cpp (+5-1) 
- (modified) flang-rt/lib/runtime/copy.cpp (+4) 
- (modified) flang-rt/lib/runtime/derived.cpp (+19-5) 
- (modified) flang-rt/lib/runtime/type-info.cpp (+26-7) 
- (modified) flang/lib/Semantics/runtime-type-info.cpp (+16) 
- (modified) flang/module/__fortran_type_info.f90 (+4-1) 
- (modified) flang/test/Lower/CUDA/cuda-allocatable-device.cuf (+23-1) 


``````````diff
diff --git a/flang-rt/include/flang-rt/runtime/type-info.h b/flang-rt/include/flang-rt/runtime/type-info.h
index a6528312750a0..89f70c1a23d51 100644
--- a/flang-rt/include/flang-rt/runtime/type-info.h
+++ b/flang-rt/include/flang-rt/runtime/type-info.h
@@ -57,7 +57,11 @@ class Component {
     Allocatable = 3,
     Automatic = 4,
     PointerDevice = 5,
-    AllocatableDevice = 6
+    AllocatableDevice = 6,
+    PointerManaged = 7,
+    AllocatableManaged = 8,
+    PointerUnified = 9,
+    AllocatableUnified = 10
   };
 
   RT_API_ATTRS const Descriptor &name() const { return name_.descriptor(); }
diff --git a/flang-rt/lib/runtime/assign.cpp b/flang-rt/lib/runtime/assign.cpp
index dd5d4b945881e..d3b9d565c620d 100644
--- a/flang-rt/lib/runtime/assign.cpp
+++ b/flang-rt/lib/runtime/assign.cpp
@@ -649,7 +649,9 @@ RT_API_ATTRS int DerivedAssignTicket<IS_COMPONENTWISE>::Continue(
       }
       break;
     case typeInfo::Component::Genre::Pointer:
-    case typeInfo::Component::Genre::PointerDevice: {
+    case typeInfo::Component::Genre::PointerDevice:
+    case typeInfo::Component::Genre::PointerManaged:
+    case typeInfo::Component::Genre::PointerUnified: {
       std::size_t componentByteSize{
           this->component_->SizeInBytes(this->instance_)};
       if (IS_COMPONENTWISE && toIsContiguous_ && fromIsContiguous_) {
@@ -682,6 +684,8 @@ RT_API_ATTRS int DerivedAssignTicket<IS_COMPONENTWISE>::Continue(
     } break;
     case typeInfo::Component::Genre::Allocatable:
     case typeInfo::Component::Genre::AllocatableDevice:
+    case typeInfo::Component::Genre::AllocatableManaged:
+    case typeInfo::Component::Genre::AllocatableUnified:
     case typeInfo::Component::Genre::Automatic: {
       auto *toDesc{reinterpret_cast<Descriptor *>(
           this->instance_.template Element<char>(this->subscripts_) +
diff --git a/flang-rt/lib/runtime/copy.cpp b/flang-rt/lib/runtime/copy.cpp
index 8b7db61b014e1..d1d35d2bbd686 100644
--- a/flang-rt/lib/runtime/copy.cpp
+++ b/flang-rt/lib/runtime/copy.cpp
@@ -170,6 +170,10 @@ RT_API_ATTRS void CopyElement(const Descriptor &to, const SubscriptValue toAt[],
           if (component->genre() == typeInfo::Component::Genre::Allocatable ||
               component->genre() ==
                   typeInfo::Component::Genre::AllocatableDevice ||
+              component->genre() ==
+                  typeInfo::Component::Genre::AllocatableManaged ||
+              component->genre() ==
+                  typeInfo::Component::Genre::AllocatableUnified ||
               component->genre() == typeInfo::Component::Genre::Automatic) {
             Descriptor &toDesc{
                 *reinterpret_cast<Descriptor *>(toPtr + component->offset())};
diff --git a/flang-rt/lib/runtime/derived.cpp b/flang-rt/lib/runtime/derived.cpp
index 7fc426b9efc9a..8aa6be510a13a 100644
--- a/flang-rt/lib/runtime/derived.cpp
+++ b/flang-rt/lib/runtime/derived.cpp
@@ -65,7 +65,9 @@ RT_API_ATTRS int InitializeTicket::Continue(WorkQueue &workQueue) {
   for (; !Componentwise::IsComplete(); SkipToNextComponent()) {
     char *rawComponent{rawInstance + component_->offset()};
     if (component_->genre() == typeInfo::Component::Genre::Allocatable ||
-        component_->genre() == typeInfo::Component::Genre::AllocatableDevice) {
+        component_->genre() == typeInfo::Component::Genre::AllocatableDevice ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableManaged ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableUnified) {
       Descriptor &allocDesc{*reinterpret_cast<Descriptor *>(rawComponent)};
       component_->EstablishDescriptor(
           allocDesc, instance_, workQueue.terminator());
@@ -79,7 +81,9 @@ RT_API_ATTRS int InitializeTicket::Continue(WorkQueue &workQueue) {
         Fortran::runtime::memcpy(rawComponent, init, bytes);
       }
     } else if (component_->genre() == typeInfo::Component::Genre::Pointer ||
-        component_->genre() == typeInfo::Component::Genre::PointerDevice) {
+        component_->genre() == typeInfo::Component::Genre::PointerDevice ||
+        component_->genre() == typeInfo::Component::Genre::PointerManaged ||
+        component_->genre() == typeInfo::Component::Genre::PointerUnified) {
       // Data pointers without explicit initialization are established
       // so that they are valid right-hand side targets of pointer
       // assignment statements.
@@ -164,7 +168,9 @@ RT_API_ATTRS int InitializeClone(const Descriptor &clone,
 RT_API_ATTRS int InitializeCloneTicket::Continue(WorkQueue &workQueue) {
   while (!IsComplete()) {
     if (component_->genre() == typeInfo::Component::Genre::Allocatable ||
-        component_->genre() == typeInfo::Component::Genre::AllocatableDevice) {
+        component_->genre() == typeInfo::Component::Genre::AllocatableDevice ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableManaged ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableUnified) {
       Descriptor &origDesc{*instance_.ElementComponent<Descriptor>(
           subscripts_, component_->offset())};
       if (origDesc.IsAllocated()) {
@@ -343,7 +349,11 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) {
   while (!IsComplete()) {
     if ((component_->genre() == typeInfo::Component::Genre::Allocatable ||
             component_->genre() ==
-                typeInfo::Component::Genre::AllocatableDevice) &&
+                typeInfo::Component::Genre::AllocatableDevice ||
+            component_->genre() ==
+                typeInfo::Component::Genre::AllocatableManaged ||
+            component_->genre() ==
+                typeInfo::Component::Genre::AllocatableUnified) &&
         component_->category() == TypeCategory::Derived) {
       // Component may be polymorphic or unlimited polymorphic. Need to use the
       // dynamic type to check whether finalization is needed.
@@ -366,6 +376,8 @@ RT_API_ATTRS int FinalizeTicket::Continue(WorkQueue &workQueue) {
       }
     } else if (component_->genre() == typeInfo::Component::Genre::Allocatable ||
         component_->genre() == typeInfo::Component::Genre::AllocatableDevice ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableManaged ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableUnified ||
         component_->genre() == typeInfo::Component::Genre::Automatic) {
       if (const typeInfo::DerivedType *compType{component_->derivedType()};
           compType && !compType->noFinalizationNeeded()) {
@@ -449,7 +461,9 @@ RT_API_ATTRS int DestroyTicket::Continue(WorkQueue &workQueue) {
   while (!IsComplete()) {
     const auto *componentDerived{component_->derivedType()};
     if (component_->genre() == typeInfo::Component::Genre::Allocatable ||
-        component_->genre() == typeInfo::Component::Genre::AllocatableDevice) {
+        component_->genre() == typeInfo::Component::Genre::AllocatableDevice ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableManaged ||
+        component_->genre() == typeInfo::Component::Genre::AllocatableUnified) {
       if (fixedStride_ &&
           (!componentDerived || componentDerived->noDestructionNeeded())) {
         // common fast path, just deallocate in every element
diff --git a/flang-rt/lib/runtime/type-info.cpp b/flang-rt/lib/runtime/type-info.cpp
index 1157dda09c412..99bf48b5def36 100644
--- a/flang-rt/lib/runtime/type-info.cpp
+++ b/flang-rt/lib/runtime/type-info.cpp
@@ -95,16 +95,25 @@ RT_API_ATTRS std::size_t Component::SizeInBytes(
 RT_API_ATTRS void Component::EstablishDescriptor(Descriptor &descriptor,
     const Descriptor &container, Terminator &terminator) const {
   ISO::CFI_attribute_t attribute{static_cast<ISO::CFI_attribute_t>(
-      genre_ == Genre::Allocatable || genre_ == Genre::AllocatableDevice
+      genre_ == Genre::Allocatable || genre_ == Genre::AllocatableDevice ||
+              genre_ == Genre::AllocatableManaged ||
+              genre_ == Genre::AllocatableUnified
           ? CFI_attribute_allocatable
-          : genre_ == Genre::Pointer || genre_ == Genre::PointerDevice
+          : genre_ == Genre::Pointer || genre_ == Genre::PointerDevice ||
+              genre_ == Genre::PointerManaged || genre_ == Genre::PointerUnified
           ? CFI_attribute_pointer
           : CFI_attribute_other)};
   TypeCategory cat{category()};
-  unsigned allocatorIdx{
-      genre_ == Genre::AllocatableDevice || genre_ == Genre::PointerDevice
-          ? kDeviceAllocatorPos
-          : kDefaultAllocator};
+  unsigned allocatorIdx{kDefaultAllocator};
+  if (genre_ == Genre::AllocatableDevice || genre_ == Genre::PointerDevice) {
+    allocatorIdx = kDeviceAllocatorPos;
+  } else if (genre_ == Genre::AllocatableManaged ||
+      genre_ == Genre::PointerManaged) {
+    allocatorIdx = kManagedAllocatorPos;
+  } else if (genre_ == Genre::AllocatableUnified ||
+      genre_ == Genre::PointerUnified) {
+    allocatorIdx = kUnifiedAllocatorPos;
+  }
   if (cat == TypeCategory::Character) {
     std::size_t lengthInChars{0};
     if (auto length{characterLen_.GetValue(&container)}) {
@@ -128,7 +137,9 @@ RT_API_ATTRS void Component::EstablishDescriptor(Descriptor &descriptor,
         cat, kind_, nullptr, rank_, nullptr, attribute, false, allocatorIdx);
   }
   if (rank_ && genre_ != Genre::Allocatable && genre_ != Genre::Pointer &&
-      genre_ != Genre::AllocatableDevice && genre_ != Genre::PointerDevice) {
+      genre_ != Genre::AllocatableDevice && genre_ != Genre::PointerDevice &&
+      genre_ != Genre::AllocatableManaged && genre_ != Genre::PointerManaged &&
+      genre_ != Genre::AllocatableUnified && genre_ != Genre::PointerUnified) {
     const typeInfo::Value *boundValues{bounds()};
     RUNTIME_CHECK(terminator, boundValues != nullptr);
     auto byteStride{static_cast<SubscriptValue>(descriptor.ElementBytes())};
@@ -281,10 +292,18 @@ FILE *Component::Dump(FILE *f) const {
     std::fputs("    Pointer          ", f);
   } else if (genre_ == Genre::PointerDevice) {
     std::fputs("    PointerDevice    ", f);
+  } else if (genre_ == Genre::PointerManaged) {
+    std::fputs("    PointerManaged   ", f);
+  } else if (genre_ == Genre::PointerUnified) {
+    std::fputs("    PointerUnified   ", f);
   } else if (genre_ == Genre::Allocatable) {
     std::fputs("    Allocatable.     ", f);
   } else if (genre_ == Genre::AllocatableDevice) {
     std::fputs("    AllocatableDevice", f);
+  } else if (genre_ == Genre::AllocatableManaged) {
+    std::fputs("    AllocatableManaged", f);
+  } else if (genre_ == Genre::AllocatableUnified) {
+    std::fputs("    AllocatableUnified", f);
   } else if (genre_ == Genre::Automatic) {
     std::fputs("    Automatic        ", f);
   } else {
diff --git a/flang/lib/Semantics/runtime-type-info.cpp b/flang/lib/Semantics/runtime-type-info.cpp
index 8f92fda65685a..b6c712ecdae20 100644
--- a/flang/lib/Semantics/runtime-type-info.cpp
+++ b/flang/lib/Semantics/runtime-type-info.cpp
@@ -773,6 +773,10 @@ evaluate::StructureConstructor RuntimeTableBuilder::DescribeComponent(
       symbol, foldingContext)};
   bool isDevice{object.cudaDataAttr() &&
       *object.cudaDataAttr() == common::CUDADataAttr::Device};
+  bool isManaged{object.cudaDataAttr() &&
+      *object.cudaDataAttr() == common::CUDADataAttr::Managed};
+  bool isUnified{object.cudaDataAttr() &&
+      *object.cudaDataAttr() == common::CUDADataAttr::Unified};
   CHECK(typeAndShape.has_value());
   auto dyType{typeAndShape->type()};
   int rank{typeAndShape->Rank()};
@@ -888,6 +892,12 @@ evaluate::StructureConstructor RuntimeTableBuilder::DescribeComponent(
     if (isDevice) {
       AddValue(values, componentSchema_, "genre"s,
           GetEnumValue("allocatabledevice"));
+    } else if (isManaged) {
+      AddValue(values, componentSchema_, "genre"s,
+          GetEnumValue("allocatablemanaged"));
+    } else if (isUnified) {
+      AddValue(values, componentSchema_, "genre"s,
+          GetEnumValue("allocatableunified"));
     } else {
       AddValue(values, componentSchema_, "genre"s, GetEnumValue("allocatable"));
     }
@@ -895,6 +905,12 @@ evaluate::StructureConstructor RuntimeTableBuilder::DescribeComponent(
     if (isDevice) {
       AddValue(
           values, componentSchema_, "genre"s, GetEnumValue("pointerdevice"));
+    } else if (isManaged) {
+      AddValue(
+          values, componentSchema_, "genre"s, GetEnumValue("pointermanaged"));
+    } else if (isUnified) {
+      AddValue(
+          values, componentSchema_, "genre"s, GetEnumValue("pointerunified"));
     } else {
       AddValue(values, componentSchema_, "genre"s, GetEnumValue("pointer"));
     }
diff --git a/flang/module/__fortran_type_info.f90 b/flang/module/__fortran_type_info.f90
index ae8eeef4a55e8..a8d9959feb872 100644
--- a/flang/module/__fortran_type_info.f90
+++ b/flang/module/__fortran_type_info.f90
@@ -75,7 +75,10 @@
   end type
 
   enum, bind(c) ! Component::Genre
-    enumerator :: Data = 1, Pointer = 2, Allocatable = 3, Automatic = 4, PointerDevice = 5, AllocatableDevice = 6
+    enumerator :: Data = 1, Pointer = 2, Allocatable = 3, Automatic = 4
+    enumerator :: PointerDevice = 5, AllocatableDevice = 6
+    enumerator :: PointerManaged = 7, AllocatableManaged = 8
+    enumerator :: PointerUnified = 9, AllocatableUnified = 10
   end enum
 
   enum, bind(c) ! common::TypeCategory
diff --git a/flang/test/Lower/CUDA/cuda-allocatable-device.cuf b/flang/test/Lower/CUDA/cuda-allocatable-device.cuf
index 57c588e5beafa..428aec073fbb6 100644
--- a/flang/test/Lower/CUDA/cuda-allocatable-device.cuf
+++ b/flang/test/Lower/CUDA/cuda-allocatable-device.cuf
@@ -6,17 +6,39 @@ module m
     real(kind=8), pointer, dimension(:), device :: pd
   end type
 
+  type managed_array
+    real(kind=8), allocatable, dimension(:), managed :: ad
+    real(kind=8), pointer, dimension(:), managed :: pd
+  end type
+
+  type unified_array
+    real(kind=8), allocatable, dimension(:), unified :: ad
+    real(kind=8), pointer, dimension(:), unified :: pd
+  end type
+
   type(device_array), allocatable :: da(:)
+  type(managed_array), allocatable :: ma(:)
+  type(unified_array), allocatable :: ua(:)
 end module
 
 ! CHECK-LABEL: fir.global linkonce_odr @_QMmE.c.device_array
 ! CHECK: fir.insert_value %{{.*}}, %c6{{.*}}, ["genre"
 ! CHECK: fir.insert_value %{{.*}}, %c5{{.*}}, ["genre"
 
+! CHECK-LABEL: fir.global linkonce_odr @_QMmE.c.managed_array
+! CHECK: fir.insert_value %{{.*}}, %c8{{.*}}, ["genre"
+! CHECK: fir.insert_value %{{.*}}, %c7{{.*}}, ["genre"
+
+! CHECK-LABEL: fir.global linkonce_odr @_QMmE.c.unified_array
+! CHECK: fir.insert_value %{{.*}}, %c10{{.*}}, ["genre"
+! CHECK: fir.insert_value %{{.*}}, %c9{{.*}}, ["genre"
+
 program main
   use m
   type(device_array) :: local
+  type(managed_array) :: local_ma
+  type(unified_array) :: local_ua
 end
 
 ! CHECK-LABEL: func.func @_QQmain()
-! CHECK: fir.call @_FortranAInitialize
+! CHECK-COUNT-3: fir.call @_FortranAInitialize

``````````

</details>


https://github.com/llvm/llvm-project/pull/177409


More information about the flang-commits mailing list