[flang-commits] [flang] [flang][cuda] Update attribute compatibily check for unified matching rule (PR #90679)
via flang-commits
flang-commits at lists.llvm.org
Tue Apr 30 15:47:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
This patch updates the compatibility checks for CUDA attribute iin preparation to implement the matching rules described in section 3.2.3. We this patch the compiler will still emit an error when there is multiple specific procedures that matches since the matching distances is not yet implemented. This will be done in a separate patch.
https://docs.nvidia.com/hpc-sdk/archive/24.3/compilers/cuda-fortran-prog-guide/index.html#cfref-var-attr-unified-data
---
Full diff: https://github.com/llvm/llvm-project/pull/90679.diff
4 Files Affected:
- (modified) flang/include/flang/Common/Fortran.h (+3-2)
- (modified) flang/lib/Common/Fortran.cpp (+23-1)
- (modified) flang/lib/Semantics/check-call.cpp (+3-2)
- (modified) flang/test/Semantics/cuf13.cuf (+9)
``````````diff
diff --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h
index 2a53452a2774ff..c63362a966a6b1 100644
--- a/flang/include/flang/Common/Fortran.h
+++ b/flang/include/flang/Common/Fortran.h
@@ -114,8 +114,9 @@ static constexpr IgnoreTKRSet ignoreTKRAll{IgnoreTKR::Type, IgnoreTKR::Kind,
IgnoreTKR::Rank, IgnoreTKR::Device, IgnoreTKR::Managed};
std::string AsFortran(IgnoreTKRSet);
-bool AreCompatibleCUDADataAttrs(
- std::optional<CUDADataAttr>, std::optional<CUDADataAttr>, IgnoreTKRSet);
+bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>,
+ std::optional<CUDADataAttr>, IgnoreTKRSet,
+ bool allowUnifiedMatchingRule = false);
static constexpr char blankCommonObjectName[] = "__BLNK__";
diff --git a/flang/lib/Common/Fortran.cpp b/flang/lib/Common/Fortran.cpp
index 8ada8fe210a30f..c8efe0bb234328 100644
--- a/flang/lib/Common/Fortran.cpp
+++ b/flang/lib/Common/Fortran.cpp
@@ -97,8 +97,12 @@ std::string AsFortran(IgnoreTKRSet tkr) {
return result;
}
+/// Check compatibilty of CUDA attribute.
+/// When `allowUnifiedMatchingRule` is enabled, argument `x` represents the
+/// dummy argument attribute while `y` represents the actual argument attribute.
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
- std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR) {
+ std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
+ bool allowUnifiedMatchingRule) {
if (!x && !y) {
return true;
} else if (x && y && *x == *y) {
@@ -114,6 +118,24 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
x.value_or(CUDADataAttr::Managed) == CUDADataAttr::Managed &&
y.value_or(CUDADataAttr::Managed) == CUDADataAttr::Managed) {
return true;
+ } else if (allowUnifiedMatchingRule) {
+ if (!x) { // Dummy argument has no attribute -> host
+ if (y && *y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified) {
+ return true;
+ }
+ } else {
+ if (*x == CUDADataAttr::Device && y &&
+ (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) {
+ return true;
+ } else if (*x == CUDADataAttr::Managed && y &&
+ *y == CUDADataAttr::Unified) {
+ return true;
+ } else if (*x == CUDADataAttr::Unified && y &&
+ *y == CUDADataAttr::Managed) {
+ return true;
+ }
+ }
+ return false;
} else {
return false;
}
diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp
index db0949e905a658..f0da779785142a 100644
--- a/flang/lib/Semantics/check-call.cpp
+++ b/flang/lib/Semantics/check-call.cpp
@@ -897,8 +897,9 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
actualDataAttr = common::CUDADataAttr::Device;
}
}
- if (!common::AreCompatibleCUDADataAttrs(
- dummyDataAttr, actualDataAttr, dummy.ignoreTKR)) {
+ if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
+ dummy.ignoreTKR,
+ /*allowUnifiedMatchingRule=*/true)) {
auto toStr{[](std::optional<common::CUDADataAttr> x) {
return x ? "ATTRIBUTES("s +
parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s
diff --git a/flang/test/Semantics/cuf13.cuf b/flang/test/Semantics/cuf13.cuf
index 7c6673e21bf11b..6db829002fae67 100644
--- a/flang/test/Semantics/cuf13.cuf
+++ b/flang/test/Semantics/cuf13.cuf
@@ -6,6 +6,10 @@ module matching
module procedure sub_device
end interface
+ interface subman
+ module procedure sub_host
+ end interface
+
contains
subroutine sub_host(a)
integer :: a(:)
@@ -21,8 +25,13 @@ program m
use matching
integer, pinned, allocatable :: a(:)
+ integer, managed, allocatable :: b(:)
logical :: plog
allocate(a(100), pinned = plog)
+ allocate(b(200))
call sub(a)
+
+ call subman(b)
+
end
``````````
</details>
https://github.com/llvm/llvm-project/pull/90679
More information about the flang-commits
mailing list