[clang] [CUDA][HIP] Fix deduction guide (PR #69366)

Yaxun Liu via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 17 17:51:27 PDT 2023


https://github.com/yxsamliu updated https://github.com/llvm/llvm-project/pull/69366

>From b28384d33f858a6d4139da931b436cbf1a0a426a Mon Sep 17 00:00:00 2001
From: "Yaxun (Sam) Liu" <yaxun.liu at amd.com>
Date: Sat, 14 Oct 2023 17:28:13 -0400
Subject: [PATCH] [CUDA][HIP] Fix deduction guide

Currently clang assumes implicit deduction guide to be host
device. This generates two identical implicit deduction
guides when a class have a device and a host constructor
which have the same input parameter, which causes ambiguity.

Since an implicit deduction guide is derived from a constructor,
it should take the same host/device attribute as the originating
constructor. This matches nvcc behavior as seen in
https://godbolt.org/z/sY1vdYWKe and https://godbolt.org/z/vTer7xa3j
---
 clang/docs/HIPSupport.rst              | 55 +++++++++++++++++
 clang/include/clang/AST/DeclCXX.h      |  9 +--
 clang/lib/AST/DeclCXX.cpp              | 21 +++++++
 clang/lib/Sema/SemaCUDA.cpp            |  5 +-
 clang/lib/Sema/SemaTemplate.cpp        | 18 ++++--
 clang/test/SemaCUDA/deduction-guide.cu | 85 ++++++++++++++++++++++++++
 6 files changed, 179 insertions(+), 14 deletions(-)
 create mode 100644 clang/test/SemaCUDA/deduction-guide.cu

diff --git a/clang/docs/HIPSupport.rst b/clang/docs/HIPSupport.rst
index 8b4649733a9c777..8a9802e19e6367f 100644
--- a/clang/docs/HIPSupport.rst
+++ b/clang/docs/HIPSupport.rst
@@ -176,3 +176,58 @@ Predefined Macros
    * - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
      - Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
 
+Support for Deduction Guides
+============================
+
+Explicit Deduction Guides
+-------------------------
+
+Explicit deduction guides in HIP can be annotated with either the
+``__host__`` or ``__device__`` attributes. If no attribute is provided,
+it defaults to ``__host__``.
+
+.. code-block:: cpp
+
+   template <typename T>
+   class MyArray {
+       //...
+   };
+
+   template <typename T>
+   MyArray(T)->MyArray<T>;
+
+   __device__ MyArray(float)->MyArray<int>;
+
+   // Uses of the deduction guides
+   MyArray arr1 = 10;      // Uses the default host guide
+   __device__ void foo() {
+       MyArray arr2 = 3.14f; // Uses the device guide
+   }
+
+Implicit Deduction Guides
+-------------------------
+Implicit deduction guides derived from constructors inherit the same host or
+device attributes as the originating constructor.
+
+.. code-block:: cpp
+
+   template <typename T>
+   class MyVector {
+   public:
+       __device__ MyVector(T) { /* ... */ }
+       //...
+   };
+
+   // The implicit deduction guide for MyVector will be `__device__` due to the device constructor
+
+   __device__ void foo() {
+       MyVector vec(42);  // Uses the implicit device guide derived from the constructor
+   }
+
+Availability Checks
+--------------------
+When a deduction guide (either explicit or implicit) is used, HIP checks its
+availability based on its host/device attributes and the context in a similar
+way as checking a function. Utilizing a deduction guide in an incompatible context
+results in a compile-time error.
+
diff --git a/clang/include/clang/AST/DeclCXX.h b/clang/include/clang/AST/DeclCXX.h
index 5eaae6bdd2bc63e..863ced731d42b2f 100644
--- a/clang/include/clang/AST/DeclCXX.h
+++ b/clang/include/clang/AST/DeclCXX.h
@@ -1948,14 +1948,7 @@ class CXXDeductionGuideDecl : public FunctionDecl {
                         ExplicitSpecifier ES,
                         const DeclarationNameInfo &NameInfo, QualType T,
                         TypeSourceInfo *TInfo, SourceLocation EndLocation,
-                        CXXConstructorDecl *Ctor, DeductionCandidate Kind)
-      : FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
-                     SC_None, false, false, ConstexprSpecKind::Unspecified),
-        Ctor(Ctor), ExplicitSpec(ES) {
-    if (EndLocation.isValid())
-      setRangeEnd(EndLocation);
-    setDeductionCandidateKind(Kind);
-  }
+                        CXXConstructorDecl *Ctor, DeductionCandidate Kind);
 
   CXXConstructorDecl *Ctor;
   ExplicitSpecifier ExplicitSpec;
diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp
index 9107525a44f22c2..e0683173e24f440 100644
--- a/clang/lib/AST/DeclCXX.cpp
+++ b/clang/lib/AST/DeclCXX.cpp
@@ -2113,6 +2113,27 @@ ExplicitSpecifier ExplicitSpecifier::getFromDecl(FunctionDecl *Function) {
   }
 }
 
+CXXDeductionGuideDecl::CXXDeductionGuideDecl(
+    ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
+    ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
+    TypeSourceInfo *TInfo, SourceLocation EndLocation, CXXConstructorDecl *Ctor,
+    DeductionCandidate Kind)
+    : FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
+                   SC_None, false, false, ConstexprSpecKind::Unspecified),
+      Ctor(Ctor), ExplicitSpec(ES) {
+  if (EndLocation.isValid())
+    setRangeEnd(EndLocation);
+  setDeductionCandidateKind(Kind);
+  // If Ctor is not nullptr, this deduction guide is implicitly derived from
+  // the ctor, therefore it should have the same host/device attribute.
+  if (Ctor && C.getLangOpts().CUDA) {
+    if (Ctor->hasAttr<CUDAHostAttr>())
+      this->addAttr(CUDAHostAttr::CreateImplicit(C));
+    if (Ctor->hasAttr<CUDADeviceAttr>())
+      this->addAttr(CUDADeviceAttr::CreateImplicit(C));
+  }
+}
+
 CXXDeductionGuideDecl *CXXDeductionGuideDecl::Create(
     ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
     ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index d993499cf4a6e6e..d1d59ad1b9fc4b1 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -149,10 +149,13 @@ Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
     return CFT_Device;
   } else if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
     return CFT_Host;
-  } else if ((D->isImplicit() || !D->isUserProvided()) &&
+  } else if (!isa<CXXDeductionGuideDecl>(D) &&
+             (D->isImplicit() || !D->isUserProvided()) &&
              !IgnoreImplicitHDAttr) {
     // Some implicit declarations (like intrinsic functions) are not marked.
     // Set the most lenient target on them for maximal flexibility.
+    // Implicit deduction duides are derived from constructors and their
+    // host/device attributes are determined by their originating constructors.
     return CFT_HostDevice;
   }
 
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 6389ec708bf34ae..0b854f06a95743b 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2685,19 +2685,27 @@ void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,
     AddedAny = true;
   }
 
+  // Build simple deduction guide and set CUDA host/device attributes.
+  auto BuildSimpleDeductionGuide = [&](auto T) {
+    auto *DG = cast<CXXDeductionGuideDecl>(
+        cast<FunctionTemplateDecl>(Transform.buildSimpleDeductionGuide(T))
+            ->getTemplatedDecl());
+    if (LangOpts.CUDA) {
+      DG->addAttr(CUDAHostAttr::CreateImplicit(getASTContext()));
+      DG->addAttr(CUDADeviceAttr::CreateImplicit(getASTContext()));
+    }
+    return DG;
+  };
   // C++17 [over.match.class.deduct]
   //    --  If C is not defined or does not declare any constructors, an
   //    additional function template derived as above from a hypothetical
   //    constructor C().
   if (!AddedAny)
-    Transform.buildSimpleDeductionGuide(std::nullopt);
+    BuildSimpleDeductionGuide(std::nullopt);
 
   //    -- An additional function template derived as above from a hypothetical
   //    constructor C(C), called the copy deduction candidate.
-  cast<CXXDeductionGuideDecl>(
-      cast<FunctionTemplateDecl>(
-          Transform.buildSimpleDeductionGuide(Transform.DeducedType))
-          ->getTemplatedDecl())
+  BuildSimpleDeductionGuide(Transform.DeducedType)
       ->setDeductionCandidateKind(DeductionCandidate::Copy);
 }
 
diff --git a/clang/test/SemaCUDA/deduction-guide.cu b/clang/test/SemaCUDA/deduction-guide.cu
new file mode 100644
index 000000000000000..505700c4b58daed
--- /dev/null
+++ b/clang/test/SemaCUDA/deduction-guide.cu
@@ -0,0 +1,85 @@
+// RUN: %clang_cc1 -fsyntax-only -verify=expected,host %s
+// RUN: %clang_cc1 -fcuda-is-device -fsyntax-only -verify=expected,dev %s
+
+#include "Inputs/cuda.h"
+
+// Implicit deduction guide for host.
+template <typename T>
+struct HGuideImp {       // expected-note {{candidate template ignored: could not match 'HGuideImp<T>' against 'int'}}
+   HGuideImp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+                         // dev-note at -1 {{'<deduction guide for HGuideImp><int>' declared here}}
+                         // dev-note at -2 {{'HGuideImp' declared here}}
+};
+
+// Explicit deduction guide for host.
+template <typename T>
+struct HGuideExp {       // expected-note {{candidate template ignored: could not match 'HGuideExp<T>' against 'int'}}
+   HGuideExp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+                         // dev-note at -1 {{'HGuideExp' declared here}}
+};
+template<typename T>
+HGuideExp(T) -> HGuideExp<T>; // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+                              // dev-note at -1 {{'<deduction guide for HGuideExp><int>' declared here}}
+
+// Implicit deduction guide for device.
+template <typename T>
+struct DGuideImp {                  // expected-note {{candidate template ignored: could not match 'DGuideImp<T>' against 'int'}}
+   __device__ DGuideImp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+                                    // host-note at -1 {{'<deduction guide for DGuideImp><int>' declared here}}
+                                    // host-note at -2 {{'DGuideImp' declared here}}
+};
+
+// Explicit deduction guide for device.
+template <typename T>
+struct DGuideExp {                   // expected-note {{candidate template ignored: could not match 'DGuideExp<T>' against 'int'}}
+   __device__ DGuideExp(T value) {}  // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+                                     // host-note at -1 {{'DGuideExp' declared here}}
+};
+
+template<typename T>
+__device__ DGuideExp(T) -> DGuideExp<T>; // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+                                         // host-note at -1 {{'<deduction guide for DGuideExp><int>' declared here}}
+
+template <typename T>
+struct HDGuide {
+   __device__ HDGuide(T value) {}
+   HDGuide(T value) {}
+};
+
+template<typename T>
+HDGuide(T) -> HDGuide<T>;
+
+template<typename T>
+__device__ HDGuide(T) -> HDGuide<T>;
+
+void hfun() {
+    HGuideImp hgi = 10;
+    HGuideExp hge = 10;
+    DGuideImp dgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideImp'}}
+    DGuideExp dge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideExp'}}
+    HDGuide hdg = 10;
+}
+
+__device__ void dfun() {
+    HGuideImp hgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideImp'}}
+    HGuideExp hge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideExp'}}
+    DGuideImp dgi = 10;
+    DGuideExp dge = 10;
+    HDGuide hdg = 10;
+}
+
+__host__ __device__ void hdfun() {
+    HGuideImp hgi = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideImp><int>' in __host__ __device__ function}}
+                        // dev-error at -1 {{reference to __host__ function 'HGuideImp' in __host__ __device__ function}}
+    HGuideExp hge = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideExp><int>' in __host__ __device__ function}}
+                        // dev-error at -1 {{reference to __host__ function 'HGuideExp' in __host__ __device__ function}}
+    DGuideImp dgi = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideImp><int>' in __host__ __device__ function}}
+                        // host-error at -1 {{reference to __device__ function 'DGuideImp' in __host__ __device__ function}}
+    DGuideExp dge = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideExp><int>' in __host__ __device__ function}}
+                        // host-error at -1 {{reference to __device__ function 'DGuideExp' in __host__ __device__ function}}
+    HDGuide hdg = 10;
+}
+
+HGuideImp hgi = 10;
+HGuideExp hge = 10;
+HDGuide hdg = 10;



More information about the cfe-commits mailing list