[flang-commits] [flang] [flang][cuda] Compute matching distance in generic resolution (PR #90774)

via flang-commits flang-commits at lists.llvm.org
Wed May 1 13:36:15 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Implement the matching distance as described here: https://docs.nvidia.com/hpc-sdk/archive/24.3/compilers/cuda-fortran-prog-guide/index.html#cfref-var-attr-unified-data

Generic resolved to the smallest distance. 

---
Full diff: https://github.com/llvm/llvm-project/pull/90774.diff


2 Files Affected:

- (modified) flang/lib/Semantics/expression.cpp (+101-6) 
- (modified) flang/test/Semantics/cuf13.cuf (+22-8) 


``````````diff
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b8396209fc6854..0ba0871f530169 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2492,6 +2492,89 @@ static bool CheckCompatibleArguments(
   return true;
 }
 
+// Compute the matching distance as described in section 3.2.3 of the CUDA
+// Fortran references.
+static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
+    const std::optional<ActualArgument> &actual) {
+  std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
+  if (actual) {
+    if (auto *expr{actual->UnwrapExpr()}) {
+      const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
+      if (actualLastSymbol) {
+        actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
+        if (const auto *actualObject{actualLastSymbol
+                    ? actualLastSymbol
+                          ->detailsIf<semantics::ObjectEntityDetails>()
+                    : nullptr}) {
+          actualDataAttr = actualObject->cudaDataAttr();
+        }
+      }
+    }
+  }
+
+  common::visit(common::visitors{
+                    [&](const characteristics::DummyDataObject &object) {
+                      dummyDataAttr = object.cudaDataAttr;
+                    },
+                    [&](const auto &) {},
+                },
+      dummy.u);
+
+  if (!dummyDataAttr) {
+    if (!actualDataAttr) {
+      return 0;
+    } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+        *actualDataAttr == common::CUDADataAttr::Unified) {
+      return 3;
+    }
+  } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
+    if (!actualDataAttr) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+      return 0;
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+        *actualDataAttr == common::CUDADataAttr::Unified) {
+      return 2;
+    }
+  } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
+    if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+      return 0;
+    } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+      return 1;
+    }
+  } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
+    if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+      return std::numeric_limits<int>::max();
+    } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+      return 1;
+    } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+      return 0;
+    }
+  }
+  return std::numeric_limits<int>::max();
+}
+
+static int ComputeCudaMatchingDistance(
+    const characteristics::Procedure &procedure,
+    const ActualArguments &actuals) {
+  const auto &dummies{procedure.dummyArguments};
+  CHECK(dummies.size() == actuals.size());
+  int distance = 0;
+  for (std::size_t i{0}; i < dummies.size(); ++i) {
+    const characteristics::DummyArgument &dummy{dummies[i]};
+    const std::optional<ActualArgument> &actual{actuals[i]};
+    int d = GetMatchingDistance(dummy, actual);
+    if (d == std::numeric_limits<int>::max())
+      return d;
+    distance += d;
+  }
+  return distance;
+}
+
 // Handles a forward reference to a module function from what must
 // be a specification expression.  Return false if the symbol is
 // an invalid forward reference.
@@ -2539,6 +2622,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
   const Symbol *elemental{nullptr}; // matching elemental specific proc
   const Symbol *nonElemental{nullptr}; // matching non-elemental specific
   const Symbol &ultimate{symbol.GetUltimate()};
+  int crtMatchingDistance{std::numeric_limits<int>::max()};
   // Check for a match with an explicit INTRINSIC
   if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
     parser::Messages buffer;
@@ -2575,12 +2659,21 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
             CheckCompatibleArguments(*procedure, localActuals)) {
           if ((procedure->IsElemental() && elemental) ||
               (!procedure->IsElemental() && nonElemental)) {
-            // 16.9.144(6): a bare NULL() is not allowed as an actual
-            // argument to a generic procedure if the specific procedure
-            // cannot be unambiguously distinguished
-            // Underspecified external procedure actual arguments can
-            // also lead to ambiguity.
-            return {nullptr, true /* due to ambiguity */};
+            int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
+            if (d != crtMatchingDistance) {
+              if (d > crtMatchingDistance) {
+                continue;
+              }
+              // Matching distance is smaller than the previously matched
+              // specific. Let it go thourgh so the current procedure is picked.
+            } else {
+              // 16.9.144(6): a bare NULL() is not allowed as an actual
+              // argument to a generic procedure if the specific procedure
+              // cannot be unambiguously distinguished
+              // Underspecified external procedure actual arguments can
+              // also lead to ambiguity.
+              return {nullptr, true /* due to ambiguity */};
+            }
           }
           if (!procedure->IsElemental()) {
             // takes priority over elemental match
@@ -2588,6 +2681,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
           } else {
             elemental = &specific;
           }
+          crtMatchingDistance =
+              ComputeCudaMatchingDistance(*procedure, localActuals);
         }
       }
     }
diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf
index 6db829002fae67..0662a51ff5f22a 100644
--- a/flang/test/Semantics/cuf13.cuf
+++ b/flang/test/Semantics/cuf13.cuf
@@ -1,13 +1,11 @@
-! RUN: %python %S/test_errors.py %s %flang_fc1
+! RUN: %flang -fc1 -x cuda -fdebug-unparse %s | FileCheck %s
 
 module matching
   interface sub
     module procedure sub_host
     module procedure sub_device
-  end interface
-
-  interface subman
-    module procedure sub_host
+    module procedure sub_managed
+    module procedure sub_unified
   end interface
 
 contains
@@ -19,6 +17,13 @@ contains
     integer, device :: a(:)
   end
 
+  subroutine sub_managed(a)
+    integer, managed :: a(:)
+  end
+
+  subroutine sub_unified(a)
+    integer, unified :: a(:)
+  end
 end module
 
 program m
@@ -26,12 +31,21 @@ program m
 
   integer, pinned, allocatable :: a(:)
   integer, managed, allocatable :: b(:)
+  integer, unified, allocatable :: u(:)
+  integer, device :: d(10)
   logical :: plog
   allocate(a(100), pinned = plog)
   allocate(b(200))
+  allocate(u(100))
 
-  call sub(a)
-
-  call subman(b)
+  call sub(a) ! Should resolve to sub_host
+  call sub(b) ! Should resolve to sub_managed
+  call sub(u) ! Should resolve to sub_unified
+  call sub(d) ! Should resolve to sub_device
 
 end
+
+! CHECK: CALL sub_host
+! CHECK: CALL sub_managed
+! CHECK: CALL sub_unified
+! CHECK: CALL sub_device

``````````

</details>


https://github.com/llvm/llvm-project/pull/90774


More information about the flang-commits mailing list