[flang-commits] [flang] [flang][cuda] Resolve generic tie‑breaks using per‑arg CUDA distance (PR #181038)

Zhen Wang via flang-commits flang-commits at lists.llvm.org
Thu Feb 12 18:45:55 PST 2026


https://github.com/wangzpgi updated https://github.com/llvm/llvm-project/pull/181038

>From 6ea1d125fef749b84f2ec982b3d20073ba2ab9f9 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 9 Feb 2026 14:51:47 -0800
Subject: [PATCH 1/5] cuda matching distance

---
 flang/lib/Semantics/expression.cpp | 76 ++++++++++++++++++++++--------
 1 file changed, 57 insertions(+), 19 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b0235469c9de6..b4dac18b17a10 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2799,6 +2799,40 @@ static bool CheckCompatibleArguments(
 
 static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
 
+struct CudaMatchingDistance {
+  std::vector<int> perArg;
+  bool isInfinite{false};
+};
+
+static int CompareCudaMatchingDistance(
+    const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
+  if (x.isInfinite != y.isInfinite) {
+    return x.isInfinite ? 1 : -1;
+  }
+  if (x.isInfinite) {
+    return 0;
+  }
+  std::size_t n{x.perArg.size()};
+  if (n > y.perArg.size()) {
+    n = y.perArg.size();
+  }
+  for (std::size_t i{0}; i < n; ++i) {
+    if (x.perArg[i] < y.perArg[i]) {
+      return -1;
+    }
+    if (x.perArg[i] > y.perArg[i]) {
+      return 1;
+    }
+  }
+  if (x.perArg.size() < y.perArg.size()) {
+    return -1;
+  }
+  if (x.perArg.size() > y.perArg.size()) {
+    return 1;
+  }
+  return 0;
+}
+
 // Compute the matching distance as described in section 3.2.3 of the CUDA
 // Fortran references.
 static int GetMatchingDistance(const common::LanguageFeatureControl &features,
@@ -2882,20 +2916,23 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
   return cudaInfMatchingValue;
 }
 
-static int ComputeCudaMatchingDistance(
+static CudaMatchingDistance ComputeCudaMatchingDistance(
     const common::LanguageFeatureControl &features,
     const characteristics::Procedure &procedure,
     const ActualArguments &actuals) {
   const auto &dummies{procedure.dummyArguments};
   CHECK(dummies.size() == actuals.size());
-  int distance{0};
+  CudaMatchingDistance distance;
+  distance.perArg.reserve(dummies.size());
   for (std::size_t i{0}; i < dummies.size(); ++i) {
     const characteristics::DummyArgument &dummy{dummies[i]};
     const std::optional<ActualArgument> &actual{actuals[i]};
     int d{GetMatchingDistance(features, dummy, actual)};
-    if (d == cudaInfMatchingValue)
-      return d;
-    distance += d;
+    if (d == cudaInfMatchingValue) {
+      distance.isInfinite = true;
+      return distance;
+    }
+    distance.perArg.push_back(d);
   }
   return distance;
 }
@@ -2967,7 +3004,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
   const Symbol *nonElemental{nullptr}; // matching non-elemental specific
   const auto *genericDetails{ultimate.detailsIf<semantics::GenericDetails>()};
   if (genericDetails && !explicitIntrinsic) {
-    int crtMatchingDistance{cudaInfMatchingValue};
+    std::optional<CudaMatchingDistance> crtMatchingDistance;
     for (const Symbol &specific0 : genericDetails->specificProcs()) {
       const Symbol &specific1{BypassGeneric(specific0)};
       if (isSubroutine != !IsFunction(specific1)) {
@@ -2992,23 +3029,25 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
                 context_, false /* no integer conversions */) &&
             CheckCompatibleArguments(
                 *procedure, localActuals, foldingContext_)) {
+          CudaMatchingDistance d{ComputeCudaMatchingDistance(
+              context_.languageFeatures(), *procedure, localActuals)};
           if ((procedure->IsElemental() && elemental) ||
               (!procedure->IsElemental() && nonElemental)) {
-            int d{ComputeCudaMatchingDistance(
-                context_.languageFeatures(), *procedure, localActuals)};
-            if (d != crtMatchingDistance) {
-              if (d > crtMatchingDistance) {
+            if (crtMatchingDistance) {
+              int cmp{CompareCudaMatchingDistance(d, *crtMatchingDistance)};
+              if (cmp > 0) {
                 continue;
               }
+              if (cmp == 0) {
+                // 16.9.144(6): a bare NULL() is not allowed as an actual
+                // argument to a generic procedure if the specific procedure
+                // cannot be unambiguously distinguished
+                // Underspecified external procedure actual arguments can
+                // also lead to ambiguity.
+                return {nullptr, true /* due to ambiguity */, std::move(tried)};
+              }
               // Matching distance is smaller than the previously matched
               // specific. Let it go through so the current procedure is picked.
-            } else {
-              // 16.9.144(6): a bare NULL() is not allowed as an actual
-              // argument to a generic procedure if the specific procedure
-              // cannot be unambiguously distinguished
-              // Underspecified external procedure actual arguments can
-              // also lead to ambiguity.
-              return {nullptr, true /* due to ambiguity */, std::move(tried)};
             }
           }
           if (!procedure->IsElemental()) {
@@ -3017,8 +3056,7 @@ auto ExpressionAnalyzer::ResolveGeneric(const Symbol &symbol,
           } else {
             elemental = specific;
           }
-          crtMatchingDistance = ComputeCudaMatchingDistance(
-              context_.languageFeatures(), *procedure, localActuals);
+          crtMatchingDistance = std::move(d);
         }
       }
     }

>From 2e0c6d04e9756564dc1510fb48f1c458aba399a7 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 11 Feb 2026 13:18:42 -0800
Subject: [PATCH 2/5] simplify comparator

---
 flang/lib/Semantics/expression.cpp | 19 +++++--------------
 flang/test/Semantics/cuf17.cuf     | 30 ++++++++++++++++++++++++++++++
 2 files changed, 35 insertions(+), 14 deletions(-)
 create mode 100644 flang/test/Semantics/cuf17.cuf

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b4dac18b17a10..a0bcd6bd8a39e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2812,22 +2812,13 @@ static int CompareCudaMatchingDistance(
   if (x.isInfinite) {
     return 0;
   }
-  std::size_t n{x.perArg.size()};
-  if (n > y.perArg.size()) {
-    n = y.perArg.size();
-  }
-  for (std::size_t i{0}; i < n; ++i) {
-    if (x.perArg[i] < y.perArg[i]) {
-      return -1;
-    }
-    if (x.perArg[i] > y.perArg[i]) {
-      return 1;
-    }
-  }
-  if (x.perArg.size() < y.perArg.size()) {
+  CHECK(x.perArg.size() == y.perArg.size());
+  if (std::lexicographical_compare(
+          x.perArg.begin(), x.perArg.end(), y.perArg.begin(), y.perArg.end())) {
     return -1;
   }
-  if (x.perArg.size() > y.perArg.size()) {
+  if (std::lexicographical_compare(
+          y.perArg.begin(), y.perArg.end(), x.perArg.begin(), x.perArg.end())) {
     return 1;
   }
   return 0;
diff --git a/flang/test/Semantics/cuf17.cuf b/flang/test/Semantics/cuf17.cuf
new file mode 100644
index 0000000000000..4309510bba6cd
--- /dev/null
+++ b/flang/test/Semantics/cuf17.cuf
@@ -0,0 +1,30 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+module matching_two_args
+  interface two_args
+    module procedure sub_managed_device
+    module procedure sub_unified_unified
+  end interface
+contains
+  subroutine sub_managed_device(a, b)
+    integer, managed :: a(:)
+    integer, device :: b(:)
+  end
+
+  subroutine sub_unified_unified(a, b)
+    integer, unified :: a(:)
+    integer, unified :: b(:)
+  end
+end module
+
+program test
+  use matching_two_args
+  integer, managed, allocatable :: a(:)
+  integer, managed, allocatable :: b(:)
+
+  allocate(a(10), b(10))
+  call two_args(a, b)
+end
+
+! CHECK: fir.call @_QMmatching_two_argsPsub_managed_device
+! CHECK-NOT: @_QMmatching_two_argsPsub_unified_unified

>From 726987066d5f2ffdb04027450d1a09d6cf6f6892 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 11 Feb 2026 15:08:24 -0800
Subject: [PATCH 3/5] use SmallVector

---
 flang/lib/Semantics/expression.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index a0bcd6bd8a39e..be648ba4c5389 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -25,6 +25,7 @@
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "flang/Support/Fortran.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <functional>
@@ -2800,7 +2801,7 @@ static bool CheckCompatibleArguments(
 static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
 
 struct CudaMatchingDistance {
-  std::vector<int> perArg;
+  llvm::SmallVector<int> perArg;
   bool isInfinite{false};
 };
 

>From e99f333050881194fd1456d83b5800460ce0026c Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 11 Feb 2026 19:55:08 -0800
Subject: [PATCH 4/5] revert SmallVector back to std::vector

---
 flang/lib/Semantics/expression.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index be648ba4c5389..a0bcd6bd8a39e 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -25,7 +25,6 @@
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "flang/Support/Fortran.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <functional>
@@ -2801,7 +2800,7 @@ static bool CheckCompatibleArguments(
 static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
 
 struct CudaMatchingDistance {
-  llvm::SmallVector<int> perArg;
+  std::vector<int> perArg;
   bool isInfinite{false};
 };
 

>From 07945be01fecf8935cf818b75c261ffe4bd6af6a Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Thu, 12 Feb 2026 18:45:02 -0800
Subject: [PATCH 5/5] add comments

---
 flang/lib/Semantics/expression.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index a0bcd6bd8a39e..a4b32e4674e1d 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2804,6 +2804,10 @@ struct CudaMatchingDistance {
   bool isInfinite{false};
 };
 
+// Compare CUDA matching distances using lexicographical comparison of per-argument
+// distances. This is needed to differentiate procedures that would have similar
+// total distance when summing the per-argument weights, allowing the compiler to
+// select the best match based on argument-by-argument comparison.
 static int CompareCudaMatchingDistance(
     const CudaMatchingDistance &x, const CudaMatchingDistance &y) {
   if (x.isInfinite != y.isInfinite) {



More information about the flang-commits mailing list