[flang-commits] [flang] [flang][cuda] Compute matching distance in generic resolution (PR #90774)
via flang-commits
flang-commits at lists.llvm.org
Wed May 1 13:36:15 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Implement the matching distance as described here: https://docs.nvidia.com/hpc-sdk/archive/24.3/compilers/cuda-fortran-prog-guide/index.html#cfref-var-attr-unified-data
Generic resolved to the smallest distance.
---
Full diff: https://github.com/llvm/llvm-project/pull/90774.diff
2 Files Affected:
- (modified) flang/lib/Semantics/expression.cpp (+101-6)
- (modified) flang/test/Semantics/cuf13.cuf (+22-8)
``````````diff
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b8396209fc6854..0ba0871f530169 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2492,6 +2492,89 @@ static bool CheckCompatibleArguments(
return true;
}
+// Compute the matching distance as described in section 3.2.3 of the CUDA
+// Fortran references.
+static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
+ const std::optional<ActualArgument> &actual) {
+ std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
+ if (actual) {
+ if (auto *expr{actual->UnwrapExpr()}) {
+ const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
+ if (actualLastSymbol) {
+ actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
+ if (const auto *actualObject{actualLastSymbol
+ ? actualLastSymbol
+ ->detailsIf<semantics::ObjectEntityDetails>()
+ : nullptr}) {
+ actualDataAttr = actualObject->cudaDataAttr();
+ }
+ }
+ }
+ }
+
+ common::visit(common::visitors{
+ [&](const characteristics::DummyDataObject &object) {
+ dummyDataAttr = object.cudaDataAttr;
+ },
+ [&](const auto &) {},
+ },
+ dummy.u);
+
+ if (!dummyDataAttr) {
+ if (!actualDataAttr) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+ *actualDataAttr == common::CUDADataAttr::Unified) {
+ return 3;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
+ if (!actualDataAttr) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+ *actualDataAttr == common::CUDADataAttr::Unified) {
+ return 2;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
+ if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+ return 1;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
+ if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+ return 1;
+ } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+ return 0;
+ }
+ }
+ return std::numeric_limits<int>::max();
+}
+
+static int ComputeCudaMatchingDistance(
+ const characteristics::Procedure &procedure,
+ const ActualArguments &actuals) {
+ const auto &dummies{procedure.dummyArguments};
+ CHECK(dummies.size() == actuals.size());
+ int distance = 0;
+ for (std::size_t i{0}; i < dummies.size(); ++i) {
+ const characteristics::DummyArgument &dummy{dummies[i]};
+ const std::optional<ActualArgument> &actual{actuals[i]};
+ int d = GetMatchingDistance(dummy, actual);
+ if (d == std::numeric_limits<int>::max())
+ return d;
+ distance += d;
+ }
+ return distance;
+}
+
// Handles a forward reference to a module function from what must
// be a specification expression. Return false if the symbol is
// an invalid forward reference.
@@ -2539,6 +2622,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
const Symbol *elemental{nullptr}; // matching elemental specific proc
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const Symbol &ultimate{symbol.GetUltimate()};
+ int crtMatchingDistance{std::numeric_limits<int>::max()};
// Check for a match with an explicit INTRINSIC
if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
parser::Messages buffer;
@@ -2575,12 +2659,21 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
CheckCompatibleArguments(*procedure, localActuals)) {
if ((procedure->IsElemental() && elemental) ||
(!procedure->IsElemental() && nonElemental)) {
- // 16.9.144(6): a bare NULL() is not allowed as an actual
- // argument to a generic procedure if the specific procedure
- // cannot be unambiguously distinguished
- // Underspecified external procedure actual arguments can
- // also lead to ambiguity.
- return {nullptr, true /* due to ambiguity */};
+ int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
+ if (d != crtMatchingDistance) {
+ if (d > crtMatchingDistance) {
+ continue;
+ }
+ // Matching distance is smaller than the previously matched
+ // specific. Let it go thourgh so the current procedure is picked.
+ } else {
+ // 16.9.144(6): a bare NULL() is not allowed as an actual
+ // argument to a generic procedure if the specific procedure
+ // cannot be unambiguously distinguished
+ // Underspecified external procedure actual arguments can
+ // also lead to ambiguity.
+ return {nullptr, true /* due to ambiguity */};
+ }
}
if (!procedure->IsElemental()) {
// takes priority over elemental match
@@ -2588,6 +2681,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
} else {
elemental = &specific;
}
+ crtMatchingDistance =
+ ComputeCudaMatchingDistance(*procedure, localActuals);
}
}
}
diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf
index 6db829002fae67..0662a51ff5f22a 100644
--- a/flang/test/Semantics/cuf13.cuf
+++ b/flang/test/Semantics/cuf13.cuf
@@ -1,13 +1,11 @@
-! RUN: %python %S/test_errors.py %s %flang_fc1
+! RUN: %flang -fc1 -x cuda -fdebug-unparse %s | FileCheck %s
module matching
interface sub
module procedure sub_host
module procedure sub_device
- end interface
-
- interface subman
- module procedure sub_host
+ module procedure sub_managed
+ module procedure sub_unified
end interface
contains
@@ -19,6 +17,13 @@ contains
integer, device :: a(:)
end
+ subroutine sub_managed(a)
+ integer, managed :: a(:)
+ end
+
+ subroutine sub_unified(a)
+ integer, unified :: a(:)
+ end
end module
program m
@@ -26,12 +31,21 @@ program m
integer, pinned, allocatable :: a(:)
integer, managed, allocatable :: b(:)
+ integer, unified, allocatable :: u(:)
+ integer, device :: d(10)
logical :: plog
allocate(a(100), pinned = plog)
allocate(b(200))
+ allocate(u(100))
- call sub(a)
-
- call subman(b)
+ call sub(a) ! Should resolve to sub_host
+ call sub(b) ! Should resolve to sub_managed
+ call sub(u) ! Should resolve to sub_unified
+ call sub(d) ! Should resolve to sub_device
end
+
+! CHECK: CALL sub_host
+! CHECK: CALL sub_managed
+! CHECK: CALL sub_unified
+! CHECK: CALL sub_device
``````````
</details>
https://github.com/llvm/llvm-project/pull/90774
More information about the flang-commits
mailing list