[flang-commits] [flang] 792ac8e - [flang][cuda] Resolve generic tie‑breaks using per‑arg CUDA distance (#181038)
via flang-commits
flang-commits at lists.llvm.org
Thu Feb 12 20:04:11 PST 2026
Author: Zhen Wang
Date: 2026-02-13T04:04:07Z
New Revision: 792ac8e94cdb1d8685d52cced32fcd513c0aec57
URL: https://github.com/llvm/llvm-project/commit/792ac8e94cdb1d8685d52cced32fcd513c0aec57
DIFF: https://github.com/llvm/llvm-project/commit/792ac8e94cdb1d8685d52cced32fcd513c0aec57.diff
LOG: [flang][cuda] Resolve generic tie‑breaks using per‑arg CUDA distance (#181038)
Switch CUDA generic resolution to compare per‑argument CUDA matching
distances lexicographically instead of summing them. This prevents
ambiguous matches when different specifics yield the same total
distance.
Added:
flang/test/Semantics/cuf17.cuf
Modified:
flang/lib/Semantics/expression.cpp
Removed:
################################################################################
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b0235469c9de6..81603b8c20ce3 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2799,6 +2799,36 @@ static bool CheckCompatibleArguments(
static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
+struct CudaMatchingDistance {
+ std::vector<int> perArg;
+ bool isInfinite{false};
+};
+
+// Compare CUDA matching distances using lexicographical comparison of
+// per-argument distances. This is needed to
diff erentiate procedures that would
+// have similar total distance when summing the per-argument weights, allowing
+// the compiler to select the best match based on argument-by-argument
+// comparison.
+static int CompareCudaMatchingDistance(
+ const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
+ if (x.isInfinite != y.isInfinite) {
+ return x.isInfinite ? 1 : -1;
+ }
+ if (x.isInfinite) {
+ return 0;
+ }
+ CHECK(x.perArg.size() == y.perArg.size());
+ if (std::lexicographical_compare(
+ x.perArg.begin(), x.perArg.end(), y.perArg.begin(), y.perArg.end())) {
+ return -1;
+ }
+ if (std::lexicographical_compare(
+ y.perArg.begin(), y.perArg.end(), x.perArg.begin(), x.perArg.end())) {
+ return 1;
+ }
+ return 0;
+}
+
// Compute the matching distance as described in section 3.2.3 of the CUDA
// Fortran references.
static int GetMatchingDistance(const common::LanguageFeatureControl &features,
@@ -2882,20 +2912,23 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
return cudaInfMatchingValue;
}
-static int ComputeCudaMatchingDistance(
+static CudaMatchingDistance ComputeCudaMatchingDistance(
const common::LanguageFeatureControl &features,
const characteristics::Procedure &procedure,
const ActualArguments &actuals) {
const auto &dummies{procedure.dummyArguments};
CHECK(dummies.size() == actuals.size());
- int distance{0};
+ CudaMatchingDistance distance;
+ distance.perArg.reserve(dummies.size());
for (std::size_t i{0}; i < dummies.size(); ++i) {
const characteristics::DummyArgument &dummy{dummies[i]};
const std::optional<ActualArgument> &actual{actuals[i]};
int d{GetMatchingDistance(features, dummy, actual)};
- if (d == cudaInfMatchingValue)
- return d;
- distance += d;
+ if (d == cudaInfMatchingValue) {
+ distance.isInfinite = true;
+ return distance;
+ }
+ distance.perArg.push_back(d);
}
return distance;
}
@@ -2967,7 +3000,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const auto *genericDetails{ultimate.detailsIf<semantics::GenericDetails>()};
if (genericDetails && !explicitIntrinsic) {
- int crtMatchingDistance{cudaInfMatchingValue};
+ std::optional<CudaMatchingDistance> crtMatchingDistance;
for (const Symbol &specific0 : genericDetails->specificProcs()) {
const Symbol &specific1{BypassGeneric(specific0)};
if (isSubroutine != !IsFunction(specific1)) {
@@ -2992,23 +3025,25 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
context_, false /* no integer conversions */) &&
CheckCompatibleArguments(
*procedure, localActuals, foldingContext_)) {
+ CudaMatchingDistance d{ComputeCudaMatchingDistance(
+ context_.languageFeatures(), *procedure, localActuals)};
if ((procedure->IsElemental() && elemental) ||
(!procedure->IsElemental() && nonElemental)) {
- int d{ComputeCudaMatchingDistance(
- context_.languageFeatures(), *procedure, localActuals)};
- if (d != crtMatchingDistance) {
- if (d > crtMatchingDistance) {
+ if (crtMatchingDistance) {
+ int cmp{CompareCudaMatchingDistance(d, *crtMatchingDistance)};
+ if (cmp > 0) {
continue;
}
+ if (cmp == 0) {
+ // 16.9.144(6): a bare NULL() is not allowed as an actual
+ // argument to a generic procedure if the specific procedure
+ // cannot be unambiguously distinguished
+ // Underspecified external procedure actual arguments can
+ // also lead to ambiguity.
+ return {nullptr, true /* due to ambiguity */, std::move(tried)};
+ }
// Matching distance is smaller than the previously matched
// specific. Let it go through so the current procedure is picked.
- } else {
- // 16.9.144(6): a bare NULL() is not allowed as an actual
- // argument to a generic procedure if the specific procedure
- // cannot be unambiguously distinguished
- // Underspecified external procedure actual arguments can
- // also lead to ambiguity.
- return {nullptr, true /* due to ambiguity */, std::move(tried)};
}
}
if (!procedure->IsElemental()) {
@@ -3017,8 +3052,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
} else {
elemental = specific;
}
- crtMatchingDistance = ComputeCudaMatchingDistance(
- context_.languageFeatures(), *procedure, localActuals);
+ crtMatchingDistance = std::move(d);
}
}
}
diff --git a/flang/test/Semantics/cuf17.cuf b/flang/test/Semantics/cuf17.cuf
new file mode 100644
index 0000000000000..4309510bba6cd
--- /dev/null
+++ b/flang/test/Semantics/cuf17.cuf
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+module matching_two_args
+ interface two_args
+ module procedure sub_managed_device
+ module procedure sub_unified_unified
+ end interface
+contains
+ subroutine sub_managed_device(a, b)
+ integer, managed :: a(:)
+ integer, device :: b(:)
+ end
+
+ subroutine sub_unified_unified(a, b)
+ integer, unified :: a(:)
+ integer, unified :: b(:)
+ end
+end module
+
+program test
+ use matching_two_args
+ integer, managed, allocatable :: a(:)
+ integer, managed, allocatable :: b(:)
+
+ allocate(a(10), b(10))
+ call two_args(a, b)
+end
+
+! CHECK: fir.call @_QMmatching_two_argsPsub_managed_device
+! CHECK-NOT: @_QMmatching_two_argsPsub_unified_unified
More information about the flang-commits
mailing list