[flang-commits] [flang] [flang][CUDA] Keep host literals from using unified-memory generic distance (PR #201257)
Zhen Wang via flang-commits
flang-commits at lists.llvm.org
Tue Jun 2 20:52:01 PDT 2026
https://github.com/wangzpgi created https://github.com/llvm/llvm-project/pull/201257
Fix CUDA generic resolution under `-gpu=mem:unified` so unattributed literals and expression temporaries are not treated as unified-memory actuals.
Previously, a host scalar literal such as `1.0` could score as compatible with a `DEVICE` dummy and incorrectly select the device-scalar overload. This caused the `cudafor_cutensor` `cp27` test to pass a host stack address to a device helper and fail at runtime. The fix applies the unified/managed memory distance columns only to symbol-backed actuals.
>From 6db2c1b33e0105a72e2c1593c57898f283d533f6 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 2 Jun 2026 20:48:35 -0700
Subject: [PATCH] Keep host literals from using unified-memory generic distance
---
flang/lib/Semantics/expression.cpp | 13 +++++++++--
flang/test/Semantics/cuf28.cuf | 37 ++++++++++++++++++++++++++++++
2 files changed, 48 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Semantics/cuf28.cuf
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index ceca9436e2672..fb2015c5b1945 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2870,9 +2870,11 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
CHECK(!(isCudaUnified && isCudaManaged) && "expect only one enabled.");
std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
+ bool actualCanUseCudaMemoryMode{false};
if (actual) {
if (auto *expr{actual->UnwrapExpr()}) {
if (evaluate::IsVariable(*expr)) {
+ actualCanUseCudaMemoryMode = true;
// Match check-call.cpp: walk the whole designator so e.g. b%a picks up
// ATTRIBUTES(DEVICE) from the base b when the component a has no CUDA
// attribute (OpenACC use_device(b) + doit(b%a)), not only from the
@@ -2886,6 +2888,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
}
}
} else if (const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)}) {
+ actualCanUseCudaMemoryMode = true;
const Symbol &resolved{
semantics::ResolveAssociations(*actualLastSymbol)};
if (const auto *actualObject{
@@ -2916,7 +2919,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
if (!dummyDataAttr) {
if (!actualDataAttr) {
- if (isCudaUnified || isCudaManaged) {
+ if ((isCudaUnified || isCudaManaged) && actualCanUseCudaMemoryMode) {
return 3;
}
return 0;
@@ -2928,7 +2931,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
}
} else if (*dummyDataAttr == common::CUDADataAttr::Device) {
if (!actualDataAttr) {
- if (isCudaUnified || isCudaManaged) {
+ if ((isCudaUnified || isCudaManaged) && actualCanUseCudaMemoryMode) {
return 2;
}
return cudaInfMatchingValue;
@@ -2940,6 +2943,9 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
}
} else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
if (!actualDataAttr) {
+ if (!actualCanUseCudaMemoryMode) {
+ return cudaInfMatchingValue;
+ }
return isCudaUnified ? 1 : isCudaManaged ? 0 : cudaInfMatchingValue;
}
if (*actualDataAttr == common::CUDADataAttr::Device) {
@@ -2951,6 +2957,9 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
}
} else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
if (!actualDataAttr) {
+ if (!actualCanUseCudaMemoryMode) {
+ return cudaInfMatchingValue;
+ }
return isCudaUnified ? 0 : isCudaManaged ? 1 : cudaInfMatchingValue;
}
if (*actualDataAttr == common::CUDADataAttr::Device) {
diff --git a/flang/test/Semantics/cuf28.cuf b/flang/test/Semantics/cuf28.cuf
new file mode 100644
index 0000000000000..8b1514cd908bd
--- /dev/null
+++ b/flang/test/Semantics/cuf28.cuf
@@ -0,0 +1,37 @@
+! RUN: bbc -emit-hlfir -fcuda -gpu=unified %s -o - | FileCheck %s
+
+! Under -gpu=mem:unified, an unattributed host scalar is still an exact match
+! for an unattributed dummy. A device dummy remains compatible as a fallback,
+! but it must not beat the host-specific overload during generic resolution.
+module matching_host_scalar_unified
+ interface pick
+ module procedure host_scalar
+ module procedure device_scalar
+ end interface
+ interface device_fallback
+ module procedure device_scalar
+ end interface
+contains
+ subroutine host_scalar(a, x)
+ real(8), device :: a(:)
+ real(4) :: x
+ end
+
+ subroutine device_scalar(a, x)
+ real(8), device :: a(:)
+ real(4), device :: x
+ end
+end module
+
+program test
+ use matching_host_scalar_unified
+ real(8), device :: a(1)
+ real(4) :: x
+
+ call pick(a, 1.0)
+ call device_fallback(a, x)
+end
+
+! CHECK-LABEL: func.func @_QQmain()
+! CHECK: fir.call @_QMmatching_host_scalar_unifiedPhost_scalar
+! CHECK: fir.call @_QMmatching_host_scalar_unifiedPdevice_scalar
More information about the flang-commits
mailing list