[clang] [CUDA][HIP] Fix deduction guide (PR #69366)
Yaxun Liu via cfe-commits
cfe-commits at lists.llvm.org
Tue Oct 17 17:51:27 PDT 2023
https://github.com/yxsamliu updated https://github.com/llvm/llvm-project/pull/69366
>From b28384d33f858a6d4139da931b436cbf1a0a426a Mon Sep 17 00:00:00 2001
From: "Yaxun (Sam) Liu" <yaxun.liu at amd.com>
Date: Sat, 14 Oct 2023 17:28:13 -0400
Subject: [PATCH] [CUDA][HIP] Fix deduction guide
Currently clang assumes implicit deduction guide to be host
device. This generates two identical implicit deduction
guides when a class have a device and a host constructor
which have the same input parameter, which causes ambiguity.
Since an implicit deduction guide is derived from a constructor,
it should take the same host/device attribute as the originating
constructor. This matches nvcc behavior as seen in
https://godbolt.org/z/sY1vdYWKe and https://godbolt.org/z/vTer7xa3j
---
clang/docs/HIPSupport.rst | 55 +++++++++++++++++
clang/include/clang/AST/DeclCXX.h | 9 +--
clang/lib/AST/DeclCXX.cpp | 21 +++++++
clang/lib/Sema/SemaCUDA.cpp | 5 +-
clang/lib/Sema/SemaTemplate.cpp | 18 ++++--
clang/test/SemaCUDA/deduction-guide.cu | 85 ++++++++++++++++++++++++++
6 files changed, 179 insertions(+), 14 deletions(-)
create mode 100644 clang/test/SemaCUDA/deduction-guide.cu
diff --git a/clang/docs/HIPSupport.rst b/clang/docs/HIPSupport.rst
index 8b4649733a9c777..8a9802e19e6367f 100644
--- a/clang/docs/HIPSupport.rst
+++ b/clang/docs/HIPSupport.rst
@@ -176,3 +176,58 @@ Predefined Macros
* - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
- Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
+Support for Deduction Guides
+============================
+
+Explicit Deduction Guides
+-------------------------
+
+Explicit deduction guides in HIP can be annotated with either the
+``__host__`` or ``__device__`` attributes. If no attribute is provided,
+it defaults to ``__host__``.
+
+.. code-block:: cpp
+
+ template <typename T>
+ class MyArray {
+ //...
+ };
+
+ template <typename T>
+ MyArray(T)->MyArray<T>;
+
+ __device__ MyArray(float)->MyArray<int>;
+
+ // Uses of the deduction guides
+ MyArray arr1 = 10; // Uses the default host guide
+ __device__ void foo() {
+ MyArray arr2 = 3.14f; // Uses the device guide
+ }
+
+Implicit Deduction Guides
+-------------------------
+Implicit deduction guides derived from constructors inherit the same host or
+device attributes as the originating constructor.
+
+.. code-block:: cpp
+
+ template <typename T>
+ class MyVector {
+ public:
+ __device__ MyVector(T) { /* ... */ }
+ //...
+ };
+
+ // The implicit deduction guide for MyVector will be `__device__` due to the device constructor
+
+ __device__ void foo() {
+ MyVector vec(42); // Uses the implicit device guide derived from the constructor
+ }
+
+Availability Checks
+--------------------
+When a deduction guide (either explicit or implicit) is used, HIP checks its
+availability based on its host/device attributes and the context in a similar
+way as checking a function. Utilizing a deduction guide in an incompatible context
+results in a compile-time error.
+
diff --git a/clang/include/clang/AST/DeclCXX.h b/clang/include/clang/AST/DeclCXX.h
index 5eaae6bdd2bc63e..863ced731d42b2f 100644
--- a/clang/include/clang/AST/DeclCXX.h
+++ b/clang/include/clang/AST/DeclCXX.h
@@ -1948,14 +1948,7 @@ class CXXDeductionGuideDecl : public FunctionDecl {
ExplicitSpecifier ES,
const DeclarationNameInfo &NameInfo, QualType T,
TypeSourceInfo *TInfo, SourceLocation EndLocation,
- CXXConstructorDecl *Ctor, DeductionCandidate Kind)
- : FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
- SC_None, false, false, ConstexprSpecKind::Unspecified),
- Ctor(Ctor), ExplicitSpec(ES) {
- if (EndLocation.isValid())
- setRangeEnd(EndLocation);
- setDeductionCandidateKind(Kind);
- }
+ CXXConstructorDecl *Ctor, DeductionCandidate Kind);
CXXConstructorDecl *Ctor;
ExplicitSpecifier ExplicitSpec;
diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp
index 9107525a44f22c2..e0683173e24f440 100644
--- a/clang/lib/AST/DeclCXX.cpp
+++ b/clang/lib/AST/DeclCXX.cpp
@@ -2113,6 +2113,27 @@ ExplicitSpecifier ExplicitSpecifier::getFromDecl(FunctionDecl *Function) {
}
}
+CXXDeductionGuideDecl::CXXDeductionGuideDecl(
+ ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
+ ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
+ TypeSourceInfo *TInfo, SourceLocation EndLocation, CXXConstructorDecl *Ctor,
+ DeductionCandidate Kind)
+ : FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
+ SC_None, false, false, ConstexprSpecKind::Unspecified),
+ Ctor(Ctor), ExplicitSpec(ES) {
+ if (EndLocation.isValid())
+ setRangeEnd(EndLocation);
+ setDeductionCandidateKind(Kind);
+ // If Ctor is not nullptr, this deduction guide is implicitly derived from
+ // the ctor, therefore it should have the same host/device attribute.
+ if (Ctor && C.getLangOpts().CUDA) {
+ if (Ctor->hasAttr<CUDAHostAttr>())
+ this->addAttr(CUDAHostAttr::CreateImplicit(C));
+ if (Ctor->hasAttr<CUDADeviceAttr>())
+ this->addAttr(CUDADeviceAttr::CreateImplicit(C));
+ }
+}
+
CXXDeductionGuideDecl *CXXDeductionGuideDecl::Create(
ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index d993499cf4a6e6e..d1d59ad1b9fc4b1 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -149,10 +149,13 @@ Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
return CFT_Device;
} else if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
return CFT_Host;
- } else if ((D->isImplicit() || !D->isUserProvided()) &&
+ } else if (!isa<CXXDeductionGuideDecl>(D) &&
+ (D->isImplicit() || !D->isUserProvided()) &&
!IgnoreImplicitHDAttr) {
// Some implicit declarations (like intrinsic functions) are not marked.
// Set the most lenient target on them for maximal flexibility.
+ // Implicit deduction duides are derived from constructors and their
+ // host/device attributes are determined by their originating constructors.
return CFT_HostDevice;
}
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 6389ec708bf34ae..0b854f06a95743b 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2685,19 +2685,27 @@ void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,
AddedAny = true;
}
+ // Build simple deduction guide and set CUDA host/device attributes.
+ auto BuildSimpleDeductionGuide = [&](auto T) {
+ auto *DG = cast<CXXDeductionGuideDecl>(
+ cast<FunctionTemplateDecl>(Transform.buildSimpleDeductionGuide(T))
+ ->getTemplatedDecl());
+ if (LangOpts.CUDA) {
+ DG->addAttr(CUDAHostAttr::CreateImplicit(getASTContext()));
+ DG->addAttr(CUDADeviceAttr::CreateImplicit(getASTContext()));
+ }
+ return DG;
+ };
// C++17 [over.match.class.deduct]
// -- If C is not defined or does not declare any constructors, an
// additional function template derived as above from a hypothetical
// constructor C().
if (!AddedAny)
- Transform.buildSimpleDeductionGuide(std::nullopt);
+ BuildSimpleDeductionGuide(std::nullopt);
// -- An additional function template derived as above from a hypothetical
// constructor C(C), called the copy deduction candidate.
- cast<CXXDeductionGuideDecl>(
- cast<FunctionTemplateDecl>(
- Transform.buildSimpleDeductionGuide(Transform.DeducedType))
- ->getTemplatedDecl())
+ BuildSimpleDeductionGuide(Transform.DeducedType)
->setDeductionCandidateKind(DeductionCandidate::Copy);
}
diff --git a/clang/test/SemaCUDA/deduction-guide.cu b/clang/test/SemaCUDA/deduction-guide.cu
new file mode 100644
index 000000000000000..505700c4b58daed
--- /dev/null
+++ b/clang/test/SemaCUDA/deduction-guide.cu
@@ -0,0 +1,85 @@
+// RUN: %clang_cc1 -fsyntax-only -verify=expected,host %s
+// RUN: %clang_cc1 -fcuda-is-device -fsyntax-only -verify=expected,dev %s
+
+#include "Inputs/cuda.h"
+
+// Implicit deduction guide for host.
+template <typename T>
+struct HGuideImp { // expected-note {{candidate template ignored: could not match 'HGuideImp<T>' against 'int'}}
+ HGuideImp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+ // dev-note at -1 {{'<deduction guide for HGuideImp><int>' declared here}}
+ // dev-note at -2 {{'HGuideImp' declared here}}
+};
+
+// Explicit deduction guide for host.
+template <typename T>
+struct HGuideExp { // expected-note {{candidate template ignored: could not match 'HGuideExp<T>' against 'int'}}
+ HGuideExp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+ // dev-note at -1 {{'HGuideExp' declared here}}
+};
+template<typename T>
+HGuideExp(T) -> HGuideExp<T>; // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+ // dev-note at -1 {{'<deduction guide for HGuideExp><int>' declared here}}
+
+// Implicit deduction guide for device.
+template <typename T>
+struct DGuideImp { // expected-note {{candidate template ignored: could not match 'DGuideImp<T>' against 'int'}}
+ __device__ DGuideImp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+ // host-note at -1 {{'<deduction guide for DGuideImp><int>' declared here}}
+ // host-note at -2 {{'DGuideImp' declared here}}
+};
+
+// Explicit deduction guide for device.
+template <typename T>
+struct DGuideExp { // expected-note {{candidate template ignored: could not match 'DGuideExp<T>' against 'int'}}
+ __device__ DGuideExp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+ // host-note at -1 {{'DGuideExp' declared here}}
+};
+
+template<typename T>
+__device__ DGuideExp(T) -> DGuideExp<T>; // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+ // host-note at -1 {{'<deduction guide for DGuideExp><int>' declared here}}
+
+template <typename T>
+struct HDGuide {
+ __device__ HDGuide(T value) {}
+ HDGuide(T value) {}
+};
+
+template<typename T>
+HDGuide(T) -> HDGuide<T>;
+
+template<typename T>
+__device__ HDGuide(T) -> HDGuide<T>;
+
+void hfun() {
+ HGuideImp hgi = 10;
+ HGuideExp hge = 10;
+ DGuideImp dgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideImp'}}
+ DGuideExp dge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideExp'}}
+ HDGuide hdg = 10;
+}
+
+__device__ void dfun() {
+ HGuideImp hgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideImp'}}
+ HGuideExp hge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideExp'}}
+ DGuideImp dgi = 10;
+ DGuideExp dge = 10;
+ HDGuide hdg = 10;
+}
+
+__host__ __device__ void hdfun() {
+ HGuideImp hgi = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideImp><int>' in __host__ __device__ function}}
+ // dev-error at -1 {{reference to __host__ function 'HGuideImp' in __host__ __device__ function}}
+ HGuideExp hge = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideExp><int>' in __host__ __device__ function}}
+ // dev-error at -1 {{reference to __host__ function 'HGuideExp' in __host__ __device__ function}}
+ DGuideImp dgi = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideImp><int>' in __host__ __device__ function}}
+ // host-error at -1 {{reference to __device__ function 'DGuideImp' in __host__ __device__ function}}
+ DGuideExp dge = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideExp><int>' in __host__ __device__ function}}
+ // host-error at -1 {{reference to __device__ function 'DGuideExp' in __host__ __device__ function}}
+ HDGuide hdg = 10;
+}
+
+HGuideImp hgi = 10;
+HGuideExp hge = 10;
+HDGuide hdg = 10;
More information about the cfe-commits
mailing list