[flang-commits] [flang] [flang] [cuda] Resolve generic tie‑breaks using per‑arg CUDA distance (PR #181038)
via flang-commits
flang-commits at lists.llvm.org
Wed Feb 11 14:47:10 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Zhen Wang (wangzpgi)
<details>
<summary>Changes</summary>
Switch CUDA generic resolution to compare per‑argument CUDA matching distances lexicographically instead of summing them. This prevents ambiguous matches when different specifics yield the same total distance.
---
Full diff: https://github.com/llvm/llvm-project/pull/181038.diff
2 Files Affected:
- (modified) flang/lib/Semantics/expression.cpp (+48-19)
- (added) flang/test/Semantics/cuf17.cuf (+30)
``````````diff
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b0235469c9de6..a0bcd6bd8a39e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2799,6 +2799,31 @@ static bool CheckCompatibleArguments(
static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
+struct CudaMatchingDistance {
+ std::vector<int> perArg;
+ bool isInfinite{false};
+};
+
+static int CompareCudaMatchingDistance(
+ const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
+ if (x.isInfinite != y.isInfinite) {
+ return x.isInfinite ? 1 : -1;
+ }
+ if (x.isInfinite) {
+ return 0;
+ }
+ CHECK(x.perArg.size() == y.perArg.size());
+ if (std::lexicographical_compare(
+ x.perArg.begin(), x.perArg.end(), y.perArg.begin(), y.perArg.end())) {
+ return -1;
+ }
+ if (std::lexicographical_compare(
+ y.perArg.begin(), y.perArg.end(), x.perArg.begin(), x.perArg.end())) {
+ return 1;
+ }
+ return 0;
+}
+
// Compute the matching distance as described in section 3.2.3 of the CUDA
// Fortran references.
static int GetMatchingDistance(const common::LanguageFeatureControl &features,
@@ -2882,20 +2907,23 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
return cudaInfMatchingValue;
}
-static int ComputeCudaMatchingDistance(
+static CudaMatchingDistance ComputeCudaMatchingDistance(
const common::LanguageFeatureControl &features,
const characteristics::Procedure &procedure,
const ActualArguments &actuals) {
const auto &dummies{procedure.dummyArguments};
CHECK(dummies.size() == actuals.size());
- int distance{0};
+ CudaMatchingDistance distance;
+ distance.perArg.reserve(dummies.size());
for (std::size_t i{0}; i < dummies.size(); ++i) {
const characteristics::DummyArgument &dummy{dummies[i]};
const std::optional<ActualArgument> &actual{actuals[i]};
int d{GetMatchingDistance(features, dummy, actual)};
- if (d == cudaInfMatchingValue)
- return d;
- distance += d;
+ if (d == cudaInfMatchingValue) {
+ distance.isInfinite = true;
+ return distance;
+ }
+ distance.perArg.push_back(d);
}
return distance;
}
@@ -2967,7 +2995,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const auto *genericDetails{ultimate.detailsIf<semantics::GenericDetails>()};
if (genericDetails && !explicitIntrinsic) {
- int crtMatchingDistance{cudaInfMatchingValue};
+ std::optional<CudaMatchingDistance> crtMatchingDistance;
for (const Symbol &specific0 : genericDetails->specificProcs()) {
const Symbol &specific1{BypassGeneric(specific0)};
if (isSubroutine != !IsFunction(specific1)) {
@@ -2992,23 +3020,25 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
context_, false /* no integer conversions */) &&
CheckCompatibleArguments(
*procedure, localActuals, foldingContext_)) {
+ CudaMatchingDistance d{ComputeCudaMatchingDistance(
+ context_.languageFeatures(), *procedure, localActuals)};
if ((procedure->IsElemental() && elemental) ||
(!procedure->IsElemental() && nonElemental)) {
- int d{ComputeCudaMatchingDistance(
- context_.languageFeatures(), *procedure, localActuals)};
- if (d != crtMatchingDistance) {
- if (d > crtMatchingDistance) {
+ if (crtMatchingDistance) {
+ int cmp{CompareCudaMatchingDistance(d, *crtMatchingDistance)};
+ if (cmp > 0) {
continue;
}
+ if (cmp == 0) {
+ // 16.9.144(6): a bare NULL() is not allowed as an actual
+ // argument to a generic procedure if the specific procedure
+ // cannot be unambiguously distinguished
+ // Underspecified external procedure actual arguments can
+ // also lead to ambiguity.
+ return {nullptr, true /* due to ambiguity */, std::move(tried)};
+ }
// Matching distance is smaller than the previously matched
// specific. Let it go through so the current procedure is picked.
- } else {
- // 16.9.144(6): a bare NULL() is not allowed as an actual
- // argument to a generic procedure if the specific procedure
- // cannot be unambiguously distinguished
- // Underspecified external procedure actual arguments can
- // also lead to ambiguity.
- return {nullptr, true /* due to ambiguity */, std::move(tried)};
}
}
if (!procedure->IsElemental()) {
@@ -3017,8 +3047,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
} else {
elemental = specific;
}
- crtMatchingDistance = ComputeCudaMatchingDistance(
- context_.languageFeatures(), *procedure, localActuals);
+ crtMatchingDistance = std::move(d);
}
}
}
diff --git a/flang/test/Semantics/cuf17.cuf b/flang/test/Semantics/cuf17.cuf
new file mode 100644
index 0000000000000..4309510bba6cd
--- /dev/null
+++ b/flang/test/Semantics/cuf17.cuf
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+module matching_two_args
+ interface two_args
+ module procedure sub_managed_device
+ module procedure sub_unified_unified
+ end interface
+contains
+ subroutine sub_managed_device(a, b)
+ integer, managed :: a(:)
+ integer, device :: b(:)
+ end
+
+ subroutine sub_unified_unified(a, b)
+ integer, unified :: a(:)
+ integer, unified :: b(:)
+ end
+end module
+
+program test
+ use matching_two_args
+ integer, managed, allocatable :: a(:)
+ integer, managed, allocatable :: b(:)
+
+ allocate(a(10), b(10))
+ call two_args(a, b)
+end
+
+! CHECK: fir.call @_QMmatching_two_argsPsub_managed_device
+! CHECK-NOT: @_QMmatching_two_argsPsub_unified_unified
``````````
</details>
https://github.com/llvm/llvm-project/pull/181038
More information about the flang-commits
mailing list