[flang-commits] [flang] [flang][cuda] Compute matching distance in generic resolution (PR #90774)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Wed May 1 14:25:48 PDT 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/90774
>From 07cd30a898e4177a291ee5e6b157c658bcdc51bc Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 May 2024 13:32:55 -0700
Subject: [PATCH 1/2] [flang][cuda] Compute matching distance in generic
resolution
---
flang/lib/Semantics/expression.cpp | 107 +++++++++++++++++++++++++++--
flang/test/Semantics/cuf13.cuf | 30 +++++---
2 files changed, 123 insertions(+), 14 deletions(-)
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b8396209fc6854..0ba0871f530169 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2492,6 +2492,89 @@ static bool CheckCompatibleArguments(
return true;
}
+// Compute the matching distance as described in section 3.2.3 of the CUDA
+// Fortran references.
+static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
+ const std::optional<ActualArgument> &actual) {
+ std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
+ if (actual) {
+ if (auto *expr{actual->UnwrapExpr()}) {
+ const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
+ if (actualLastSymbol) {
+ actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
+ if (const auto *actualObject{actualLastSymbol
+ ? actualLastSymbol
+ ->detailsIf<semantics::ObjectEntityDetails>()
+ : nullptr}) {
+ actualDataAttr = actualObject->cudaDataAttr();
+ }
+ }
+ }
+ }
+
+ common::visit(common::visitors{
+ [&](const characteristics::DummyDataObject &object) {
+ dummyDataAttr = object.cudaDataAttr;
+ },
+ [&](const auto &) {},
+ },
+ dummy.u);
+
+ if (!dummyDataAttr) {
+ if (!actualDataAttr) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+ *actualDataAttr == common::CUDADataAttr::Unified) {
+ return 3;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
+ if (!actualDataAttr) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Device) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed ||
+ *actualDataAttr == common::CUDADataAttr::Unified) {
+ return 2;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
+ if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+ return 0;
+ } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+ return 1;
+ }
+ } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
+ if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+ return std::numeric_limits<int>::max();
+ } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
+ return 1;
+ } else if (*actualDataAttr == common::CUDADataAttr::Unified) {
+ return 0;
+ }
+ }
+ return std::numeric_limits<int>::max();
+}
+
+static int ComputeCudaMatchingDistance(
+ const characteristics::Procedure &procedure,
+ const ActualArguments &actuals) {
+ const auto &dummies{procedure.dummyArguments};
+ CHECK(dummies.size() == actuals.size());
+ int distance = 0;
+ for (std::size_t i{0}; i < dummies.size(); ++i) {
+ const characteristics::DummyArgument &dummy{dummies[i]};
+ const std::optional<ActualArgument> &actual{actuals[i]};
+ int d = GetMatchingDistance(dummy, actual);
+ if (d == std::numeric_limits<int>::max())
+ return d;
+ distance += d;
+ }
+ return distance;
+}
+
// Handles a forward reference to a module function from what must
// be a specification expression. Return false if the symbol is
// an invalid forward reference.
@@ -2539,6 +2622,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
const Symbol *elemental{nullptr}; // matching elemental specific proc
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const Symbol &ultimate{symbol.GetUltimate()};
+ int crtMatchingDistance{std::numeric_limits<int>::max()};
// Check for a match with an explicit INTRINSIC
if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
parser::Messages buffer;
@@ -2575,12 +2659,21 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
CheckCompatibleArguments(*procedure, localActuals)) {
if ((procedure->IsElemental() && elemental) ||
(!procedure->IsElemental() && nonElemental)) {
- // 16.9.144(6): a bare NULL() is not allowed as an actual
- // argument to a generic procedure if the specific procedure
- // cannot be unambiguously distinguished
- // Underspecified external procedure actual arguments can
- // also lead to ambiguity.
- return {nullptr, true /* due to ambiguity */};
+ int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
+ if (d != crtMatchingDistance) {
+ if (d > crtMatchingDistance) {
+ continue;
+ }
+ // Matching distance is smaller than the previously matched
+ // specific. Let it go thourgh so the current procedure is picked.
+ } else {
+ // 16.9.144(6): a bare NULL() is not allowed as an actual
+ // argument to a generic procedure if the specific procedure
+ // cannot be unambiguously distinguished
+ // Underspecified external procedure actual arguments can
+ // also lead to ambiguity.
+ return {nullptr, true /* due to ambiguity */};
+ }
}
if (!procedure->IsElemental()) {
// takes priority over elemental match
@@ -2588,6 +2681,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
} else {
elemental = &specific;
}
+ crtMatchingDistance =
+ ComputeCudaMatchingDistance(*procedure, localActuals);
}
}
}
diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf
index 6db829002fae67..0662a51ff5f22a 100644
--- a/flang/test/Semantics/cuf13.cuf
+++ b/flang/test/Semantics/cuf13.cuf
@@ -1,13 +1,11 @@
-! RUN: %python %S/test_errors.py %s %flang_fc1
+! RUN: %flang -fc1 -x cuda -fdebug-unparse %s | FileCheck %s
module matching
interface sub
module procedure sub_host
module procedure sub_device
- end interface
-
- interface subman
- module procedure sub_host
+ module procedure sub_managed
+ module procedure sub_unified
end interface
contains
@@ -19,6 +17,13 @@ contains
integer, device :: a(:)
end
+ subroutine sub_managed(a)
+ integer, managed :: a(:)
+ end
+
+ subroutine sub_unified(a)
+ integer, unified :: a(:)
+ end
end module
program m
@@ -26,12 +31,21 @@ program m
integer, pinned, allocatable :: a(:)
integer, managed, allocatable :: b(:)
+ integer, unified, allocatable :: u(:)
+ integer, device :: d(10)
logical :: plog
allocate(a(100), pinned = plog)
allocate(b(200))
+ allocate(u(100))
- call sub(a)
-
- call subman(b)
+ call sub(a) ! Should resolve to sub_host
+ call sub(b) ! Should resolve to sub_managed
+ call sub(u) ! Should resolve to sub_unified
+ call sub(d) ! Should resolve to sub_device
end
+
+! CHECK: CALL sub_host
+! CHECK: CALL sub_managed
+! CHECK: CALL sub_unified
+! CHECK: CALL sub_device
>From a408ca83296a616865506e790d2772f5cd4be1f6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 May 2024 14:25:36 -0700
Subject: [PATCH 2/2] Use braces and make max int a named constant
---
flang/lib/Semantics/expression.cpp | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 0ba0871f530169..d109443c6d5c91 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2492,6 +2492,8 @@ static bool CheckCompatibleArguments(
return true;
}
+static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
+
// Compute the matching distance as described in section 3.2.3 of the CUDA
// Fortran references.
static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
@@ -2524,14 +2526,14 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
if (!actualDataAttr) {
return 0;
} else if (*actualDataAttr == common::CUDADataAttr::Device) {
- return std::numeric_limits<int>::max();
+ return cudaInfMatchingValue;
} else if (*actualDataAttr == common::CUDADataAttr::Managed ||
*actualDataAttr == common::CUDADataAttr::Unified) {
return 3;
}
} else if (*dummyDataAttr == common::CUDADataAttr::Device) {
if (!actualDataAttr) {
- return std::numeric_limits<int>::max();
+ return cudaInfMatchingValue;
} else if (*actualDataAttr == common::CUDADataAttr::Device) {
return 0;
} else if (*actualDataAttr == common::CUDADataAttr::Managed ||
@@ -2540,7 +2542,7 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
}
} else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
- return std::numeric_limits<int>::max();
+ return cudaInfMatchingValue;
} else if (*actualDataAttr == common::CUDADataAttr::Managed) {
return 0;
} else if (*actualDataAttr == common::CUDADataAttr::Unified) {
@@ -2548,14 +2550,14 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
}
} else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
- return std::numeric_limits<int>::max();
+ return cudaInfMatchingValue;
} else if (*actualDataAttr == common::CUDADataAttr::Managed) {
return 1;
} else if (*actualDataAttr == common::CUDADataAttr::Unified) {
return 0;
}
}
- return std::numeric_limits<int>::max();
+ return cudaInfMatchingValue;
}
static int ComputeCudaMatchingDistance(
@@ -2563,12 +2565,12 @@ static int ComputeCudaMatchingDistance(
const ActualArguments &actuals) {
const auto &dummies{procedure.dummyArguments};
CHECK(dummies.size() == actuals.size());
- int distance = 0;
+ int distance{0};
for (std::size_t i{0}; i < dummies.size(); ++i) {
const characteristics::DummyArgument &dummy{dummies[i]};
const std::optional<ActualArgument> &actual{actuals[i]};
- int d = GetMatchingDistance(dummy, actual);
- if (d == std::numeric_limits<int>::max())
+ int d{GetMatchingDistance(dummy, actual)};
+ if (d == cudaInfMatchingValue)
return d;
distance += d;
}
@@ -2622,7 +2624,7 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
const Symbol *elemental{nullptr}; // matching elemental specific proc
const Symbol *nonElemental{nullptr}; // matching non-elemental specific
const Symbol &ultimate{symbol.GetUltimate()};
- int crtMatchingDistance{std::numeric_limits<int>::max()};
+ int crtMatchingDistance{cudaInfMatchingValue};
// Check for a match with an explicit INTRINSIC
if (ultimate.attrs().test(semantics::Attr::INTRINSIC)) {
parser::Messages buffer;
More information about the flang-commits
mailing list