[clang] [clang] Compiler builtin for deduping a list of types (PR #105817)

Utkarsh Saxena via cfe-commits cfe-commits at lists.llvm.org
Mon Aug 26 07:06:25 PDT 2024


https://github.com/usx95 updated https://github.com/llvm/llvm-project/pull/105817

>From 77003063912f691d246c4f94dd7a952ceace9268 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Fri, 23 Aug 2024 11:57:40 +0000
Subject: [PATCH 1/3] [clang] Compiler builtin for deduping a list of types

---
 .../clang/Basic/TransformTypeTraits.def       |  2 ++
 clang/include/clang/Sema/DeclSpec.h           |  2 +-
 clang/include/clang/Sema/Sema.h               |  1 +
 clang/lib/Sema/SemaTemplate.cpp               | 33 +++++++++++++++++++
 clang/lib/Sema/SemaType.cpp                   |  4 +++
 5 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Basic/TransformTypeTraits.def b/clang/include/clang/Basic/TransformTypeTraits.def
index e27a2719a9680f..15313fb2db01eb 100644
--- a/clang/include/clang/Basic/TransformTypeTraits.def
+++ b/clang/include/clang/Basic/TransformTypeTraits.def
@@ -26,4 +26,6 @@ TRANSFORM_TYPE_TRAIT_DEF(RemoveReference, remove_reference_t)
 TRANSFORM_TYPE_TRAIT_DEF(RemoveRestrict, remove_restrict)
 TRANSFORM_TYPE_TRAIT_DEF(RemoveVolatile, remove_volatile)
 TRANSFORM_TYPE_TRAIT_DEF(EnumUnderlyingType, underlying_type)
+TRANSFORM_TYPE_TRAIT_DEF(DedupTemplateArgs, dedup_template_args)
+
 #undef TRANSFORM_TYPE_TRAIT_DEF
diff --git a/clang/include/clang/Sema/DeclSpec.h b/clang/include/clang/Sema/DeclSpec.h
index 425b6e2a0b30c9..fbfd68054cf002 100644
--- a/clang/include/clang/Sema/DeclSpec.h
+++ b/clang/include/clang/Sema/DeclSpec.h
@@ -469,7 +469,7 @@ class DeclSpec {
             T == TST_class);
   }
   static bool isTransformTypeTrait(TST T) {
-    constexpr std::array<TST, 16> Traits = {
+    constexpr std::array<TST, 17> Traits = {
 #define TRANSFORM_TYPE_TRAIT_DEF(_, Trait) TST_##Trait,
 #include "clang/Basic/TransformTypeTraits.def"
     };
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 2ec6367eccea01..d1f6a5c11a84e3 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14870,6 +14870,7 @@ class Sema final : public SemaBase {
   QualType BuildUnaryTransformType(QualType BaseType, UTTKind UKind,
                                    SourceLocation Loc);
   QualType BuiltinEnumUnderlyingType(QualType BaseType, SourceLocation Loc);
+  QualType BuiltinDedupTemplateArgs(QualType BaseType, SourceLocation Loc);
   QualType BuiltinAddPointer(QualType BaseType, SourceLocation Loc);
   QualType BuiltinRemovePointer(QualType BaseType, SourceLocation Loc);
   QualType BuiltinDecay(QualType BaseType, SourceLocation Loc);
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 87b1f98bbe5ac9..28efc766401c7a 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -17,7 +17,9 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/AST/TemplateBase.h"
 #include "clang/AST/TemplateName.h"
+#include "clang/AST/Type.h"
 #include "clang/AST/TypeVisitor.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/DiagnosticSema.h"
@@ -38,6 +40,7 @@
 #include "clang/Sema/Template.h"
 #include "clang/Sema/TemplateDeduction.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
@@ -8037,6 +8040,36 @@ static bool CheckNonTypeTemplatePartialSpecializationArgs(
   return false;
 }
 
+QualType Sema::BuiltinDedupTemplateArgs(QualType BaseType, SourceLocation Loc) {
+  if (RequireCompleteType(Loc, BaseType,
+                          diag::err_incomplete_type_used_in_type_trait_expr))
+    return QualType();
+  const ElaboratedType *ET = cast<ElaboratedType>(BaseType);
+  auto *TST = ET->getNamedType()->castAs<TemplateSpecializationType>();
+  if (!TST) {
+    Diag(Loc, diag::err_underlying_type_of_incomplete_enum) << BaseType;
+    return QualType();
+  }
+  TemplateArgumentListInfo Args(Loc, Loc);
+  auto AddArg = [&](TemplateArgument T) {
+    Args.addArgument(TemplateArgumentLoc(
+        T, Context.getTrivialTypeSourceInfo(T.getAsType(), Loc)));
+  };
+  llvm::DenseSet<QualType> SeenArgTypes;
+  for (const auto &T : TST->template_arguments()) {
+    if (SeenArgTypes.contains(T.getAsType()))
+      continue;
+    AddArg(T);
+    SeenArgTypes.insert(T.getAsType());
+  }
+  QualType DedupType = CheckTemplateIdType(TST->getTemplateName(), Loc, Args);
+
+  if (RequireCompleteType(Loc, DedupType,
+                          diag::err_coroutine_type_missing_specialization))
+    return QualType();
+  return DedupType;
+}
+
 bool Sema::CheckTemplatePartialSpecializationArgs(
     SourceLocation TemplateNameLoc, TemplateDecl *PrimaryTemplate,
     unsigned NumExplicit, ArrayRef<TemplateArgument> TemplateArgs) {
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 6fa39cdccef2b9..a9a05c3f15e1fa 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9736,6 +9736,10 @@ QualType Sema::BuildUnaryTransformType(QualType BaseType, UTTKind UKind,
     Result = BuiltinEnumUnderlyingType(BaseType, Loc);
     break;
   }
+  case UnaryTransformType::DedupTemplateArgs: {
+    Result = BuiltinDedupTemplateArgs(BaseType, Loc);
+    break;
+  }
   case UnaryTransformType::AddPointer: {
     Result = BuiltinAddPointer(BaseType, Loc);
     break;

>From 67db58b0ee4bb23b6b52b200e4c56c7402f9f527 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Fri, 23 Aug 2024 13:38:13 +0000
Subject: [PATCH 2/3] todo: Revert back to DenseSet<QualType>

---
 clang/lib/Sema/SemaTemplate.cpp    | 28 +++++++++++++---------------
 clang/test/SemaCXX/type-traits.cpp |  7 +++++++
 2 files changed, 20 insertions(+), 15 deletions(-)

diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 04f5a62c17ebb1..84f2692b0bdca3 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -44,6 +44,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/TimeProfiler.h"
 
 #include <iterator>
 #include <optional>
@@ -8045,33 +8046,30 @@ static bool CheckNonTypeTemplatePartialSpecializationArgs(
 }
 
 QualType Sema::BuiltinDedupTemplateArgs(QualType BaseType, SourceLocation Loc) {
+  llvm::TimeTraceScope TimeTrace("BuiltinDedupTemplateArgs");
   if (RequireCompleteType(Loc, BaseType,
                           diag::err_incomplete_type_used_in_type_trait_expr))
     return QualType();
   const ElaboratedType *ET = cast<ElaboratedType>(BaseType);
   auto *TST = ET->getNamedType()->castAs<TemplateSpecializationType>();
   if (!TST) {
-    Diag(Loc, diag::err_underlying_type_of_incomplete_enum) << BaseType;
+    Diag(Loc, diag::err_incomplete_type_used_in_type_trait_expr) << BaseType;
     return QualType();
   }
   TemplateArgumentListInfo Args(Loc, Loc);
-  auto AddArg = [&](TemplateArgument T) {
-    Args.addArgument(TemplateArgumentLoc(
-        T, Context.getTrivialTypeSourceInfo(T.getAsType(), Loc)));
-  };
-  llvm::DenseSet<QualType> SeenArgTypes;
-  for (const auto &T : TST->template_arguments()) {
-    if (SeenArgTypes.contains(T.getAsType()))
+  llvm::DenseSet<const Type *> SeenArgTypes;
+  for (const auto &Arg : TST->template_arguments()) {
+    if (!SeenArgTypes.insert(Arg.getAsType().getTypePtr()).second)
       continue;
-    AddArg(T);
-    SeenArgTypes.insert(T.getAsType());
+    Args.addArgument(TemplateArgumentLoc(
+        Arg, Context.getTrivialTypeSourceInfo(Arg.getAsType(), Loc)));
   }
-  QualType DedupType = CheckTemplateIdType(TST->getTemplateName(), Loc, Args);
-
-  if (RequireCompleteType(Loc, DedupType,
-                          diag::err_coroutine_type_missing_specialization))
+  QualType DedupedTypes =
+      CheckTemplateIdType(TST->getTemplateName(), Loc, Args);
+  if (RequireCompleteType(Loc, DedupedTypes,
+                          diag::err_incomplete_type_used_in_type_trait_expr))
     return QualType();
-  return DedupType;
+  return DedupedTypes;
 }
 
 bool Sema::CheckTemplatePartialSpecializationArgs(
diff --git a/clang/test/SemaCXX/type-traits.cpp b/clang/test/SemaCXX/type-traits.cpp
index bf069d9bc082c3..b312ebc91fe8c9 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -5014,3 +5014,10 @@ void remove_all_extents() {
   using SomeArray = int[1][2];
   static_assert(__is_same(remove_all_extents_t<const SomeArray>, const int));
 }
+
+template <class T> using dedup_template_args_t = __dedup_template_args(T);
+template <typename... T> struct TypeList{};
+void dedup_types() {
+  static_assert(__is_same(dedup_template_args_t<TypeList<int, int, double, int>>, 
+                          TypeList<int,double>));
+}

>From aa71b0fa6dc6b6ab6d47b14b7104d1c6b841650b Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Mon, 26 Aug 2024 14:06:11 +0000
Subject: [PATCH 3/3] Revert to DenseSet<Qualtype>

---
 clang/lib/Sema/SemaTemplate.cpp    | 28 +---------------------------
 clang/lib/Sema/SemaType.cpp        | 28 ++++++++++++++++++++++++++++
 clang/test/SemaCXX/type-traits.cpp |  2 +-
 3 files changed, 30 insertions(+), 28 deletions(-)

diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 84f2692b0bdca3..4214f7bd46adbf 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -40,6 +40,7 @@
 #include "clang/Sema/Template.h"
 #include "clang/Sema/TemplateDeduction.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallString.h"
@@ -8045,33 +8046,6 @@ static bool CheckNonTypeTemplatePartialSpecializationArgs(
   return false;
 }
 
-QualType Sema::BuiltinDedupTemplateArgs(QualType BaseType, SourceLocation Loc) {
-  llvm::TimeTraceScope TimeTrace("BuiltinDedupTemplateArgs");
-  if (RequireCompleteType(Loc, BaseType,
-                          diag::err_incomplete_type_used_in_type_trait_expr))
-    return QualType();
-  const ElaboratedType *ET = cast<ElaboratedType>(BaseType);
-  auto *TST = ET->getNamedType()->castAs<TemplateSpecializationType>();
-  if (!TST) {
-    Diag(Loc, diag::err_incomplete_type_used_in_type_trait_expr) << BaseType;
-    return QualType();
-  }
-  TemplateArgumentListInfo Args(Loc, Loc);
-  llvm::DenseSet<const Type *> SeenArgTypes;
-  for (const auto &Arg : TST->template_arguments()) {
-    if (!SeenArgTypes.insert(Arg.getAsType().getTypePtr()).second)
-      continue;
-    Args.addArgument(TemplateArgumentLoc(
-        Arg, Context.getTrivialTypeSourceInfo(Arg.getAsType(), Loc)));
-  }
-  QualType DedupedTypes =
-      CheckTemplateIdType(TST->getTemplateName(), Loc, Args);
-  if (RequireCompleteType(Loc, DedupedTypes,
-                          diag::err_incomplete_type_used_in_type_trait_expr))
-    return QualType();
-  return DedupedTypes;
-}
-
 bool Sema::CheckTemplatePartialSpecializationArgs(
     SourceLocation TemplateNameLoc, TemplateDecl *PrimaryTemplate,
     unsigned NumExplicit, ArrayRef<TemplateArgument> TemplateArgs) {
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 3bcaad4c3af2d4..a7590dc27561c8 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -50,6 +50,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/TimeProfiler.h"
 #include <bitset>
 #include <optional>
 
@@ -9584,6 +9585,33 @@ QualType Sema::BuiltinEnumUnderlyingType(QualType BaseType,
   return GetEnumUnderlyingType(*this, BaseType, Loc);
 }
 
+QualType Sema::BuiltinDedupTemplateArgs(QualType BaseType, SourceLocation Loc) {
+  llvm::TimeTraceScope TimeTrace("BuiltinDedupTemplateArgs");
+  if (RequireCompleteType(Loc, BaseType,
+                          diag::err_incomplete_type_used_in_type_trait_expr))
+    return QualType();
+  const ElaboratedType *ET = cast<ElaboratedType>(BaseType);
+  auto *TST = ET->getNamedType()->castAs<TemplateSpecializationType>();
+  if (!TST) {
+    Diag(Loc, diag::err_incomplete_type_used_in_type_trait_expr) << BaseType;
+    return QualType();
+  }
+  TemplateArgumentListInfo Args(Loc, Loc);
+  llvm::DenseSet<QualType> SeenArgTypes;
+  for (const auto &Arg : TST->template_arguments()) {
+    if (!SeenArgTypes.insert(Arg.getAsType().getCanonicalType()).second)
+      continue;
+    Args.addArgument(TemplateArgumentLoc(
+        Arg, Context.getTrivialTypeSourceInfo(Arg.getAsType(), Loc)));
+  }
+  QualType DedupedTypes =
+      CheckTemplateIdType(TST->getTemplateName(), Loc, Args);
+  if (RequireCompleteType(Loc, DedupedTypes,
+                          diag::err_incomplete_type_used_in_type_trait_expr))
+    return QualType();
+  return DedupedTypes;
+}
+
 QualType Sema::BuiltinAddPointer(QualType BaseType, SourceLocation Loc) {
   QualType Pointer = BaseType.isReferenceable() || BaseType->isVoidType()
                          ? BuildPointerType(BaseType.getNonReferenceType(), Loc,
diff --git a/clang/test/SemaCXX/type-traits.cpp b/clang/test/SemaCXX/type-traits.cpp
index b312ebc91fe8c9..7dbf7a1c4edb62 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -5018,6 +5018,6 @@ void remove_all_extents() {
 template <class T> using dedup_template_args_t = __dedup_template_args(T);
 template <typename... T> struct TypeList{};
 void dedup_types() {
-  static_assert(__is_same(dedup_template_args_t<TypeList<int, int, double, int>>, 
+  static_assert(__is_same(dedup_template_args_t<TypeList<int, int, double, int>>,
                           TypeList<int,double>));
 }



More information about the cfe-commits mailing list