[clang] ef2b170 - [Sema][HLSL] Consolidate handling of HLSL attributes

Justin Bogner via cfe-commits cfe-commits at lists.llvm.org
Tue Aug 29 08:59:59 PDT 2023


Author: Justin Bogner
Date: 2023-08-29T08:55:38-07:00
New Revision: ef2b1700f4648816e6a6ce27cfee1c501421ee50

URL: https://github.com/llvm/llvm-project/commit/ef2b1700f4648816e6a6ce27cfee1c501421ee50
DIFF: https://github.com/llvm/llvm-project/commit/ef2b1700f4648816e6a6ce27cfee1c501421ee50.diff

LOG: [Sema][HLSL] Consolidate handling of HLSL attributes

This moves the sema checking of the entrypoint sensitive HLSL
attributes all into one place. This ended up being kind of large for a
couple of reasons:

- I had to move the call to CheckHLSLEntryPoint later in
  ActOnFunctionDeclarator so that we do this after redeclarations and
  have access to all of the attributes.

- We need to transfer the target shader stage onto the specified entry
  point before doing the checking.

- I removed "library" from the HLSLShader attribute value enum and
  just go through a string to convert from the triple - the other way
  was confusing and brittle.

Differential Revision: https://reviews.llvm.org/D158803

Added: 
    clang/test/SemaHLSL/Semantics/groupindex.hlsl
    clang/test/SemaHLSL/entry_shader_redecl.hlsl

Modified: 
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/include/clang/Sema/Sema.h
    clang/lib/Sema/SemaDecl.cpp
    clang/lib/Sema/SemaDeclAttr.cpp
    clang/test/CodeGenHLSL/GlobalDestructors.hlsl
    clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
    clang/test/SemaHLSL/entry.hlsl
    clang/test/SemaHLSL/num_threads.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f7d19a3df25176..2623d1741b9d3b 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4133,24 +4133,14 @@ def HLSLShader : InheritableAttr {
   let Spellings = [Microsoft<"shader">];
   let Subjects = SubjectList<[HLSLEntry]>;
   let LangOpts = [HLSL];
-  // NOTE:
-  // order for the enum should match order in llvm::Triple::EnvironmentType.
-  // ShaderType will be converted to llvm::Triple::EnvironmentType like
-  //   (llvm::Triple::EnvironmentType)((uint32_t)ShaderType +
-  //      (uint32_t)llvm::Triple::EnvironmentType::Pixel).
-  // This will avoid update code for convert when new shader type is added.
   let Args = [
     EnumArgument<"Type", "ShaderType",
-                 [
-                   "pixel", "vertex", "geometry", "hull", "domain", "compute",
-                   "library", "raygeneration", "intersection", "anyhit",
-                   "closesthit", "miss", "callable", "mesh", "amplification"
-                 ],
-                 [
-                   "Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
-                   "Library", "RayGeneration", "Intersection", "AnyHit",
-                   "ClosestHit", "Miss", "Callable", "Mesh", "Amplification"
-                 ]>
+                 ["pixel", "vertex", "geometry", "hull", "domain", "compute",
+                  "raygeneration", "intersection", "anyhit", "closesthit",
+                  "miss", "callable", "mesh", "amplification"],
+                 ["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
+                  "RayGeneration", "Intersection", "AnyHit", "ClosestHit",
+                  "Miss", "Callable", "Mesh", "Amplification"]>
   ];
   let Documentation = [HLSLSV_ShaderTypeAttrDocs];
 }

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 7f0cfb29cddeac..a065015cfe02eb 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11869,7 +11869,7 @@ def err_std_source_location_impl_malformed : Error<
   "'std::source_location::__impl' must be standard-layout and have only two 'const char *' fields '_M_file_name' and '_M_function_name', and two integral fields '_M_line' and '_M_column'">;
 
 // HLSL Diagnostics
-def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %select{Pixel|Vertex|Geometry|Hull|Domain|Compute|Library|RayGeneration|Intersection|AnyHit|ClosestHit|Miss|Callable|Mesh|Amplification|Invalid}1 shaders, requires %2">;
+def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in '%1' shaders, requires %select{|one of the following: }2%3">;
 def err_hlsl_attr_invalid_type : Error<
    "attribute %0 only applies to a field or parameter of type '%1'">;
 def err_hlsl_attr_invalid_ast_node : Error<

diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index ce6731f99d4cbf..28e085ccebbb2f 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -3013,7 +3013,13 @@ class Sema final {
                                       QualType NewT, QualType OldT);
   void CheckMain(FunctionDecl *FD, const DeclSpec &D);
   void CheckMSVCRTEntryPoint(FunctionDecl *FD);
+  void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
   void CheckHLSLEntryPoint(FunctionDecl *FD);
+  void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
+                                   const HLSLAnnotationAttr *AnnotationAttr);
+  void DiagnoseHLSLAttrStageMismatch(
+      const Attr *A, HLSLShaderAttr::ShaderType Stage,
+      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
   Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
                                                    bool IsDefinition);
   void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);

diff  --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 18355b484975af..2ee216fffaa275 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -10338,33 +10338,6 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
     }
   }
 
-  if (getLangOpts().HLSL) {
-    auto &TargetInfo = getASTContext().getTargetInfo();
-    // Skip operator overload which not identifier.
-    // Also make sure NewFD is in translation-unit scope.
-    if (!NewFD->isInvalidDecl() && Name.isIdentifier() &&
-        NewFD->getName() == TargetInfo.getTargetOpts().HLSLEntry &&
-        S->getDepth() == 0) {
-      CheckHLSLEntryPoint(NewFD);
-      if (!NewFD->isInvalidDecl()) {
-        auto Env = TargetInfo.getTriple().getEnvironment();
-        HLSLShaderAttr::ShaderType ShaderType =
-            static_cast<HLSLShaderAttr::ShaderType>(
-                hlsl::getStageFromEnvironment(Env));
-        // To share code with HLSLShaderAttr, add HLSLShaderAttr to entry
-        // function.
-        if (HLSLShaderAttr *NT = NewFD->getAttr<HLSLShaderAttr>()) {
-          if (NT->getType() != ShaderType)
-            Diag(NT->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
-                << NT;
-        } else {
-          NewFD->addAttr(HLSLShaderAttr::Create(Context, ShaderType,
-                                                NewFD->getBeginLoc()));
-        }
-      }
-    }
-  }
-
   if (!getLangOpts().CPlusPlus) {
     // Perform semantic checking on the function declaration.
     if (!NewFD->isInvalidDecl() && NewFD->isMain())
@@ -10654,6 +10627,15 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
     }
   }
 
+  if (getLangOpts().HLSL && D.isFunctionDefinition()) {
+    // Any top level function could potentially be specified as an entry.
+    if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
+      ActOnHLSLTopLevelFunction(NewFD);
+
+    if (NewFD->hasAttr<HLSLShaderAttr>())
+      CheckHLSLEntryPoint(NewFD);
+  }
+
   // If this is the first declaration of a library builtin function, add
   // attributes as appropriate.
   if (!D.isRedeclaration()) {
@@ -12381,24 +12363,84 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) {
   }
 }
 
-void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
+void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
   auto &TargetInfo = getASTContext().getTargetInfo();
-  auto const Triple = TargetInfo.getTriple();
-  switch (Triple.getEnvironment()) {
-  default:
-    // FIXME: check all shader profiles.
+
+  if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
+    return;
+
+  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
+  HLSLShaderAttr::ShaderType ShaderType;
+  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+    if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
+      // The entry point is already annotated - check that it matches the
+      // triple.
+      if (Shader->getType() != ShaderType) {
+        Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
+            << Shader;
+        FD->setInvalidDecl();
+      }
+    } else {
+      // Implicitly add the shader attribute if the entry function isn't
+      // explicitly annotated.
+      FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
+                                                 FD->getBeginLoc()));
+    }
+  } else {
+    switch (TargetInfo.getTriple().getEnvironment()) {
+    case llvm::Triple::UnknownEnvironment:
+    case llvm::Triple::Library:
+      break;
+    default:
+      // TODO: This should probably just be llvm_unreachable and we should
+      // reject triples with random ABIs and such when we build the target.
+      // For now, crash.
+      llvm::report_fatal_error("Unhandled environment in triple");
+    }
+  }
+}
+
+void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
+  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
+  assert(ShaderAttr && "Entry point has no shader attribute");
+  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+  switch (ST) {
+  case HLSLShaderAttr::Pixel:
+  case HLSLShaderAttr::Vertex:
+  case HLSLShaderAttr::Geometry:
+  case HLSLShaderAttr::Hull:
+  case HLSLShaderAttr::Domain:
+  case HLSLShaderAttr::RayGeneration:
+  case HLSLShaderAttr::Intersection:
+  case HLSLShaderAttr::AnyHit:
+  case HLSLShaderAttr::ClosestHit:
+  case HLSLShaderAttr::Miss:
+  case HLSLShaderAttr::Callable:
+    if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
+      DiagnoseHLSLAttrStageMismatch(NT, ST,
+                                    {HLSLShaderAttr::Compute,
+                                     HLSLShaderAttr::Amplification,
+                                     HLSLShaderAttr::Mesh});
+      FD->setInvalidDecl();
+    }
     break;
-  case llvm::Triple::EnvironmentType::Compute:
+
+  case HLSLShaderAttr::Compute:
+  case HLSLShaderAttr::Amplification:
+  case HLSLShaderAttr::Mesh:
     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << Triple.getEnvironmentName();
+          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
       FD->setInvalidDecl();
     }
     break;
   }
 
-  for (const auto *Param : FD->parameters()) {
-    if (!Param->hasAttr<HLSLAnnotationAttr>()) {
+  for (ParmVarDecl *Param : FD->parameters()) {
+    if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
+      CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
+    } else {
       // FIXME: Handle struct parameters where annotations are on struct fields.
       // See: https://github.com/llvm/llvm-project/issues/57875
       Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
@@ -12409,6 +12451,40 @@ void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
   // FIXME: Verify return type semantic annotation.
 }
 
+void Sema::CheckHLSLSemanticAnnotation(
+    FunctionDecl *EntryPoint, const Decl *Param,
+    const HLSLAnnotationAttr *AnnotationAttr) {
+  auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
+  assert(ShaderAttr && "Entry point has no shader attribute");
+  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+  switch (AnnotationAttr->getKind()) {
+  case attr::HLSLSV_DispatchThreadID:
+  case attr::HLSLSV_GroupIndex:
+    if (ST == HLSLShaderAttr::Compute)
+      return;
+    DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST,
+                                  {HLSLShaderAttr::Compute});
+    break;
+  default:
+    llvm_unreachable("Unknown HLSLAnnotationAttr");
+  }
+}
+
+void Sema::DiagnoseHLSLAttrStageMismatch(
+    const Attr *A, HLSLShaderAttr::ShaderType Stage,
+    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+  SmallVector<StringRef, 8> StageStrings;
+  llvm::transform(AllowedStages, std::back_inserter(StageStrings),
+                  [](HLSLShaderAttr::ShaderType ST) {
+                    return StringRef(
+                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+                  });
+  Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
+      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+      << (AllowedStages.size() != 1) << join(StageStrings, ", ");
+}
+
 bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
   // FIXME: Need strict checking.  In C89, we need to check for
   // any assignment, increment, decrement, function-calls, or

diff  --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 3c5245db20637b..4c9807e90df070 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7065,20 +7065,8 @@ static void handleUuidAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
 }
 
 static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  using llvm::Triple;
-  Triple Target = S.Context.getTargetInfo().getTriple();
-  auto Env = S.Context.getTargetInfo().getTriple().getEnvironment();
-  if (!llvm::is_contained({Triple::Compute, Triple::Mesh, Triple::Amplification,
-                           Triple::Library},
-                          Env)) {
-    uint32_t Pipeline =
-        static_cast<uint32_t>(hlsl::getStageFromEnvironment(Env));
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-        << AL << Pipeline << "Compute, Amplification, Mesh or Library";
-    return;
-  }
-
-  llvm::VersionTuple SMVersion = Target.getOSVersion();
+  llvm::VersionTuple SMVersion =
+      S.Context.getTargetInfo().getTriple().getOSVersion();
   uint32_t ZMax = 1024;
   uint32_t ThreadMax = 1024;
   if (SMVersion.getMajor() <= 4) {
@@ -7137,21 +7125,6 @@ HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
   return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
 }
 
-static void handleHLSLSVGroupIndexAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  using llvm::Triple;
-  auto Env = S.Context.getTargetInfo().getTriple().getEnvironment();
-  if (Env != Triple::Compute && Env != Triple::Library) {
-    // FIXME: it is OK for a compute shader entry and pixel shader entry live in
-    // same HLSL file. Issue https://github.com/llvm/llvm-project/issues/57880.
-    ShaderStage Pipeline = hlsl::getStageFromEnvironment(Env);
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-        << AL << (uint32_t)Pipeline << "Compute";
-    return;
-  }
-
-  D->addAttr(::new (S.Context) HLSLSV_GroupIndexAttr(S.Context, AL));
-}
-
 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   if (!T->hasUnsignedIntegerRepresentation())
     return false;
@@ -7162,23 +7135,6 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
 
 static void handleHLSLSV_DispatchThreadIDAttr(Sema &S, Decl *D,
                                               const ParsedAttr &AL) {
-  using llvm::Triple;
-  Triple Target = S.Context.getTargetInfo().getTriple();
-  // FIXME: it is OK for a compute shader entry and pixel shader entry live in
-  // same HLSL file.Issue https://github.com/llvm/llvm-project/issues/57880.
-  if (Target.getEnvironment() != Triple::Compute &&
-      Target.getEnvironment() != Triple::Library) {
-    uint32_t Pipeline =
-        (uint32_t)S.Context.getTargetInfo().getTriple().getEnvironment() -
-        (uint32_t)llvm::Triple::Pixel;
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-        << AL << Pipeline << "Compute";
-    return;
-  }
-
-  // FIXME: report warning and ignore semantic when cannot apply on the Decl.
-  // See https://github.com/llvm/llvm-project/issues/57916.
-
   // FIXME: support semantic on field.
   // See https://github.com/llvm/llvm-project/issues/57889.
   if (isa<FieldDecl>(D)) {
@@ -7204,11 +7160,7 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
     return;
 
   HLSLShaderAttr::ShaderType ShaderType;
-  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType) ||
-      // Library is added to help convert HLSLShaderAttr::ShaderType to
-      // llvm::Triple::EnviromentType. It is not a legal
-      // HLSLShaderAttr::ShaderType.
-      ShaderType == HLSLShaderAttr::Library) {
+  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
     S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
         << AL << Str << ArgLoc;
     return;
@@ -9347,7 +9299,7 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
     handleHLSLNumThreadsAttr(S, D, AL);
     break;
   case ParsedAttr::AT_HLSLSV_GroupIndex:
-    handleHLSLSVGroupIndexAttr(S, D, AL);
+    handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
     break;
   case ParsedAttr::AT_HLSLSV_DispatchThreadID:
     handleHLSLSV_DispatchThreadIDAttr(S, D, AL);

diff  --git a/clang/test/CodeGenHLSL/GlobalDestructors.hlsl b/clang/test/CodeGenHLSL/GlobalDestructors.hlsl
index 03505e3e46c4b4..b245af7c0f7b11 100644
--- a/clang/test/CodeGenHLSL/GlobalDestructors.hlsl
+++ b/clang/test/CodeGenHLSL/GlobalDestructors.hlsl
@@ -41,6 +41,7 @@ void Wag() {
 int Pupper::Count = 0;
 
 [numthreads(1,1,1)]
+[shader("compute")]
 void main(unsigned GI : SV_GroupIndex) {
   Wag();
 }

diff  --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
index 9b8f1ce7c36c43..8484259f84692b 100644
--- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -1,9 +1,9 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl  -finclude-default-header  -ast-dump -o - %s | FileCheck %s
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump  -finclude-default-header  -verify -o - %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -hlsl-entry CSMain -x hlsl  -finclude-default-header  -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header  -verify -o - %s
 
 [numthreads(8,8,1)]
-// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in Mesh shaders, requires Compute}}
-// expected-error at +1 {{attribute 'SV_DispatchThreadID' is unsupported in Mesh shaders, requires Compute}}
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error at +1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
 void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) {
 // CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)'
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'

diff  --git a/clang/test/SemaHLSL/Semantics/groupindex.hlsl b/clang/test/SemaHLSL/Semantics/groupindex.hlsl
new file mode 100644
index 00000000000000..a33e060c829064
--- /dev/null
+++ b/clang/test/SemaHLSL/Semantics/groupindex.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -o - %s -verify
+
+// expected-no-error
+[shader("compute")][numthreads(32,1,1)]
+void compute(int GI : SV_GroupIndex) {}
+
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'pixel' shaders}}
+[shader("pixel")]
+void pixel(int GI : SV_GroupIndex) {}
+
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'vertex' shaders}}
+[shader("vertex")]
+void vertex(int GI : SV_GroupIndex) {}
+
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'geometry' shaders}}
+[shader("geometry")]
+void geometry(int GI : SV_GroupIndex) {}
+
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'domain' shaders}}
+[shader("domain")]
+void domain(int GI : SV_GroupIndex) {}
+
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'amplification' shaders}}
+[shader("amplification")][numthreads(32,1,1)]
+void amplification(int GI : SV_GroupIndex) {}
+
+// expected-error at +2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders}}
+[shader("mesh")][numthreads(32,1,1)]
+void mesh(int GI : SV_GroupIndex) {}

diff  --git a/clang/test/SemaHLSL/entry.hlsl b/clang/test/SemaHLSL/entry.hlsl
index 3d2a2030cc6f38..684535c9fb6435 100644
--- a/clang/test/SemaHLSL/entry.hlsl
+++ b/clang/test/SemaHLSL/entry.hlsl
@@ -4,7 +4,7 @@
 
 // Make sure add HLSLShaderAttr along with HLSLNumThreadsAttr.
 // CHECK:HLSLNumThreadsAttr 0x{{.*}} <line:10:2, col:18> 1 1 1
-// CHECK:HLSLShaderAttr 0x{{.*}} <line:13:1> Compute
+// CHECK:HLSLShaderAttr 0x{{.*}} <line:13:1> Implicit Compute
 
 #ifdef WITH_NUM_THREADS
 [numthreads(1,1,1)]

diff  --git a/clang/test/SemaHLSL/entry_shader_redecl.hlsl b/clang/test/SemaHLSL/entry_shader_redecl.hlsl
new file mode 100644
index 00000000000000..8dd1a541b820a5
--- /dev/null
+++ b/clang/test/SemaHLSL/entry_shader_redecl.hlsl
@@ -0,0 +1,74 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry cs1 -o - %s -ast-dump -verify | FileCheck -DSHADERFN=cs1 -check-prefix=CHECK-ENV %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry cs2 -o - %s -ast-dump -verify | FileCheck -DSHADERFN=cs2 -check-prefix=CHECK-ENV %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry cs3 -o - %s -ast-dump -verify | FileCheck -DSHADERFN=cs3 -check-prefix=CHECK-ENV %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -ast-dump -verify | FileCheck -check-prefix=CHECK-LIB %s
+
+// expected-no-diagnostics
+
+// CHECK-ENV: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} [[SHADERFN]] 'void ()'
+// CHECK-ENV: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} [[SHADERFN]] 'void ()'
+// CHECK-ENV-NEXT: CompoundStmt 0x
+// CHECK-ENV-NEXT: HLSLNumThreadsAttr 0x
+// CHECK-ENV-NEXT: HLSLShaderAttr 0x{{.*}} Implicit Compute
+void cs1();
+[numthreads(1,1,1)] void cs1() {}
+[numthreads(1,1,1)] void cs2();
+void cs2() {}
+[numthreads(1,1,1)] void cs3();
+[numthreads(1,1,1)] void cs3() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s1 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s1 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+void s1();
+[shader("compute"), numthreads(1,1,1)] void s1() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s2 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s2 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute")] void s2();
+[shader("compute"), numthreads(1,1,1)] void s2() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s3 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s3 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[numthreads(1,1,1)] void s3();
+[shader("compute"), numthreads(1,1,1)] void s3() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s4 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s4 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute"), numthreads(1,1,1)] void s4();
+[shader("compute")][numthreads(1,1,1)] void s4() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s5 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s5 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Inherited Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute"), numthreads(1,1,1)] void s5();
+void s5() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s6 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s6 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+[shader("compute"), numthreads(1,1,1)] void s6();
+[shader("compute")] void s6() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s7 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s7 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Inherited Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute"), numthreads(1,1,1)] void s7();
+[numthreads(1,1,1)] void s7() {}

diff  --git a/clang/test/SemaHLSL/num_threads.hlsl b/clang/test/SemaHLSL/num_threads.hlsl
index f93e67d54257c8..b5f9ad6c33cd66 100644
--- a/clang/test/SemaHLSL/num_threads.hlsl
+++ b/clang/test/SemaHLSL/num_threads.hlsl
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - %s | FileCheck %s 
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl -ast-dump -o - %s | FileCheck %s 
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o - %s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-pixel -x hlsl -ast-dump -o - %s -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-vertex -x hlsl -ast-dump -o - %s -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-hull -x hlsl -ast-dump -o - %s -verify
@@ -97,14 +97,20 @@ int secondFn() {
   return 1;
 }
 
+[numthreads(4,2,1)]
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:100:2, col:18> 4 2 1
+int onlyOnForwardDecl();
+
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:100:2, col:18> Inherited 4 2 1
+int onlyOnForwardDecl() {
+  return 1;
+}
 
 #else // Vertex and Pixel only beyond here
-// expected-error-re at +1 {{attribute 'numthreads' is unsupported in {{[A-Za-z]+}} shaders, requires Compute, Amplification, Mesh or Library}}
+// expected-error-re at +1 {{attribute 'numthreads' is unsupported in '{{[A-Za-z]+}}' shaders, requires one of the following: compute, amplification, mesh}}
 [numthreads(1,1,1)]
 int main() {
  return 1;
 }
 
 #endif
-
-


        


More information about the cfe-commits mailing list