[flang-commits] [flang] [flang][cuda] Resolve generic tie‑breaks using per‑arg CUDA distance (PR #181038)
Zhen Wang via flang-commits
flang-commits at lists.llvm.org
Thu Feb 12 18:45:55 PST 2026
https://github.com/wangzpgi updated https://github.com/llvm/llvm-project/pull/181038
>From 6ea1d125fef749b84f2ec982b3d20073ba2ab9f9 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 9 Feb 2026 14:51:47 -0800
Subject: [PATCH 1/5] cuda matching distance
---
flang/lib/Semantics/expression.cpp | 76 ++++++++++++++++++++++--------
1 file changed, 57 insertions(+), 19 deletions(-)
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b0235469c9de6..b4dac18b17a10 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2799,6 +2799,40 @@ static bool CheckCompatibleArguments(
static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
+struct CudaMatchingDistance {
+ std::vector<int> perArg;
+ bool isInfinite{false};
+};
+
+static int CompareCudaMatchingDistance(
+ const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
+ if (x.isInfinite != y.isInfinite) {
+ return x.isInfinite ? 1 : -1;
+ }
+ if (x.isInfinite) {
+ return 0;
+ }
+ std::size_t n{x.perArg.size()};
+ if (n > y.perArg.size()) {
+ n = y.perArg.size();
+ }
+ for (std::size_t i{0}; i < n; ++i) {
+ if (x.perArg[i] < y.perArg[i]) {
+ return -1;
+ }
+ if (x.perArg[i] > y.perArg[i]) {
+ return 1;
+ }
+ }
+ if (x.perArg.size() < y.perArg.size()) {
+ return -1;
+ }
+ if (x.perArg.size() > y.perArg.size()) {
+ return 1;
+ }
+ return 0;
+}
+
// Compute the matching distance as described in section 3.2.3 of the CUDA
// Fortran references.
static int GetMatchingDistance(const common::LanguageFeatureControl &features,
@@ -2882,20 +2916,23 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
return cudaInfMatchingValue;
}
-static int ComputeCudaMatchingDistance(
+static CudaMatchingDistance ComputeCudaMatchingDistance(
const common::LanguageFeatureControl &features,
const characteristics::Procedure &procedure,
const ActualArguments &actuals) {
const auto &dummies{procedure.dummyArguments};
CHECK(dummies.size() == actuals.size());
- int distance{0};
+ CudaMatchingDistance distance;
+ distance.perArg.reserve(dummies.size());
for (std::size_t i{0}; i < dummies.size(); ++i) {
const characteristics::DummyArgument &dummy{dummies[i]};
const std::optional<ActualArgument> &actual{actuals[i]};
int d{GetMatchingDistance(features, dummy, actual)};
- if (d == cudaInfMatchingValue)
- return d;
- distance += d;
+ if (d == cudaInfMatchingValue) {
+ distance.isInfinite = true;
+ return distance;
+ }
+ distance.perArg.push_back(d);
}
return distance;
}
@@ -2967,7 +3004,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const auto *genericDetails{ultimate.detailsIf<semantics::GenericDetails>()};
if (genericDetails && !explicitIntrinsic) {
- int crtMatchingDistance{cudaInfMatchingValue};
+ std::optional<CudaMatchingDistance> crtMatchingDistance;
for (const Symbol &specific0 : genericDetails->specificProcs()) {
const Symbol &specific1{BypassGeneric(specific0)};
if (isSubroutine != !IsFunction(specific1)) {
@@ -2992,23 +3029,25 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
context_, false /* no integer conversions */) &&
CheckCompatibleArguments(
*procedure, localActuals, foldingContext_)) {
+ CudaMatchingDistance d{ComputeCudaMatchingDistance(
+ context_.languageFeatures(), *procedure, localActuals)};
if ((procedure->IsElemental() && elemental) ||
(!procedure->IsElemental() && nonElemental)) {
- int d{ComputeCudaMatchingDistance(
- context_.languageFeatures(), *procedure, localActuals)};
- if (d != crtMatchingDistance) {
- if (d > crtMatchingDistance) {
+ if (crtMatchingDistance) {
+ int cmp{CompareCudaMatchingDistance(d, *crtMatchingDistance)};
+ if (cmp > 0) {
continue;
}
+ if (cmp == 0) {
+ // 16.9.144(6): a bare NULL() is not allowed as an actual
+ // argument to a generic procedure if the specific procedure
+ // cannot be unambiguously distinguished
+ // Underspecified external procedure actual arguments can
+ // also lead to ambiguity.
+ return {nullptr, true /* due to ambiguity */, std::move(tried)};
+ }
// Matching distance is smaller than the previously matched
// specific. Let it go through so the current procedure is picked.
- } else {
- // 16.9.144(6): a bare NULL() is not allowed as an actual
- // argument to a generic procedure if the specific procedure
- // cannot be unambiguously distinguished
- // Underspecified external procedure actual arguments can
- // also lead to ambiguity.
- return {nullptr, true /* due to ambiguity */, std::move(tried)};
}
}
if (!procedure->IsElemental()) {
@@ -3017,8 +3056,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
} else {
elemental = specific;
}
- crtMatchingDistance = ComputeCudaMatchingDistance(
- context_.languageFeatures(), *procedure, localActuals);
+ crtMatchingDistance = std::move(d);
}
}
}
>From 2e0c6d04e9756564dc1510fb48f1c458aba399a7 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 11 Feb 2026 13:18:42 -0800
Subject: [PATCH 2/5] simplify comparator
---
flang/lib/Semantics/expression.cpp | 19 +++++--------------
flang/test/Semantics/cuf17.cuf | 30 ++++++++++++++++++++++++++++++
2 files changed, 35 insertions(+), 14 deletions(-)
create mode 100644 flang/test/Semantics/cuf17.cuf
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b4dac18b17a10..a0bcd6bd8a39e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2812,22 +2812,13 @@ static int CompareCudaMatchingDistance(
if (x.isInfinite) {
return 0;
}
- std::size_t n{x.perArg.size()};
- if (n > y.perArg.size()) {
- n = y.perArg.size();
- }
- for (std::size_t i{0}; i < n; ++i) {
- if (x.perArg[i] < y.perArg[i]) {
- return -1;
- }
- if (x.perArg[i] > y.perArg[i]) {
- return 1;
- }
- }
- if (x.perArg.size() < y.perArg.size()) {
+ CHECK(x.perArg.size() == y.perArg.size());
+ if (std::lexicographical_compare(
+ x.perArg.begin(), x.perArg.end(), y.perArg.begin(), y.perArg.end())) {
return -1;
}
- if (x.perArg.size() > y.perArg.size()) {
+ if (std::lexicographical_compare(
+ y.perArg.begin(), y.perArg.end(), x.perArg.begin(), x.perArg.end())) {
return 1;
}
return 0;
diff --git a/flang/test/Semantics/cuf17.cuf b/flang/test/Semantics/cuf17.cuf
new file mode 100644
index 0000000000000..4309510bba6cd
--- /dev/null
+++ b/flang/test/Semantics/cuf17.cuf
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+module matching_two_args
+ interface two_args
+ module procedure sub_managed_device
+ module procedure sub_unified_unified
+ end interface
+contains
+ subroutine sub_managed_device(a, b)
+ integer, managed :: a(:)
+ integer, device :: b(:)
+ end
+
+ subroutine sub_unified_unified(a, b)
+ integer, unified :: a(:)
+ integer, unified :: b(:)
+ end
+end module
+
+program test
+ use matching_two_args
+ integer, managed, allocatable :: a(:)
+ integer, managed, allocatable :: b(:)
+
+ allocate(a(10), b(10))
+ call two_args(a, b)
+end
+
+! CHECK: fir.call @_QMmatching_two_argsPsub_managed_device
+! CHECK-NOT: @_QMmatching_two_argsPsub_unified_unified
>From 726987066d5f2ffdb04027450d1a09d6cf6f6892 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 11 Feb 2026 15:08:24 -0800
Subject: [PATCH 3/5] use SmallVector
---
flang/lib/Semantics/expression.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index a0bcd6bd8a39e..be648ba4c5389 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -25,6 +25,7 @@
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Support/Fortran.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <functional>
@@ -2800,7 +2801,7 @@ static bool CheckCompatibleArguments(
static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
struct CudaMatchingDistance {
- std::vector<int> perArg;
+ llvm::SmallVector<int> perArg;
bool isInfinite{false};
};
>From e99f333050881194fd1456d83b5800460ce0026c Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 11 Feb 2026 19:55:08 -0800
Subject: [PATCH 4/5] revert SmallVector back to std::vector
---
flang/lib/Semantics/expression.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index be648ba4c5389..a0bcd6bd8a39e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -25,7 +25,6 @@
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Support/Fortran.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <functional>
@@ -2801,7 +2800,7 @@ static bool CheckCompatibleArguments(
static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
struct CudaMatchingDistance {
- llvm::SmallVector<int> perArg;
+ std::vector<int> perArg;
bool isInfinite{false};
};
>From 07945be01fecf8935cf818b75c261ffe4bd6af6a Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Thu, 12 Feb 2026 18:45:02 -0800
Subject: [PATCH 5/5] add comments
---
flang/lib/Semantics/expression.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index a0bcd6bd8a39e..a4b32e4674e1d 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2804,6 +2804,10 @@ struct CudaMatchingDistance {
bool isInfinite{false};
};
+// Compare CUDA matching distances using lexicographical comparison of per-argument
+// distances. This is needed to differentiate procedures that would have similar
+// total distance when summing the per-argument weights, allowing the compiler to
+// select the best match based on argument-by-argument comparison.
static int CompareCudaMatchingDistance(
const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
if (x.isInfinite != y.isInfinite) {
More information about the flang-commits
mailing list