[clang] [HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType (PR #93847)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Fri Jun 7 15:29:05 PDT 2024


https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/93847

>From dd175a247480396b9d35cb995333fcd14152e347 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 29 May 2024 18:38:45 -0700
Subject: [PATCH 1/4] [HLSL] Use llvm::Triple::EnvironmentType instead of
 ShaderType

HLSLShaderAttr::ShaderType enum is a subset of llvm::Triple::EnvironmentType and is not needed.
---
 clang/include/clang/Basic/Attr.td   |  29 +++-----
 clang/include/clang/Sema/SemaHLSL.h |   6 +-
 clang/lib/CodeGen/CGHLSLRuntime.cpp |   2 +-
 clang/lib/Sema/SemaDeclAttr.cpp     |   4 +-
 clang/lib/Sema/SemaHLSL.cpp         | 105 +++++++++++++++-------------
 5 files changed, 72 insertions(+), 74 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 2665b7353ca4a..e373c073ec906 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4469,36 +4469,23 @@ def HLSLShader : InheritableAttr {
   let Subjects = SubjectList<[HLSLEntry]>;
   let LangOpts = [HLSL];
   let Args = [
-    EnumArgument<"Type", "ShaderType", /*is_string=*/true,
+    EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
                  ["pixel", "vertex", "geometry", "hull", "domain", "compute",
                   "raygeneration", "intersection", "anyhit", "closesthit",
                   "miss", "callable", "mesh", "amplification"],
                  ["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
                   "RayGeneration", "Intersection", "AnyHit", "ClosestHit",
-                  "Miss", "Callable", "Mesh", "Amplification"]>
+                  "Miss", "Callable", "Mesh", "Amplification"],
+                  /*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
   ];
   let Documentation = [HLSLSV_ShaderTypeAttrDocs];
   let AdditionalMembers =
 [{
-  static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;
-
-  static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
-    switch (ShaderType) {
-      case HLSLShaderAttr::Pixel:         return llvm::Triple::Pixel;
-      case HLSLShaderAttr::Vertex:        return llvm::Triple::Vertex;
-      case HLSLShaderAttr::Geometry:      return llvm::Triple::Geometry;
-      case HLSLShaderAttr::Hull:          return llvm::Triple::Hull;
-      case HLSLShaderAttr::Domain:        return llvm::Triple::Domain;
-      case HLSLShaderAttr::Compute:       return llvm::Triple::Compute;
-      case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
-      case HLSLShaderAttr::Intersection:  return llvm::Triple::Intersection;
-      case HLSLShaderAttr::AnyHit:        return llvm::Triple::AnyHit;
-      case HLSLShaderAttr::ClosestHit:    return llvm::Triple::ClosestHit;
-      case HLSLShaderAttr::Miss:          return llvm::Triple::Miss;
-      case HLSLShaderAttr::Callable:      return llvm::Triple::Callable;
-      case HLSLShaderAttr::Mesh:          return llvm::Triple::Mesh;
-      case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
-    }
+  static const llvm::Triple::EnvironmentType MinShaderTypeValue = llvm::Triple::Pixel;
+  static const llvm::Triple::EnvironmentType MaxShaderTypeValue = llvm::Triple::Amplification;
+
+  static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
+    return ShaderType >= MinShaderTypeValue && ShaderType <= MaxShaderTypeValue;
   }
 }];
 }
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index eac1f7c07c85d..00df6c2bd15e4 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -38,7 +38,7 @@ class SemaHLSL : public SemaBase {
                                           const AttributeCommonInfo &AL, int X,
                                           int Y, int Z);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                                  HLSLShaderAttr::ShaderType ShaderType);
+                                  llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
   mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
                          HLSLParamModifierAttr::Spelling Spelling);
@@ -47,8 +47,8 @@ class SemaHLSL : public SemaBase {
   void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
                                const HLSLAnnotationAttr *AnnotationAttr);
   void DiagnoseAttrStageMismatch(
-      const Attr *A, HLSLShaderAttr::ShaderType Stage,
-      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
+      const Attr *A, llvm::Triple::EnvironmentType Stage,
+      std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
   void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
 };
 
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 5e6a3dd4878f4..55ba21ae2ba69 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
   assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
   const StringRef ShaderAttrKindStr = "hlsl.shader";
   Fn->addFnAttr(ShaderAttrKindStr,
-                ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
+                llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
   if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
     const StringRef NumThreadsKindStr = "hlsl.numthreads";
     std::string NumThreadsStr =
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 7c1fb23b90728..49c9de73aafb5 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7341,8 +7341,8 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
     return;
 
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
+  llvm::Triple::EnvironmentType ShaderType;
+  if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
     S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
         << AL << Str << ArgLoc;
     return;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9e614ae99f37d..2795b0af6f1c9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -145,7 +145,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
 
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                          HLSLShaderAttr::ShaderType ShaderType) {
+                          llvm::Triple::EnvironmentType ShaderType) {
   if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
     if (NT->getType() != ShaderType) {
       Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
@@ -183,13 +183,12 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
   if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
     return;
 
-  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+  llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
+  if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
     if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
       // The entry point is already annotated - check that it matches the
       // triple.
-      if (Shader->getType() != ShaderType) {
+      if (Shader->getType() != Env) {
         Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
             << Shader;
         FD->setInvalidDecl();
@@ -197,11 +196,11 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
     } else {
       // Implicitly add the shader attribute if the entry function isn't
       // explicitly annotated.
-      FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+      FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
                                                  FD->getBeginLoc()));
     }
   } else {
-    switch (TargetInfo.getTriple().getEnvironment()) {
+    switch (Env) {
     case llvm::Triple::UnknownEnvironment:
     case llvm::Triple::Library:
       break;
@@ -214,38 +213,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
 void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
   assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+  llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
 
   switch (ST) {
-  case HLSLShaderAttr::Pixel:
-  case HLSLShaderAttr::Vertex:
-  case HLSLShaderAttr::Geometry:
-  case HLSLShaderAttr::Hull:
-  case HLSLShaderAttr::Domain:
-  case HLSLShaderAttr::RayGeneration:
-  case HLSLShaderAttr::Intersection:
-  case HLSLShaderAttr::AnyHit:
-  case HLSLShaderAttr::ClosestHit:
-  case HLSLShaderAttr::Miss:
-  case HLSLShaderAttr::Callable:
+  case llvm::Triple::Pixel:
+  case llvm::Triple::Vertex:
+  case llvm::Triple::Geometry:
+  case llvm::Triple::Hull:
+  case llvm::Triple::Domain:
+  case llvm::Triple::RayGeneration:
+  case llvm::Triple::Intersection:
+  case llvm::Triple::AnyHit:
+  case llvm::Triple::ClosestHit:
+  case llvm::Triple::Miss:
+  case llvm::Triple::Callable:
     if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
       DiagnoseAttrStageMismatch(NT, ST,
-                                {HLSLShaderAttr::Compute,
-                                 HLSLShaderAttr::Amplification,
-                                 HLSLShaderAttr::Mesh});
+                                {llvm::Triple::Compute,
+                                 llvm::Triple::Amplification,
+                                 llvm::Triple::Mesh});
       FD->setInvalidDecl();
     }
     break;
 
-  case HLSLShaderAttr::Compute:
-  case HLSLShaderAttr::Amplification:
-  case HLSLShaderAttr::Mesh:
+  case llvm::Triple::Compute:
+  case llvm::Triple::Amplification:
+  case llvm::Triple::Mesh:
     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+          << llvm::Triple::getEnvironmentTypeName(ST);
       FD->setInvalidDecl();
     }
     break;
+  default:
+    llvm_unreachable("Unhandled environment in triple");
   }
 
   for (ParmVarDecl *Param : FD->parameters()) {
@@ -267,14 +268,14 @@ void SemaHLSL::CheckSemanticAnnotation(
     const HLSLAnnotationAttr *AnnotationAttr) {
   auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
   assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+  llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
 
   switch (AnnotationAttr->getKind()) {
   case attr::HLSLSV_DispatchThreadID:
   case attr::HLSLSV_GroupIndex:
-    if (ST == HLSLShaderAttr::Compute)
+    if (ST == llvm::Triple::Compute)
       return;
-    DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+    DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
     break;
   default:
     llvm_unreachable("Unknown HLSLAnnotationAttr");
@@ -282,16 +283,16 @@ void SemaHLSL::CheckSemanticAnnotation(
 }
 
 void SemaHLSL::DiagnoseAttrStageMismatch(
-    const Attr *A, HLSLShaderAttr::ShaderType Stage,
-    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+    const Attr *A, llvm::Triple::EnvironmentType Stage,
+    std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
   SmallVector<StringRef, 8> StageStrings;
   llvm::transform(AllowedStages, std::back_inserter(StageStrings),
-                  [](HLSLShaderAttr::ShaderType ST) {
+                  [](llvm::Triple::EnvironmentType ST) {
                     return StringRef(
-                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+                        HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
                   });
   Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+      << A << llvm::Triple::getEnvironmentTypeName(Stage)
       << (AllowedStages.size() != 1) << join(StageStrings, ", ");
 }
 
@@ -321,16 +322,22 @@ class DiagnoseHLSLAvailability
   //
   // Maps FunctionDecl to an unsigned number that represents the set of shader
   // environments the function has been scanned for.
-  // Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
-  // defined without any assigned values, it is guaranteed to be numbered
-  // sequentially from 0 up and we can use it to 'index' individual bits
-  // in the set.
+  // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
+  // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
+  // (verified by static_asserts in Triple.cpp), we can use it to index
+  // individual bits in the set, as long as we shift the values to start with 0
+  // by subtracting the value of llvm::Triple::Pixel first.
+  //
   // The N'th bit in the set will be set if the function has been scanned
-  // in shader environment whose ShaderType integer value equals N.
+  // in shader environment whose llvm::Triple::EnvironmentType integer value
+  // equals (llvm::Triple::Pixel + N).
+  //
   // For example, if a function has been scanned in compute and pixel stage
-  // environment, the value will be 0x21 (100001 binary) because
-  // (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
-  // (int)HLSLShaderAttr::ShaderType::Compute == 5.
+  // environment, the value will be 0x21 (100001 binary) because:
+  //
+  //   (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
+  //   (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
+  //
   // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
   // been scanned in any environment.
   llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
@@ -346,12 +353,16 @@ class DiagnoseHLSLAvailability
   bool ReportOnlyShaderStageIssues;
 
   // Helper methods for dealing with current stage context / environment
-  void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
+  void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
     static_assert(sizeof(unsigned) >= 4);
-    assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"
-
-    CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
-    CurrentShaderStageBit = (1 << ShaderType);
+    assert(HLSLShaderAttr::isValidShaderType(ShaderType));
+    assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 && "ShaderType is too big for this bitmap"); // 31 is reserved for "unknown"
+
+    unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
+    assert(((unsigned)1) << bitmapIndex != 0 && bitmapIndex != 31 &&
+           "ShaderType is too big for this bitmap");
+    CurrentShaderEnvironment = ShaderType;
+    CurrentShaderStageBit = (1 << bitmapIndex);
   }
 
   void SetUnknownShaderStageContext() {

>From 4df63e630e49a25d753c3636c1ca86d421421dcf Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Thu, 30 May 2024 10:18:55 -0700
Subject: [PATCH 2/4] Do not define min and max shader type values

---
 clang/include/clang/Basic/Attr.td | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index e373c073ec906..a337509d3e2b5 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4481,11 +4481,8 @@ def HLSLShader : InheritableAttr {
   let Documentation = [HLSLSV_ShaderTypeAttrDocs];
   let AdditionalMembers =
 [{
-  static const llvm::Triple::EnvironmentType MinShaderTypeValue = llvm::Triple::Pixel;
-  static const llvm::Triple::EnvironmentType MaxShaderTypeValue = llvm::Triple::Amplification;
-
   static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
-    return ShaderType >= MinShaderTypeValue && ShaderType <= MaxShaderTypeValue;
+    return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
   }
 }];
 }

>From a9cefc2dfb9f7073439d77bc1ad59eefaeeff8c9 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Thu, 30 May 2024 10:24:19 -0700
Subject: [PATCH 3/4] Cleanup

---
 clang/lib/Sema/SemaHLSL.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 2795b0af6f1c9..da9bda3eaf3d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -356,11 +356,11 @@ class DiagnoseHLSLAvailability
   void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
     static_assert(sizeof(unsigned) >= 4);
     assert(HLSLShaderAttr::isValidShaderType(ShaderType));
-    assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 && "ShaderType is too big for this bitmap"); // 31 is reserved for "unknown"
+    assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
+           "ShaderType is too big for this bitmap"); // 31 is reserved for
+                                                     // "unknown"
 
     unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
-    assert(((unsigned)1) << bitmapIndex != 0 && bitmapIndex != 31 &&
-           "ShaderType is too big for this bitmap");
     CurrentShaderEnvironment = ShaderType;
     CurrentShaderStageBit = (1 << bitmapIndex);
   }

>From e66ee67da934025ea1b4e93c24fbf11360c7f8cf Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 7 Jun 2024 15:28:35 -0700
Subject: [PATCH 4/4] Fix merge issue

---
 clang/lib/Sema/SemaHLSL.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a084cb5c46968..144cdcc0d98ef 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -431,8 +431,8 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
   if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
     return;
 
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
+  llvm::Triple::EnvironmentType ShaderType;
+  if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
     Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
         << AL << Str << ArgLoc;
     return;



More information about the cfe-commits mailing list