[flang-commits] [flang] [flang][cuda] Handle Constant and Shared attributes in CUDA generic matching distance (PR #201451)

via flang-commits flang-commits at lists.llvm.org
Wed Jun 3 13:55:31 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Zhen Wang (wangzpgi)

<details>
<summary>Changes</summary>

The CUDA generic resolution matching distance function did not handle actual arguments with the `Constant` or `Shared` data attribute. These attributes represent device memory but were unhandled, causing the distance to fall through to infinity. Under `-gpu=mem:unified`, this led to spurious ambiguity errors when multiple specifics (e.g. host-to-device and device-to-device overloads) both became candidates with tied infinite distances.

Treat Constant and Shared actuals the same as Device in the matching distance table, since all three reside in device memory.

---
Full diff: https://github.com/llvm/llvm-project/pull/201451.diff


2 Files Affected:

- (modified) flang/lib/Semantics/expression.cpp (+19-10) 
- (added) flang/test/Semantics/cuf-constant-generic-unified.cuf (+41) 


``````````diff
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index ceca9436e2672..50869a3c870ef 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2850,12 +2850,14 @@ static int CompareCudaMatchingDistance(
 //
 //                       Actual argument attribute
 //                 None                              ACC      gpu=    gpu=
-//   Dummy attr   (Host)  Device  Managed Unified  use_dev  unified  managed
-//   ----------+--------+-------+--------+-------+--------+--------+--------+
-//   None(host)|    0   |  INF  |   3    |   3   |   3    |   3    |   3    |
-//   Device    |   INF  |   0   |   2    |   2   |   0    |   2    |   2    |
-//   Managed   |   INF  |  INF  |   0    |   1   |  INF   |   1    |   0    |
-//   Unified   |   INF  |  INF  |   1    |   0   |  INF   |   0    |   1    |
+//   Dummy attr   (Host)  Device  Managed  Unified  use_dev  unified  managed
+//   ----------+--------+-------+--------+--------+--------+--------+--------+
+//   None(host)|    0   |  INF  |   3    |   3    |   3    |   3    |   3    |
+//   Device    |   INF  |   0   |   2    |   2    |   0    |   2    |   2    |
+//   Managed   |   INF  |  INF  |   0    |   1    |  INF   |   1    |   0    |
+//   Unified   |   INF  |  INF  |   1    |   0    |  INF   |   0    |   1    |
+//
+// Constant and Shared actuals use the Device column (all are device memory).
 //
 // In addition: a dummy declared TYPE(*) (assumed-size/rank opaque buffer)
 // is "CUDA address space agnostic" and accepts any attributed actual at a
@@ -2914,13 +2916,20 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     return 3;
   }
 
+  auto actualIsDeviceMemory{[&]() {
+    return actualDataAttr &&
+        (*actualDataAttr == common::CUDADataAttr::Device ||
+            *actualDataAttr == common::CUDADataAttr::Constant ||
+            *actualDataAttr == common::CUDADataAttr::Shared);
+  }};
+
   if (!dummyDataAttr) {
     if (!actualDataAttr) {
       if (isCudaUnified || isCudaManaged) {
         return 3;
       }
       return 0;
-    } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+    } else if (actualIsDeviceMemory()) {
       return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
         *actualDataAttr == common::CUDADataAttr::Unified) {
@@ -2932,7 +2941,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
         return 2;
       }
       return cudaInfMatchingValue;
-    } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+    } else if (actualIsDeviceMemory()) {
       return 0;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
         *actualDataAttr == common::CUDADataAttr::Unified) {
@@ -2942,7 +2951,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     if (!actualDataAttr) {
       return isCudaUnified ? 1 : isCudaManaged ? 0 : cudaInfMatchingValue;
     }
-    if (*actualDataAttr == common::CUDADataAttr::Device) {
+    if (actualIsDeviceMemory()) {
       return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
       return 0;
@@ -2953,7 +2962,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     if (!actualDataAttr) {
       return isCudaUnified ? 0 : isCudaManaged ? 1 : cudaInfMatchingValue;
     }
-    if (*actualDataAttr == common::CUDADataAttr::Device) {
+    if (actualIsDeviceMemory()) {
       return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
       return 1;
diff --git a/flang/test/Semantics/cuf-constant-generic-unified.cuf b/flang/test/Semantics/cuf-constant-generic-unified.cuf
new file mode 100644
index 0000000000000..e61dbb0c5b70b
--- /dev/null
+++ b/flang/test/Semantics/cuf-constant-generic-unified.cuf
@@ -0,0 +1,41 @@
+! RUN: bbc -emit-hlfir -fcuda -gpu=unified %s -o - | FileCheck %s
+
+! Under -gpu=mem:unified, a CONSTANT actual argument (device memory)
+! must resolve to the device-attributed specific in a generic that also
+! has a host-typed specific. Previously, the Constant CUDA data
+! attribute was unhandled in the generic matching distance computation,
+! causing ambiguity errors.
+
+module m
+  interface gen
+    module procedure sub_host
+    module procedure sub_device
+  end interface
+contains
+  subroutine sub_host(x)
+    integer :: x
+  end subroutine
+  subroutine sub_device(x)
+    integer, device :: x
+  end subroutine
+end module
+
+module mconst
+  integer, constant :: cvar
+end module
+
+subroutine caller(host_val)
+  use m
+  use mconst
+  implicit none
+  integer, intent(in) :: host_val
+
+  ! constant actual -> device specific (constant is device memory)
+  call gen(cvar)
+  ! host actual under unified -> device specific (per cuf14.cuf rules)
+  call gen(host_val)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPcaller
+! CHECK: fir.call @_QMmPsub_device
+! CHECK: fir.call @_QMmPsub_device

``````````

</details>


https://github.com/llvm/llvm-project/pull/201451


More information about the flang-commits mailing list