[clang] [CUDA][HIP] Fix host/device context in concept (PR #67721)

Yaxun Liu via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 28 11:49:41 PDT 2023


https://github.com/yxsamliu updated https://github.com/llvm/llvm-project/pull/67721

>From ba151e83af264074303ccc1d8f4ecf853a4a153f Mon Sep 17 00:00:00 2001
From: "Yaxun (Sam) Liu" <yaxun.liu at amd.com>
Date: Thu, 28 Sep 2023 14:48:28 -0400
Subject: [PATCH] [CUDA][HIP] Fix host/device context in concept

Currently, constraints are checked in Sema::FinishTemplateArgumentDeduction,
where the current function in ASTContext is set to the instantiated template
function. When resolving functions for the constraints, clang assumes the
caller is the current function, This causes incompatibility with nvcc and
also for constexpr template functions with C++.

clang caches the constraint checking result per concept/type matching. It
assumes the result does not depend on the instantiation context.

This patch let constraint checking have its own host/device context and by
default it is host to be compatible with C++. This makes the constraint
checking independent of callers and make the caching valid.

In the future, we may introduce device constraints by other means,
e.g. adding __device__ attribute per function call in constraints.

Fixes: https://github.com/llvm/llvm-project/issues/67507
---
 clang/docs/HIPSupport.rst       | 31 +++++++++++++++++++++++++++++++
 clang/include/clang/Sema/Sema.h |  9 +++++++--
 clang/lib/Sema/SemaCUDA.cpp     | 33 ++++++++++++++++++++-------------
 clang/lib/Sema/SemaConcept.cpp  |  2 ++
 clang/test/SemaCUDA/concept.cu  | 23 +++++++++++++++++++++++
 5 files changed, 83 insertions(+), 15 deletions(-)
 create mode 100644 clang/test/SemaCUDA/concept.cu

diff --git a/clang/docs/HIPSupport.rst b/clang/docs/HIPSupport.rst
index 8b4649733a9c777..ea7eed0fe7ce1eb 100644
--- a/clang/docs/HIPSupport.rst
+++ b/clang/docs/HIPSupport.rst
@@ -176,3 +176,34 @@ Predefined Macros
    * - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
      - Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
 
+C++20 Concepts with HIP and CUDA
+--------------------------------
+
+In Clang, when working with HIP or CUDA, it's important to note that all constraints in C++20 concepts are assumed to be for the host side only. This behavior is consistent across both programming models, and developers should be aware of this assumption when writing code that utilizes C++20 concepts.
+
+Example:
+.. code-block:: c++
+
+   template <class T>
+   concept MyConcept = requires(T& obj) {
+     my_function(obj);  // Assumed to be a host-side requirement
+   };
+
+   template <MyConcept T>
+   __global__ void kernel() {
+      // Kernel code
+   }
+
+   struct MyType {};
+
+   inline void my_function(MyType& obj) {}
+
+   int main() {
+      kernel<MyType><<<1,1>>>();
+      return 0;
+   }
+
+In the above example, the ``MyConcept`` concept is assumed to check the host-side requirements, even though it's being used in a device kernel. Developers should structure their code accordingly to ensure correct behavior and to satisfy the host-side constraints assumed by Clang.
+
+This assumption helps maintain a consistent behavior when dealing with template constraints, and simplifies the compilation model by reducing the complexity associated with differentiating between host and device-side requirements.
+
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 712db0a3dd895d5..9b1545b634177d4 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13312,6 +13312,7 @@ class Sema final {
     CTCK_Unknown,       /// Unknown context
     CTCK_InitGlobalVar, /// Function called during global variable
                         /// initialization
+    CTCK_Constraint,    /// Function called for constraint checking
   };
 
   /// Define the current global CUDA host/device context where a function may be
@@ -13319,13 +13320,17 @@ class Sema final {
   struct CUDATargetContext {
     CUDAFunctionTarget Target = CFT_HostDevice;
     CUDATargetContextKind Kind = CTCK_Unknown;
-    Decl *D = nullptr;
+    const Decl *D = nullptr;
+    const Expr *E = nullptr;
+    /// Whether should override the current function.
+    bool shouldOverride(const Decl *D) const;
   } CurCUDATargetCtx;
 
   struct CUDATargetContextRAII {
     Sema &S;
     CUDATargetContext SavedCtx;
-    CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
+    CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, const Decl *D,
+                          const Expr *E = nullptr);
     ~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
   };
 
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index 88f5484575db17a..285711116bb8d4e 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -114,27 +114,34 @@ static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
 
 Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
                                                    CUDATargetContextKind K,
-                                                   Decl *D)
+                                                   const Decl *D, const Expr *E)
     : S(S_) {
   SavedCtx = S.CurCUDATargetCtx;
-  assert(K == CTCK_InitGlobalVar);
-  auto *VD = dyn_cast_or_null<VarDecl>(D);
-  if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
-    auto Target = CFT_Host;
-    if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
-         !hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
-        hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
-        hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
-      Target = CFT_Device;
-    S.CurCUDATargetCtx = {Target, K, VD};
+  auto Target = CFT_Host;
+  if (K == CTCK_InitGlobalVar) {
+    auto *VD = dyn_cast_or_null<VarDecl>(D);
+    if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
+      if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
+           !hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
+          hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
+          hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
+        Target = CFT_Device;
+      S.CurCUDATargetCtx = {Target, K, D, E};
+    }
+    return;
   }
+  assert(K == CTCK_Constraint);
+  S.CurCUDATargetCtx = {Target, K, D, E};
+}
+
+bool Sema::CUDATargetContext::shouldOverride(const Decl *D) const {
+  return Kind == CTCK_Constraint || D == nullptr;
 }
 
 /// IdentifyCUDATarget - Determine the CUDA compilation target for this function
 Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
                                                   bool IgnoreImplicitHDAttr) {
-  // Code that lives outside a function gets the target from CurCUDATargetCtx.
-  if (D == nullptr)
+  if (CurCUDATargetCtx.shouldOverride(D))
     return CurCUDATargetCtx.Target;
 
   if (D->hasAttr<CUDAInvalidTargetAttr>())
diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp
index 036548b68247bfa..94e2b4e444deb26 100644
--- a/clang/lib/Sema/SemaConcept.cpp
+++ b/clang/lib/Sema/SemaConcept.cpp
@@ -336,6 +336,8 @@ static ExprResult calculateConstraintSatisfaction(
     Sema &S, const NamedDecl *Template, SourceLocation TemplateNameLoc,
     const MultiLevelTemplateArgumentList &MLTAL, const Expr *ConstraintExpr,
     ConstraintSatisfaction &Satisfaction) {
+  Sema::CUDATargetContextRAII X(S, Sema::CTCK_Constraint,
+                                /*Decl=*/nullptr, ConstraintExpr);
   return calculateConstraintSatisfaction(
       S, ConstraintExpr, Satisfaction, [&](const Expr *AtomicExpr) {
         EnterExpressionEvaluationContext ConstantEvaluated(
diff --git a/clang/test/SemaCUDA/concept.cu b/clang/test/SemaCUDA/concept.cu
new file mode 100644
index 000000000000000..e29381892c59019
--- /dev/null
+++ b/clang/test/SemaCUDA/concept.cu
@@ -0,0 +1,23 @@
+// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -fcuda-is-device -x hip %s \
+// RUN:   -std=c++20 -fsyntax-only -verify
+// RUN: %clang_cc1 -triple x86_64 -x hip %s \
+// RUN:   -std=c++20 -fsyntax-only -verify
+
+// expected-no-diagnostics
+
+#include "Inputs/cuda.h"
+
+template <class T>
+concept C = requires(T x) {
+  func(x);
+};
+
+struct A {};
+void func(A x) {}
+
+template <C T> __global__ void kernel(T x) { }
+
+int main() {
+  A a;
+  kernel<<<1,1>>>(a);
+}



More information about the cfe-commits mailing list