[flang-commits] [flang] [flang][cuda] Compute matching distance in generic resolution (PR #90774)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed May 1 14:25:48 PDT 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/90774

>From 07cd30a898e4177a291ee5e6b157c658bcdc51bc Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 May 2024 13:32:55 -0700
Subject: [PATCH 1/2] [flang][cuda] Compute matching distance in generic
 resolution

---
 flang/lib/Semantics/expression.cpp | 107 +++++++++++++++++++++++++++--
 flang/test/Semantics/cuf13.cuf     |  30 +++++---
 2 files changed, 123 insertions(+), 14 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b8396209fc6854..0ba0871f530169 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2492,6 +2492,89 @@ static bool CheckCompatibleArguments(
   return true;
 }
 
+// Compute the matching distance as described in section 3.2.3 of the CUDA
+// Fortran references.
+static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
+    const std::optional<ActualArgument> &actual) {
+  std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
+  if (actual) {
+    if (auto *expr{actual->UnwrapExpr()}) {
+      const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
+      if (actualLastSymbol) {
+        actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
+        if (const auto *actualObject{actualLastSymbol
+                    ? actualLastSymbol
+                          ->detailsIf<semantics::ObjectEntityDetails>()
+                    : nullptr}) {
+          actualDataAttr = actualObject->cudaDataAttr();
+        }
+      }
+    }
+  }
+
+  common::visit(common::visitors{
+                    [&](const characteristics::DummyDataObject &object) {
+                      dummyDataAttr = object.cudaDataAttr;
+                    },
+                    [&](const auto &) {},
+                },
+      dummy.u);
+
+  if (!dummyDataAttr) {
+    if (!actualDataAttr) {
+      return 0;
+    } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+        *actualDataAttr == common::CUDADataAttr::Unified) {
+      return 3;
+    }
+  } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
+    if (!actualDataAttr) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+      return 0;
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+        *actualDataAttr == common::CUDADataAttr::Unified) {
+      return 2;
+    }
+  } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
+    if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+      return 0;
+    } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+      return 1;
+    }
+  } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
+    if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+      return 1;
+    } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+      return 0;
+    }
+  }
+  return std::numeric_limits<int>::max();
+}
+
+static int ComputeCudaMatchingDistance(
+    const characteristics::Procedure &procedure,
+    const ActualArguments &actuals) {
+  const auto &dummies{procedure.dummyArguments};
+  CHECK(dummies.size() == actuals.size());
+  int distance = 0;
+  for (std::size_t i{0}; i < dummies.size(); ++i) {
+    const characteristics::DummyArgument &dummy{dummies[i]};
+    const std::optional<ActualArgument> &actual{actuals[i]};
+    int d = GetMatchingDistance(dummy, actual);
+    if (d == std::numeric_limits<int>::max())
+      return d;
+    distance += d;
+  }
+  return distance;
+}
+
 // Handles a forward reference to a module function from what must
 // be a specification expression.  Return false if the symbol is
 // an invalid forward reference.
@@ -2539,6 +2622,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
   const Symbol *elemental{nullptr}; // matching elemental specific proc
   const Symbol *nonElemental{nullptr}; // matching non-elemental specific
   const Symbol &ultimate{symbol.GetUltimate()};
+  int crtMatchingDistance{std::numeric_limits<int>::max()};
   // Check for a match with an explicit INTRINSIC
   if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
     parser::Messages buffer;
@@ -2575,12 +2659,21 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
             CheckCompatibleArguments(*procedure, localActuals)) {
           if ((procedure->IsElemental() && elemental) ||
               (!procedure->IsElemental() && nonElemental)) {
-            // 16.9.144(6): a bare NULL() is not allowed as an actual
-            // argument to a generic procedure if the specific procedure
-            // cannot be unambiguously distinguished
-            // Underspecified external procedure actual arguments can
-            // also lead to ambiguity.
-            return {nullptr, true /* due to ambiguity */};
+            int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
+            if (d != crtMatchingDistance) {
+              if (d > crtMatchingDistance) {
+                continue;
+              }
+              // Matching distance is smaller than the previously matched
+              // specific. Let it go thourgh so the current procedure is picked.
+            } else {
+              // 16.9.144(6): a bare NULL() is not allowed as an actual
+              // argument to a generic procedure if the specific procedure
+              // cannot be unambiguously distinguished
+              // Underspecified external procedure actual arguments can
+              // also lead to ambiguity.
+              return {nullptr, true /* due to ambiguity */};
+            }
           }
           if (!procedure->IsElemental()) {
             // takes priority over elemental match
@@ -2588,6 +2681,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
           } else {
             elemental = &specific;
           }
+          crtMatchingDistance =
+              ComputeCudaMatchingDistance(*procedure, localActuals);
         }
       }
     }
diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf
index 6db829002fae67..0662a51ff5f22a 100644
--- a/flang/test/Semantics/cuf13.cuf
+++ b/flang/test/Semantics/cuf13.cuf
@@ -1,13 +1,11 @@
-! RUN: %python %S/test_errors.py %s %flang_fc1
+! RUN: %flang -fc1 -x cuda -fdebug-unparse %s | FileCheck %s
 
 module matching
   interface sub
     module procedure sub_host
     module procedure sub_device
-  end interface
-
-  interface subman
-    module procedure sub_host
+    module procedure sub_managed
+    module procedure sub_unified
   end interface
 
 contains
@@ -19,6 +17,13 @@ contains
     integer, device :: a(:)
   end
 
+  subroutine sub_managed(a)
+    integer, managed :: a(:)
+  end
+
+  subroutine sub_unified(a)
+    integer, unified :: a(:)
+  end
 end module
 
 program m
@@ -26,12 +31,21 @@ program m
 
   integer, pinned, allocatable :: a(:)
   integer, managed, allocatable :: b(:)
+  integer, unified, allocatable :: u(:)
+  integer, device :: d(10)
   logical :: plog
   allocate(a(100), pinned = plog)
   allocate(b(200))
+  allocate(u(100))
 
-  call sub(a)
-
-  call subman(b)
+  call sub(a) ! Should resolve to sub_host
+  call sub(b) ! Should resolve to sub_managed
+  call sub(u) ! Should resolve to sub_unified
+  call sub(d) ! Should resolve to sub_device
 
 end
+
+! CHECK: CALL sub_host
+! CHECK: CALL sub_managed
+! CHECK: CALL sub_unified
+! CHECK: CALL sub_device

>From a408ca83296a616865506e790d2772f5cd4be1f6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 May 2024 14:25:36 -0700
Subject: [PATCH 2/2] Use braces and make max int a named constant

---
 flang/lib/Semantics/expression.cpp | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 0ba0871f530169..d109443c6d5c91 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2492,6 +2492,8 @@ static bool CheckCompatibleArguments(
   return true;
 }
 
+static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
+
 // Compute the matching distance as described in section 3.2.3 of the CUDA
 // Fortran references.
 static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
@@ -2524,14 +2526,14 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
     if (!actualDataAttr) {
       return 0;
     } else if (*actualDataAttr == common::CUDADataAttr::Device) {
-      return std::numeric_limits<int>::max();
+      return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
         *actualDataAttr == common::CUDADataAttr::Unified) {
       return 3;
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
     if (!actualDataAttr) {
-      return std::numeric_limits<int>::max();
+      return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Device) {
       return 0;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
@@ -2540,7 +2542,7 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
     if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
-      return std::numeric_limits<int>::max();
+      return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
       return 0;
     } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
@@ -2548,14 +2550,14 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
     if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
-      return std::numeric_limits<int>::max();
+      return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
       return 1;
     } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
       return 0;
     }
   }
-  return std::numeric_limits<int>::max();
+  return cudaInfMatchingValue;
 }
 
 static int ComputeCudaMatchingDistance(
@@ -2563,12 +2565,12 @@ static int ComputeCudaMatchingDistance(
     const ActualArguments &actuals) {
   const auto &dummies{procedure.dummyArguments};
   CHECK(dummies.size() == actuals.size());
-  int distance = 0;
+  int distance{0};
   for (std::size_t i{0}; i < dummies.size(); ++i) {
     const characteristics::DummyArgument &dummy{dummies[i]};
     const std::optional<ActualArgument> &actual{actuals[i]};
-    int d = GetMatchingDistance(dummy, actual);
-    if (d == std::numeric_limits<int>::max())
+    int d{GetMatchingDistance(dummy, actual)};
+    if (d == cudaInfMatchingValue)
       return d;
     distance += d;
   }
@@ -2622,7 +2624,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
   const Symbol *elemental{nullptr}; // matching elemental specific proc
   const Symbol *nonElemental{nullptr}; // matching non-elemental specific
   const Symbol &ultimate{symbol.GetUltimate()};
-  int crtMatchingDistance{std::numeric_limits<int>::max()};
+  int crtMatchingDistance{cudaInfMatchingValue};
   // Check for a match with an explicit INTRINSIC
   if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
     parser::Messages buffer;



More information about the flang-commits mailing list