[flang-commits] [flang] c189df8 - [flang][cuda] Fix resolution of overloaded operator (#122402)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 10 08:44:22 PST 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-01-10T08:44:19-08:00
New Revision: c189df842c67a2476a59363fa36a0c1b1137f533

URL: https://github.com/llvm/llvm-project/commit/c189df842c67a2476a59363fa36a0c1b1137f533
DIFF: https://github.com/llvm/llvm-project/commit/c189df842c67a2476a59363fa36a0c1b1137f533.diff

LOG: [flang][cuda] Fix resolution of overloaded operator (#122402)

Added: 
    

Modified: 
    flang/lib/Semantics/resolve-names.cpp
    flang/test/Semantics/cuf10.cuf

Removed: 
    


################################################################################
diff  --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 724f1b28078356..51e7c5960dc2ef 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8970,18 +8970,6 @@ void ResolveNamesVisitor::FinishSpecificationPart(
   misparsedStmtFuncFound_ = false;
   funcResultStack().CompleteFunctionResultType();
   CheckImports();
-  bool inDeviceSubprogram = false;
-  if (auto *subp{currScope().symbol()
-              ? currScope().symbol()->detailsIf<SubprogramDetails>()
-              : nullptr}) {
-    if (auto attrs{subp->cudaSubprogramAttrs()}) {
-      if (*attrs == common::CUDASubprogramAttrs::Device ||
-          *attrs == common::CUDASubprogramAttrs::Global ||
-          *attrs == common::CUDASubprogramAttrs::Grid_Global) {
-        inDeviceSubprogram = true;
-      }
-    }
-  }
   for (auto &pair : currScope()) {
     auto &symbol{*pair.second};
     if (inInterfaceBlock()) {
@@ -8990,14 +8978,6 @@ void ResolveNamesVisitor::FinishSpecificationPart(
     if (NeedsExplicitType(symbol)) {
       ApplyImplicitRules(symbol);
     }
-    if (inDeviceSubprogram && symbol.has<ObjectEntityDetails>()) {
-      auto *object{symbol.detailsIf<ObjectEntityDetails>()};
-      if (!object->cudaDataAttr() && !IsValue(symbol) &&
-          (IsDummy(symbol) || object->IsArray())) {
-        // Implicitly set device attribute if none is set in device context.
-        object->set_cudaDataAttr(common::CUDADataAttr::Device);
-      }
-    }
     if (IsDummy(symbol) && isImplicitNoneType() &&
         symbol.test(Symbol::Flag::Implicit) && !context().HasError(symbol)) {
       Say(symbol.name(),
@@ -9522,6 +9502,7 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
       },
       node.stmt());
   Walk(node.spec());
+  bool inDeviceSubprogram = false;
   // If this is a function, convert result to an object. This is to prevent the
   // result from being converted later to a function symbol if it is called
   // inside the function.
@@ -9535,6 +9516,15 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
       if (details->isFunction()) {
         ConvertToObjectEntity(const_cast<Symbol &>(details->result()));
       }
+      // Check the current procedure is a device procedure to apply implicit
+      // attribute at the end.
+      if (auto attrs{details->cudaSubprogramAttrs()}) {
+        if (*attrs == common::CUDASubprogramAttrs::Device ||
+            *attrs == common::CUDASubprogramAttrs::Global ||
+            *attrs == common::CUDASubprogramAttrs::Grid_Global) {
+          inDeviceSubprogram = true;
+        }
+      }
     }
   }
   if (node.IsModule()) {
@@ -9561,6 +9551,15 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) {
           symbol.GetType() ? Symbol::Flag::Function : Symbol::Flag::Subroutine);
     }
     ApplyImplicitRules(symbol);
+    // Apply CUDA implicit attributes if needed.
+    if (inDeviceSubprogram && symbol.has<ObjectEntityDetails>()) {
+      auto *object{symbol.detailsIf<ObjectEntityDetails>()};
+      if (!object->cudaDataAttr() && !IsValue(symbol) &&
+          (IsDummy(symbol) || object->IsArray())) {
+        // Implicitly set device attribute if none is set in device context.
+        object->set_cudaDataAttr(common::CUDADataAttr::Device);
+      }
+    }
   }
 }
 

diff  --git a/flang/test/Semantics/cuf10.cuf b/flang/test/Semantics/cuf10.cuf
index 24b596b1fa55db..f85471855ec57e 100644
--- a/flang/test/Semantics/cuf10.cuf
+++ b/flang/test/Semantics/cuf10.cuf
@@ -3,6 +3,13 @@ module m
   real, device :: a(4,8)
   real, managed, allocatable :: b(:,:)
   integer, constant :: x = 1
+  type :: int
+     real :: i, s 
+  end type int
+  interface operator (+)
+     module procedure addHost
+     module procedure addDevice
+  end interface operator (+)
  contains
   attributes(global) subroutine kernel(a,b,c,n,m)
     integer, value :: n
@@ -30,4 +37,16 @@ module m
   subroutine sub2()
     call sub1<<<1,1>>>(x) ! actual constant to device dummy
   end
+  function addHost(a, b) result(c)
+    type(int), intent(in) :: a, b
+    type(int) :: c
+  end function addHost
+  attributes(device) function addDevice(a, b) result(c)
+    type(int), device :: c
+    type(int), intent(in) :: a ,b
+  end function addDevice
+  attributes(global) subroutine overload(c, a, b)
+    type (int) :: c, a, b
+    c = a+b ! ok resolve to addDevice
+  end subroutine overload
 end


        


More information about the flang-commits mailing list