[flang-commits] [flang] [flang][cuda] Fix resolution of overloaded operator (PR #122402)
via flang-commits
flang-commits at lists.llvm.org
Thu Jan 9 17:52:45 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Move the code adding the implicit attribute when the EntityDetails have been converted to ObjectEntityDetails. ObjectEntityDetails holds the CUDA data attributes.
This resolves an issue when resolving overloaded operators.
---
Full diff: https://github.com/llvm/llvm-project/pull/122402.diff
2 Files Affected:
- (modified) flang/lib/Semantics/resolve-names.cpp (+19-20)
- (modified) flang/test/Semantics/cuf10.cuf (+19)
``````````diff
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 724f1b28078356..51e7c5960dc2ef 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8970,18 +8970,6 @@ void ResolveNamesVisitor::FinishSpecificationPart(
misparsedStmtFuncFound_ = false;
funcResultStack().CompleteFunctionResultType();
CheckImports();
- bool inDeviceSubprogram = false;
- if (auto *subp{currScope().symbol()
- ? currScope().symbol()->detailsIf<SubprogramDetails>()
- : nullptr}) {
- if (auto attrs{subp->cudaSubprogramAttrs()}) {
- if (*attrs == common::CUDASubprogramAttrs::Device ||
- *attrs == common::CUDASubprogramAttrs::Global ||
- *attrs == common::CUDASubprogramAttrs::Grid_Global) {
- inDeviceSubprogram = true;
- }
- }
- }
for (auto &pair : currScope()) {
auto &symbol{*pair.second};
if (inInterfaceBlock()) {
@@ -8990,14 +8978,6 @@ void ResolveNamesVisitor::FinishSpecificationPart(
if (NeedsExplicitType(symbol)) {
ApplyImplicitRules(symbol);
}
- if (inDeviceSubprogram && symbol.has<ObjectEntityDetails>()) {
- auto *object{symbol.detailsIf<ObjectEntityDetails>()};
- if (!object->cudaDataAttr() && !IsValue(symbol) &&
- (IsDummy(symbol) || object->IsArray())) {
- // Implicitly set device attribute if none is set in device context.
- object->set_cudaDataAttr(common::CUDADataAttr::Device);
- }
- }
if (IsDummy(symbol) && isImplicitNoneType() &&
symbol.test(Symbol::Flag::Implicit) && !context().HasError(symbol)) {
Say(symbol.name(),
@@ -9522,6 +9502,7 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
},
node.stmt());
Walk(node.spec());
+ bool inDeviceSubprogram = false;
// If this is a function, convert result to an object. This is to prevent the
// result from being converted later to a function symbol if it is called
// inside the function.
@@ -9535,6 +9516,15 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
if (details->isFunction()) {
ConvertToObjectEntity(const_cast<Symbol &>(details->result()));
}
+ // Check the current procedure is a device procedure to apply implicit
+ // attribute at the end.
+ if (auto attrs{details->cudaSubprogramAttrs()}) {
+ if (*attrs == common::CUDASubprogramAttrs::Device ||
+ *attrs == common::CUDASubprogramAttrs::Global ||
+ *attrs == common::CUDASubprogramAttrs::Grid_Global) {
+ inDeviceSubprogram = true;
+ }
+ }
}
}
if (node.IsModule()) {
@@ -9561,6 +9551,15 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
symbol.GetType() ? Symbol::Flag::Function : Symbol::Flag::Subroutine);
}
ApplyImplicitRules(symbol);
+ // Apply CUDA implicit attributes if needed.
+ if (inDeviceSubprogram && symbol.has<ObjectEntityDetails>()) {
+ auto *object{symbol.detailsIf<ObjectEntityDetails>()};
+ if (!object->cudaDataAttr() && !IsValue(symbol) &&
+ (IsDummy(symbol) || object->IsArray())) {
+ // Implicitly set device attribute if none is set in device context.
+ object->set_cudaDataAttr(common::CUDADataAttr::Device);
+ }
+ }
}
}
diff --git a/flang/test/Semantics/cuf10.cuf b/flang/test/Semantics/cuf10.cuf
index 24b596b1fa55db..f85471855ec57e 100644
--- a/flang/test/Semantics/cuf10.cuf
+++ b/flang/test/Semantics/cuf10.cuf
@@ -3,6 +3,13 @@ module m
real, device :: a(4,8)
real, managed, allocatable :: b(:,:)
integer, constant :: x = 1
+ type :: int
+ real :: i, s
+ end type int
+ interface operator (+)
+ module procedure addHost
+ module procedure addDevice
+ end interface operator (+)
contains
attributes(global) subroutine kernel(a,b,c,n,m)
integer, value :: n
@@ -30,4 +37,16 @@ module m
subroutine sub2()
call sub1<<<1,1>>>(x) ! actual constant to device dummy
end
+ function addHost(a, b) result(c)
+ type(int), intent(in) :: a, b
+ type(int) :: c
+ end function addHost
+ attributes(device) function addDevice(a, b) result(c)
+ type(int), device :: c
+ type(int), intent(in) :: a ,b
+ end function addDevice
+ attributes(global) subroutine overload(c, a, b)
+ type (int) :: c, a, b
+ c = a+b ! ok resolve to addDevice
+ end subroutine overload
end
``````````
</details>
https://github.com/llvm/llvm-project/pull/122402
More information about the flang-commits
mailing list