[flang-commits] [flang] a131525 - [flang][cuda] Compute matching distance in generic resolution (#90774)
via flang-commits
flang-commits at lists.llvm.org
Thu May 2 09:07:45 PDT 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-05-02T09:07:04-07:00
New Revision: a131525908a908baff4cd01140dae158a307dc9e
URL: https://github.com/llvm/llvm-project/commit/a131525908a908baff4cd01140dae158a307dc9e
DIFF: https://github.com/llvm/llvm-project/commit/a131525908a908baff4cd01140dae158a307dc9e.diff
LOG: [flang][cuda] Compute matching distance in generic resolution (#90774)
Implement the matching distance as described here:
https://docs.nvidia.com/hpc-sdk/archive/24.3/compilers/cuda-fortran-prog-guide/index.html#cfref-var-attr-unified-data
Generic resolved to the smallest distance.
Added:
Modified:
flang/lib/Semantics/expression.cpp
flang/test/Semantics/cuf13.cuf
Removed:
################################################################################
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index f677973ca2753b..8445be581fddd5 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2494,6 +2494,91 @@ static bool CheckCompatibleArguments(
return true;
}
+static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
+
+// Compute the matching distance as described in section 3.2.3 of the CUDA
+// Fortran references.
+static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
+ const std::optional<ActualArgument> &actual) {
+ std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
+ if (actual) {
+ if (auto *expr{actual->UnwrapExpr()}) {
+ const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
+ if (actualLastSymbol) {
+ actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
+ if (const auto *actualObject{actualLastSymbol
+ ? actualLastSymbol
+ ->detailsIf<semantics::ObjectEntityDetails>()
+ : nullptr}) {
+ actualDataAttr = actualObject->cudaDataAttr();
+ }
+ }
+ }
+ }
+
+ common::visit(common::visitors{
+ [&](const characteristics::DummyDataObject &object) {
+ dummyDataAttr = object.cudaDataAttr;
+ },
+ [&](const auto &) {},
+ },
+ dummy.u);
+
+ if (!dummyDataAttr) {
+ if (!actualDataAttr) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+ return cudaInfMatchingValue;
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+ *actualDataAttr == common::CUDADataAttr::Unified) {
+ return 3;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
+ if (!actualDataAttr) {
+ return cudaInfMatchingValue;
+ } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+ *actualDataAttr == common::CUDADataAttr::Unified) {
+ return 2;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
+ if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+ return cudaInfMatchingValue;
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+ return 1;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
+ if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+ return cudaInfMatchingValue;
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+ return 1;
+ } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+ return 0;
+ }
+ }
+ return cudaInfMatchingValue;
+}
+
+static int ComputeCudaMatchingDistance(
+ const characteristics::Procedure &procedure,
+ const ActualArguments &actuals) {
+ const auto &dummies{procedure.dummyArguments};
+ CHECK(dummies.size() == actuals.size());
+ int distance{0};
+ for (std::size_t i{0}; i < dummies.size(); ++i) {
+ const characteristics::DummyArgument &dummy{dummies[i]};
+ const std::optional<ActualArgument> &actual{actuals[i]};
+ int d{GetMatchingDistance(dummy, actual)};
+ if (d == cudaInfMatchingValue)
+ return d;
+ distance += d;
+ }
+ return distance;
+}
+
// Handles a forward reference to a module function from what must
// be a specification expression. Return false if the symbol is
// an invalid forward reference.
@@ -2541,6 +2626,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
const Symbol *elemental{nullptr}; // matching elemental specific proc
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const Symbol &ultimate{symbol.GetUltimate()};
+ int crtMatchingDistance{cudaInfMatchingValue};
// Check for a match with an explicit INTRINSIC
if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
parser::Messages buffer;
@@ -2577,12 +2663,21 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
CheckCompatibleArguments(*procedure, localActuals)) {
if ((procedure->IsElemental() && elemental) ||
(!procedure->IsElemental() && nonElemental)) {
- // 16.9.144(6): a bare NULL() is not allowed as an actual
- // argument to a generic procedure if the specific procedure
- // cannot be unambiguously distinguished
- // Underspecified external procedure actual arguments can
- // also lead to ambiguity.
- return {nullptr, true /* due to ambiguity */};
+ int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
+ if (d != crtMatchingDistance) {
+ if (d > crtMatchingDistance) {
+ continue;
+ }
+ // Matching distance is smaller than the previously matched
+ // specific. Let it go thourgh so the current procedure is picked.
+ } else {
+ // 16.9.144(6): a bare NULL() is not allowed as an actual
+ // argument to a generic procedure if the specific procedure
+ // cannot be unambiguously distinguished
+ // Underspecified external procedure actual arguments can
+ // also lead to ambiguity.
+ return {nullptr, true /* due to ambiguity */};
+ }
}
if (!procedure->IsElemental()) {
// takes priority over elemental match
@@ -2590,6 +2685,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
} else {
elemental = &specific;
}
+ crtMatchingDistance =
+ ComputeCudaMatchingDistance(*procedure, localActuals);
}
}
}
diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf
index 6db829002fae67..0662a51ff5f22a 100644
--- a/flang/test/Semantics/cuf13.cuf
+++ b/flang/test/Semantics/cuf13.cuf
@@ -1,13 +1,11 @@
-! RUN: %python %S/test_errors.py %s %flang_fc1
+! RUN: %flang -fc1 -x cuda -fdebug-unparse %s | FileCheck %s
module matching
interface sub
module procedure sub_host
module procedure sub_device
- end interface
-
- interface subman
- module procedure sub_host
+ module procedure sub_managed
+ module procedure sub_unified
end interface
contains
@@ -19,6 +17,13 @@ contains
integer, device :: a(:)
end
+ subroutine sub_managed(a)
+ integer, managed :: a(:)
+ end
+
+ subroutine sub_unified(a)
+ integer, unified :: a(:)
+ end
end module
program m
@@ -26,12 +31,21 @@ program m
integer, pinned, allocatable :: a(:)
integer, managed, allocatable :: b(:)
+ integer, unified, allocatable :: u(:)
+ integer, device :: d(10)
logical :: plog
allocate(a(100), pinned = plog)
allocate(b(200))
+ allocate(u(100))
- call sub(a)
-
- call subman(b)
+ call sub(a) ! Should resolve to sub_host
+ call sub(b) ! Should resolve to sub_managed
+ call sub(u) ! Should resolve to sub_unified
+ call sub(d) ! Should resolve to sub_device
end
+
+! CHECK: CALL sub_host
+! CHECK: CALL sub_managed
+! CHECK: CALL sub_unified
+! CHECK: CALL sub_device
More information about the flang-commits
mailing list