[flang-commits] [flang] [flang] [cuda] Resolve generic tie‑breaks using per‑arg CUDA distance (PR #181038)

via flang-commits flang-commits at lists.llvm.org
Wed Feb 11 14:47:10 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Zhen Wang (wangzpgi)

<details>
<summary>Changes</summary>

Switch CUDA generic resolution to compare per‑argument CUDA matching distances lexicographically instead of summing them. This prevents ambiguous matches when different specifics yield the same total distance.

---
Full diff: https://github.com/llvm/llvm-project/pull/181038.diff


2 Files Affected:

- (modified) flang/lib/Semantics/expression.cpp (+48-19) 
- (added) flang/test/Semantics/cuf17.cuf (+30) 


``````````diff
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b0235469c9de6..a0bcd6bd8a39e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2799,6 +2799,31 @@ static bool CheckCompatibleArguments(
 
 static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
 
+struct CudaMatchingDistance {
+  std::vector<int> perArg;
+  bool isInfinite{false};
+};
+
+static int CompareCudaMatchingDistance(
+    const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
+  if (x.isInfinite != y.isInfinite) {
+    return x.isInfinite ? 1 : -1;
+  }
+  if (x.isInfinite) {
+    return 0;
+  }
+  CHECK(x.perArg.size() == y.perArg.size());
+  if (std::lexicographical_compare(
+          x.perArg.begin(), x.perArg.end(), y.perArg.begin(), y.perArg.end())) {
+    return -1;
+  }
+  if (std::lexicographical_compare(
+          y.perArg.begin(), y.perArg.end(), x.perArg.begin(), x.perArg.end())) {
+    return 1;
+  }
+  return 0;
+}
+
 // Compute the matching distance as described in section 3.2.3 of the CUDA
 // Fortran references.
 static int GetMatchingDistance(const common::LanguageFeatureControl &features,
@@ -2882,20 +2907,23 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
   return cudaInfMatchingValue;
 }
 
-static int ComputeCudaMatchingDistance(
+static CudaMatchingDistance ComputeCudaMatchingDistance(
     const common::LanguageFeatureControl &features,
     const characteristics::Procedure &procedure,
     const ActualArguments &actuals) {
   const auto &dummies{procedure.dummyArguments};
   CHECK(dummies.size() == actuals.size());
-  int distance{0};
+  CudaMatchingDistance distance;
+  distance.perArg.reserve(dummies.size());
   for (std::size_t i{0}; i < dummies.size(); ++i) {
     const characteristics::DummyArgument &dummy{dummies[i]};
     const std::optional<ActualArgument> &actual{actuals[i]};
     int d{GetMatchingDistance(features, dummy, actual)};
-    if (d == cudaInfMatchingValue)
-      return d;
-    distance += d;
+    if (d == cudaInfMatchingValue) {
+      distance.isInfinite = true;
+      return distance;
+    }
+    distance.perArg.push_back(d);
   }
   return distance;
 }
@@ -2967,7 +2995,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
   const Symbol *nonElemental{nullptr}; // matching non-elemental specific
   const auto *genericDetails{ultimate.detailsIf<semantics::GenericDetails>()};
   if (genericDetails && !explicitIntrinsic) {
-    int crtMatchingDistance{cudaInfMatchingValue};
+    std::optional<CudaMatchingDistance> crtMatchingDistance;
     for (const Symbol &specific0 : genericDetails->specificProcs()) {
       const Symbol &specific1{BypassGeneric(specific0)};
       if (isSubroutine != !IsFunction(specific1)) {
@@ -2992,23 +3020,25 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
                 context_, false /* no integer conversions */) &&
             CheckCompatibleArguments(
                 *procedure, localActuals, foldingContext_)) {
+          CudaMatchingDistance d{ComputeCudaMatchingDistance(
+              context_.languageFeatures(), *procedure, localActuals)};
           if ((procedure->IsElemental() && elemental) ||
               (!procedure->IsElemental() && nonElemental)) {
-            int d{ComputeCudaMatchingDistance(
-                context_.languageFeatures(), *procedure, localActuals)};
-            if (d != crtMatchingDistance) {
-              if (d > crtMatchingDistance) {
+            if (crtMatchingDistance) {
+              int cmp{CompareCudaMatchingDistance(d, *crtMatchingDistance)};
+              if (cmp > 0) {
                 continue;
               }
+              if (cmp == 0) {
+                // 16.9.144(6): a bare NULL() is not allowed as an actual
+                // argument to a generic procedure if the specific procedure
+                // cannot be unambiguously distinguished
+                // Underspecified external procedure actual arguments can
+                // also lead to ambiguity.
+                return {nullptr, true /* due to ambiguity */, std::move(tried)};
+              }
               // Matching distance is smaller than the previously matched
               // specific. Let it go through so the current procedure is picked.
-            } else {
-              // 16.9.144(6): a bare NULL() is not allowed as an actual
-              // argument to a generic procedure if the specific procedure
-              // cannot be unambiguously distinguished
-              // Underspecified external procedure actual arguments can
-              // also lead to ambiguity.
-              return {nullptr, true /* due to ambiguity */, std::move(tried)};
             }
           }
           if (!procedure->IsElemental()) {
@@ -3017,8 +3047,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
           } else {
             elemental = specific;
           }
-          crtMatchingDistance = ComputeCudaMatchingDistance(
-              context_.languageFeatures(), *procedure, localActuals);
+          crtMatchingDistance = std::move(d);
         }
       }
     }
diff --git a/flang/test/Semantics/cuf17.cuf b/flang/test/Semantics/cuf17.cuf
new file mode 100644
index 0000000000000..4309510bba6cd
--- /dev/null
+++ b/flang/test/Semantics/cuf17.cuf
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+module matching_two_args
+  interface two_args
+    module procedure sub_managed_device
+    module procedure sub_unified_unified
+  end interface
+contains
+  subroutine sub_managed_device(a, b)
+    integer, managed :: a(:)
+    integer, device :: b(:)
+  end
+
+  subroutine sub_unified_unified(a, b)
+    integer, unified :: a(:)
+    integer, unified :: b(:)
+  end
+end module
+
+program test
+  use matching_two_args
+  integer, managed, allocatable :: a(:)
+  integer, managed, allocatable :: b(:)
+
+  allocate(a(10), b(10))
+  call two_args(a, b)
+end
+
+! CHECK: fir.call @_QMmatching_two_argsPsub_managed_device
+! CHECK-NOT: @_QMmatching_two_argsPsub_unified_unified

``````````

</details>


https://github.com/llvm/llvm-project/pull/181038


More information about the flang-commits mailing list