[clang] d414451 - [CUDA][HIP] Fix hostness check with -fopenmp

Yaxun Liu via cfe-commits cfe-commits at lists.llvm.org
Thu Mar 24 12:20:26 PDT 2022


Author: Yaxun (Sam) Liu
Date: 2022-03-24T15:19:47-04:00
New Revision: d41445113bccaa037e5876659b4fd98d96af03e4

URL: https://github.com/llvm/llvm-project/commit/d41445113bccaa037e5876659b4fd98d96af03e4
DIFF: https://github.com/llvm/llvm-project/commit/d41445113bccaa037e5876659b4fd98d96af03e4.diff

LOG: [CUDA][HIP] Fix hostness check with -fopenmp

CUDA/HIP determines whether a function can be called based on
the device/host attributes of callee and caller. Clang assumes the
caller is CurContext. This is correct in most cases, however, it is
not correct in OpenMP parallel region when CUDA/HIP program
is compiled with -fopenmp. This causes incorrect overloading
resolution and missed diagnostics.

To get the correct caller, clang needs to chase the parent chain
of DeclContext starting from CurContext until a function decl
or a lambda decl is reached. Sema API is adapted to achieve that
and used to determine the caller in hostness check.

Reviewed by: Artem Belevich, Richard Smith

Differential Revision: https://reviews.llvm.org/D121765

Added: 
    clang/test/CodeGenCUDA/openmp-parallel.cu
    clang/test/SemaCUDA/openmp-parallel.cu

Modified: 
    clang/include/clang/Sema/Sema.h
    clang/lib/Sema/Sema.cpp
    clang/lib/Sema/SemaCUDA.cpp
    clang/lib/Sema/SemaOverload.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 8e3f9221763f3..f95308275688e 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -3318,12 +3318,14 @@ class Sema final {
   void ActOnReenterFunctionContext(Scope* S, Decl* D);
   void ActOnExitFunctionContext();
 
-  DeclContext *getFunctionLevelDeclContext();
-
-  /// getCurFunctionDecl - If inside of a function body, this returns a pointer
-  /// to the function decl for the function being parsed.  If we're currently
-  /// in a 'block', this returns the containing context.
-  FunctionDecl *getCurFunctionDecl();
+  /// If \p AllowLambda is true, treat lambda as function.
+  DeclContext *getFunctionLevelDeclContext(bool AllowLambda = false);
+
+  /// Returns a pointer to the innermost enclosing function, or nullptr if the
+  /// current context is not inside a function. If \p AllowLambda is true,
+  /// this can return the call operator of an enclosing lambda, otherwise
+  /// lambdas are skipped when looking for an enclosing function.
+  FunctionDecl *getCurFunctionDecl(bool AllowLambda = false);
 
   /// getCurMethodDecl - If inside of a method body, this returns a pointer to
   /// the method decl for the method being parsed.  If we're currently

diff  --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index d625ffedbe539..fa09281f2f0e5 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -1421,19 +1421,18 @@ void Sema::ActOnEndOfTranslationUnit() {
 // Helper functions.
 //===----------------------------------------------------------------------===//
 
-DeclContext *Sema::getFunctionLevelDeclContext() {
+DeclContext *Sema::getFunctionLevelDeclContext(bool AllowLambda) {
   DeclContext *DC = CurContext;
 
   while (true) {
     if (isa<BlockDecl>(DC) || isa<EnumDecl>(DC) || isa<CapturedDecl>(DC) ||
         isa<RequiresExprBodyDecl>(DC)) {
       DC = DC->getParent();
-    } else if (isa<CXXMethodDecl>(DC) &&
+    } else if (!AllowLambda && isa<CXXMethodDecl>(DC) &&
                cast<CXXMethodDecl>(DC)->getOverloadedOperator() == OO_Call &&
                cast<CXXRecordDecl>(DC->getParent())->isLambda()) {
       DC = DC->getParent()->getParent();
-    }
-    else break;
+    } else break;
   }
 
   return DC;
@@ -1442,8 +1441,8 @@ DeclContext *Sema::getFunctionLevelDeclContext() {
 /// getCurFunctionDecl - If inside of a function body, this returns a pointer
 /// to the function decl for the function being parsed.  If we're currently
 /// in a 'block', this returns the containing context.
-FunctionDecl *Sema::getCurFunctionDecl() {
-  DeclContext *DC = getFunctionLevelDeclContext();
+FunctionDecl *Sema::getCurFunctionDecl(bool AllowLambda) {
+  DeclContext *DC = getFunctionLevelDeclContext(AllowLambda);
   return dyn_cast<FunctionDecl>(DC);
 }
 

diff  --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index 92785514e1048..b0af13044fc29 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -728,8 +728,9 @@ void Sema::MaybeAddCUDAConstantAttr(VarDecl *VD) {
 Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
                                                        unsigned DiagID) {
   assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
+  FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true);
   SemaDiagnosticBuilder::Kind DiagKind = [&] {
-    if (!isa<FunctionDecl>(CurContext))
+    if (!CurFunContext)
       return SemaDiagnosticBuilder::K_Nop;
     switch (CurrentCUDATarget()) {
     case CFT_Global:
@@ -743,7 +744,7 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
         return SemaDiagnosticBuilder::K_Nop;
       if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
         return SemaDiagnosticBuilder::K_Immediate;
-      return (getEmissionStatus(cast<FunctionDecl>(CurContext)) ==
+      return (getEmissionStatus(CurFunContext) ==
               FunctionEmissionStatus::Emitted)
                  ? SemaDiagnosticBuilder::K_ImmediateWithCallStack
                  : SemaDiagnosticBuilder::K_Deferred;
@@ -751,15 +752,15 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
       return SemaDiagnosticBuilder::K_Nop;
     }
   }();
-  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID,
-                               dyn_cast<FunctionDecl>(CurContext), *this);
+  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this);
 }
 
 Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
                                                      unsigned DiagID) {
   assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
+  FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true);
   SemaDiagnosticBuilder::Kind DiagKind = [&] {
-    if (!isa<FunctionDecl>(CurContext))
+    if (!CurFunContext)
       return SemaDiagnosticBuilder::K_Nop;
     switch (CurrentCUDATarget()) {
     case CFT_Host:
@@ -772,7 +773,7 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
         return SemaDiagnosticBuilder::K_Nop;
       if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
         return SemaDiagnosticBuilder::K_Immediate;
-      return (getEmissionStatus(cast<FunctionDecl>(CurContext)) ==
+      return (getEmissionStatus(CurFunContext) ==
               FunctionEmissionStatus::Emitted)
                  ? SemaDiagnosticBuilder::K_ImmediateWithCallStack
                  : SemaDiagnosticBuilder::K_Deferred;
@@ -780,8 +781,7 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
       return SemaDiagnosticBuilder::K_Nop;
     }
   }();
-  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID,
-                               dyn_cast<FunctionDecl>(CurContext), *this);
+  return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this);
 }
 
 bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
@@ -794,7 +794,7 @@ bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
 
   // FIXME: Is bailing out early correct here?  Should we instead assume that
   // the caller is a global initializer?
-  FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
+  FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
   if (!Caller)
     return true;
 
@@ -860,7 +860,7 @@ void Sema::CUDACheckLambdaCapture(CXXMethodDecl *Callee,
 
   // File-scope lambda can only do init captures for global variables, which
   // results in passing by value for these global variables.
-  FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
+  FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
   if (!Caller)
     return;
 

diff  --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index c802bc7fd85c8..33271609a0081 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -6473,7 +6473,7 @@ void Sema::AddOverloadCandidate(
 
   // (CUDA B.1): Check for invalid calls between targets.
   if (getLangOpts().CUDA)
-    if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
+    if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
       // Skip the check for callers that are implicit members, because in this
       // case we may not yet know what the member's target is; the target is
       // inferred for the member automatically, based on the bases and fields of
@@ -6983,7 +6983,7 @@ Sema::AddMethodCandidate(CXXMethodDecl *Method, DeclAccessPair FoundDecl,
 
   // (CUDA B.1): Check for invalid calls between targets.
   if (getLangOpts().CUDA)
-    if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
+    if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
       if (!IsAllowedCUDACall(Caller, Method)) {
         Candidate.Viable = false;
         Candidate.FailureKind = ovl_fail_bad_target;
@@ -9639,7 +9639,7 @@ bool clang::isBetterOverloadCandidate(
   // overloading resolution diagnostics.
   if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function &&
       S.getLangOpts().GPUExcludeWrongSideOverloads) {
-    if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext)) {
+    if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
       bool IsCallerImplicitHD = Sema::isCUDAImplicitHostDeviceFunction(Caller);
       bool IsCand1ImplicitHD =
           Sema::isCUDAImplicitHostDeviceFunction(Cand1.Function);
@@ -9922,7 +9922,7 @@ bool clang::isBetterOverloadCandidate(
   // If other rules cannot determine which is better, CUDA preference is used
   // to determine which is better.
   if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) {
-    FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
+    FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
     return S.IdentifyCUDAPreference(Caller, Cand1.Function) >
            S.IdentifyCUDAPreference(Caller, Cand2.Function);
   }
@@ -10043,7 +10043,7 @@ OverloadCandidateSet::BestViableFunction(Sema &S, SourceLocation Loc,
   // -fgpu-exclude-wrong-side-overloads is on, all candidates are compared
   // uniformly in isBetterOverloadCandidate.
   if (S.getLangOpts().CUDA && !S.getLangOpts().GPUExcludeWrongSideOverloads) {
-    const FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
+    const FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
     bool ContainsSameSideCandidate =
         llvm::any_of(Candidates, [&](OverloadCandidate *Cand) {
           // Check viable function only.
@@ -11077,7 +11077,7 @@ static void DiagnoseBadDeduction(Sema &S, OverloadCandidate *Cand,
 
 /// CUDA: diagnose an invalid call across targets.
 static void DiagnoseBadTarget(Sema &S, OverloadCandidate *Cand) {
-  FunctionDecl *Caller = cast<FunctionDecl>(S.CurContext);
+  FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
   FunctionDecl *Callee = Cand->Function;
 
   Sema::CUDAFunctionTarget CallerTarget = S.IdentifyCUDATarget(Caller),
@@ -12136,7 +12136,7 @@ class AddressOfFunctionResolver {
 
     if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
       if (S.getLangOpts().CUDA)
-        if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext))
+        if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
           if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
             return false;
       if (FunDecl->isMultiVersion()) {
@@ -12253,7 +12253,8 @@ class AddressOfFunctionResolver {
   }
 
   void EliminateSuboptimalCudaMatches() {
-    S.EraseUnwantedCUDAMatches(dyn_cast<FunctionDecl>(S.CurContext), Matches);
+    S.EraseUnwantedCUDAMatches(S.getCurFunctionDecl(/*AllowLambda=*/true),
+                               Matches);
   }
 
 public:

diff  --git a/clang/test/CodeGenCUDA/openmp-parallel.cu b/clang/test/CodeGenCUDA/openmp-parallel.cu
new file mode 100644
index 0000000000000..f9c32f3991d53
--- /dev/null
+++ b/clang/test/CodeGenCUDA/openmp-parallel.cu
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
+// RUN:   -fopenmp -emit-llvm -o -  -x hip %s | FileCheck %s
+
+#include "Inputs/cuda.h"
+
+void foo(double) {}
+__device__ void foo(int) {}
+
+// Check foo resolves to the host function.
+// CHECK-LABEL: define {{.*}}@_Z5test1v
+// CHECK: call void @_Z3food(double noundef 1.000000e+00)
+void test1() {
+  #pragma omp parallel
+  for (int i = 0; i < 100; i++)
+    foo(1);
+}
+
+// Check foo resolves to the host function.
+// CHECK-LABEL: define {{.*}}@_Z5test2v
+// CHECK: call void @_Z3food(double noundef 1.000000e+00)
+void test2() {
+  auto Lambda = []() {
+    #pragma omp parallel
+    for (int i = 0; i < 100; i++)
+      foo(1);
+  };
+  Lambda();
+}

diff  --git a/clang/test/SemaCUDA/openmp-parallel.cu b/clang/test/SemaCUDA/openmp-parallel.cu
new file mode 100644
index 0000000000000..2e519e903b25c
--- /dev/null
+++ b/clang/test/SemaCUDA/openmp-parallel.cu
@@ -0,0 +1,19 @@
+// RUN: %clang_cc1 -fopenmp -fsyntax-only -verify %s
+
+#include "Inputs/cuda.h"
+
+__device__ void foo(int) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+// expected-note at -1 {{'foo' declared here}}
+
+int main() {
+  #pragma omp parallel
+  for (int i = 0; i < 100; i++)
+    foo(1); // expected-error {{no matching function for call to 'foo'}}
+  
+  auto Lambda = []() {
+    #pragma omp parallel
+    for (int i = 0; i < 100; i++)
+      foo(1); // expected-error {{reference to __device__ function 'foo' in __host__ __device__ function}}
+    };
+  Lambda(); // expected-note {{called by 'main'}}
+}


        


More information about the cfe-commits mailing list