[clang] [clang][NFC] Move more functions to `SemaHLSL` (PR #88354)

Vlad Serebrennikov via cfe-commits cfe-commits at lists.llvm.org
Wed Apr 10 22:02:15 PDT 2024


https://github.com/Endilll created https://github.com/llvm/llvm-project/pull/88354

A follow-up to #87912. I'm moving more HLSL-related functions from `Sema` to `SemaHLSL`. I'm also dropping `HLSL` from their names in the process.

>From ecff8db824552872ba055fdc0bca42b1a0386c39 Mon Sep 17 00:00:00 2001
From: Vlad Serebrennikov <serebrennikov.vladislav at gmail.com>
Date: Thu, 11 Apr 2024 07:56:46 +0300
Subject: [PATCH] [clang][NFC] Move more functions to `SemaHLSL`

---
 clang/include/clang/Sema/Sema.h     |  15 ---
 clang/include/clang/Sema/SemaHLSL.h |  27 +++-
 clang/lib/Parse/ParseHLSL.cpp       |  10 +-
 clang/lib/Sema/SemaDecl.cpp         | 130 +------------------
 clang/lib/Sema/SemaDeclAttr.cpp     |  54 +-------
 clang/lib/Sema/SemaHLSL.cpp         | 186 +++++++++++++++++++++++++++-
 6 files changed, 218 insertions(+), 204 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e3e255a0dd76f8..e904cd3ad13fb7 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2940,13 +2940,6 @@ class Sema final : public SemaBase {
                                       QualType NewT, QualType OldT);
   void CheckMain(FunctionDecl *FD, const DeclSpec &D);
   void CheckMSVCRTEntryPoint(FunctionDecl *FD);
-  void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
-  void CheckHLSLEntryPoint(FunctionDecl *FD);
-  void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
-                                   const HLSLAnnotationAttr *AnnotationAttr);
-  void DiagnoseHLSLAttrStageMismatch(
-      const Attr *A, HLSLShaderAttr::ShaderType Stage,
-      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
   Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
                                                    bool IsDefinition);
   void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);
@@ -3707,14 +3700,6 @@ class Sema final : public SemaBase {
                           StringRef UuidAsWritten, MSGuidDecl *GuidDecl);
 
   BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL);
-  HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
-                                              const AttributeCommonInfo &AL,
-                                              int X, int Y, int Z);
-  HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                                      HLSLShaderAttr::ShaderType ShaderType);
-  HLSLParamModifierAttr *
-  mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
-                             HLSLParamModifierAttr::Spelling Spelling);
 
   WebAssemblyImportNameAttr *
   mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index acc675963c23a5..34acaf19517f2a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -13,12 +13,16 @@
 #ifndef LLVM_CLANG_SEMA_SEMAHLSL_H
 #define LLVM_CLANG_SEMA_SEMAHLSL_H
 
+#include "clang/AST/Attr.h"
+#include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/AST/Expr.h"
+#include "clang/Basic/AttributeCommonInfo.h"
 #include "clang/Basic/IdentifierTable.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/SemaBase.h"
+#include <initializer_list>
 
 namespace clang {
 
@@ -26,10 +30,25 @@ class SemaHLSL : public SemaBase {
 public:
   SemaHLSL(Sema &S);
 
-  Decl *ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
-                             SourceLocation KwLoc, IdentifierInfo *Ident,
-                             SourceLocation IdentLoc, SourceLocation LBrace);
-  void ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace);
+  Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc,
+                         IdentifierInfo *Ident, SourceLocation IdentLoc,
+                         SourceLocation LBrace);
+  void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace);
+  HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
+                                          const AttributeCommonInfo &AL, int X,
+                                          int Y, int Z);
+  HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                                  HLSLShaderAttr::ShaderType ShaderType);
+  HLSLParamModifierAttr *
+  mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                         HLSLParamModifierAttr::Spelling Spelling);
+  void ActOnTopLevelFunction(FunctionDecl *FD);
+  void CheckEntryPoint(FunctionDecl *FD);
+  void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
+                               const HLSLAnnotationAttr *AnnotationAttr);
+  void DiagnoseAttrStageMismatch(
+      const Attr *A, HLSLShaderAttr::ShaderType Stage,
+      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
 };
 
 } // namespace clang
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 5afc958600fa55..d97985d42369ad 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -72,9 +72,9 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
     return nullptr;
   }
 
-  Decl *D = Actions.HLSL().ActOnStartHLSLBuffer(
-      getCurScope(), IsCBuffer, BufferLoc, Identifier, IdentifierLoc,
-      T.getOpenLocation());
+  Decl *D = Actions.HLSL().ActOnStartBuffer(getCurScope(), IsCBuffer, BufferLoc,
+                                            Identifier, IdentifierLoc,
+                                            T.getOpenLocation());
 
   while (Tok.isNot(tok::r_brace) && Tok.isNot(tok::eof)) {
     // FIXME: support attribute on constants inside cbuffer/tbuffer.
@@ -88,7 +88,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
       T.skipToEnd();
       DeclEnd = T.getCloseLocation();
       BufferScope.Exit();
-      Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+      Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
       return nullptr;
     }
   }
@@ -96,7 +96,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
   T.consumeClose();
   DeclEnd = T.getCloseLocation();
   BufferScope.Exit();
-  Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+  Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
 
   Actions.ProcessDeclAttributeList(Actions.CurScope, D, Attrs);
   return D;
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8472aaeb6bad97..3beb4fb1f8c733 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -45,6 +45,7 @@
 #include "clang/Sema/ParsedTemplate.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaInternal.h"
 #include "clang/Sema/Template.h"
 #include "llvm/ADT/SmallString.h"
@@ -2972,10 +2973,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
   else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr))
     NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA);
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
-    NewAttr =
-        S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
+    NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
+                                           NT->getZ());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
-    NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType());
+    NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
     // Do nothing. Each redeclaration should be suppressed separately.
     NewAttr = nullptr;
@@ -10809,10 +10810,10 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
   if (getLangOpts().HLSL && D.isFunctionDefinition()) {
     // Any top level function could potentially be specified as an entry.
     if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
-      ActOnHLSLTopLevelFunction(NewFD);
+      HLSL().ActOnTopLevelFunction(NewFD);
 
     if (NewFD->hasAttr<HLSLShaderAttr>())
-      CheckHLSLEntryPoint(NewFD);
+      HLSL().CheckEntryPoint(NewFD);
   }
 
   // If this is the first declaration of a library builtin function, add
@@ -12660,125 +12661,6 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) {
   }
 }
 
-void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
-  auto &TargetInfo = getASTContext().getTargetInfo();
-
-  if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
-    return;
-
-  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
-    if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
-      // The entry point is already annotated - check that it matches the
-      // triple.
-      if (Shader->getType() != ShaderType) {
-        Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
-            << Shader;
-        FD->setInvalidDecl();
-      }
-    } else {
-      // Implicitly add the shader attribute if the entry function isn't
-      // explicitly annotated.
-      FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
-                                                 FD->getBeginLoc()));
-    }
-  } else {
-    switch (TargetInfo.getTriple().getEnvironment()) {
-    case llvm::Triple::UnknownEnvironment:
-    case llvm::Triple::Library:
-      break;
-    default:
-      llvm_unreachable("Unhandled environment in triple");
-    }
-  }
-}
-
-void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
-  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
-  assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
-  switch (ST) {
-  case HLSLShaderAttr::Pixel:
-  case HLSLShaderAttr::Vertex:
-  case HLSLShaderAttr::Geometry:
-  case HLSLShaderAttr::Hull:
-  case HLSLShaderAttr::Domain:
-  case HLSLShaderAttr::RayGeneration:
-  case HLSLShaderAttr::Intersection:
-  case HLSLShaderAttr::AnyHit:
-  case HLSLShaderAttr::ClosestHit:
-  case HLSLShaderAttr::Miss:
-  case HLSLShaderAttr::Callable:
-    if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
-      DiagnoseHLSLAttrStageMismatch(NT, ST,
-                                    {HLSLShaderAttr::Compute,
-                                     HLSLShaderAttr::Amplification,
-                                     HLSLShaderAttr::Mesh});
-      FD->setInvalidDecl();
-    }
-    break;
-
-  case HLSLShaderAttr::Compute:
-  case HLSLShaderAttr::Amplification:
-  case HLSLShaderAttr::Mesh:
-    if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
-      Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
-      FD->setInvalidDecl();
-    }
-    break;
-  }
-
-  for (ParmVarDecl *Param : FD->parameters()) {
-    if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
-      CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
-    } else {
-      // FIXME: Handle struct parameters where annotations are on struct fields.
-      // See: https://github.com/llvm/llvm-project/issues/57875
-      Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
-      Diag(Param->getLocation(), diag::note_previous_decl) << Param;
-      FD->setInvalidDecl();
-    }
-  }
-  // FIXME: Verify return type semantic annotation.
-}
-
-void Sema::CheckHLSLSemanticAnnotation(
-    FunctionDecl *EntryPoint, const Decl *Param,
-    const HLSLAnnotationAttr *AnnotationAttr) {
-  auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
-  assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
-  switch (AnnotationAttr->getKind()) {
-  case attr::HLSLSV_DispatchThreadID:
-  case attr::HLSLSV_GroupIndex:
-    if (ST == HLSLShaderAttr::Compute)
-      return;
-    DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST,
-                                  {HLSLShaderAttr::Compute});
-    break;
-  default:
-    llvm_unreachable("Unknown HLSLAnnotationAttr");
-  }
-}
-
-void Sema::DiagnoseHLSLAttrStageMismatch(
-    const Attr *A, HLSLShaderAttr::ShaderType Stage,
-    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
-  SmallVector<StringRef, 8> StageStrings;
-  llvm::transform(AllowedStages, std::back_inserter(StageStrings),
-                  [](HLSLShaderAttr::ShaderType ST) {
-                    return StringRef(
-                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
-                  });
-  Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
-      << (AllowedStages.size() != 1) << join(StageStrings, ", ");
-}
-
 bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
   // FIXME: Need strict checking.  In C89, we need to check for
   // any assignment, increment, decrement, function-calls, or
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 8bce04640e748e..b91064e28e4153 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -39,6 +39,7 @@
 #include "clang/Sema/ParsedAttr.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
@@ -7238,24 +7239,11 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
     return;
   }
 
-  HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z);
+  HLSLNumThreadsAttr *NewAttr = S.HLSL().mergeNumThreadsAttr(D, AL, X, Y, Z);
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
-                                                  const AttributeCommonInfo &AL,
-                                                  int X, int Y, int Z) {
-  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
-    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
-      Diag(AL.getLoc(), diag::note_conflicting_attribute);
-    }
-    return nullptr;
-  }
-  return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
-}
-
 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   if (!T->hasUnsignedIntegerRepresentation())
     return false;
@@ -7299,24 +7287,11 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
 
   // FIXME: check function match the shader stage.
 
-  HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType);
+  HLSLShaderAttr *NewAttr = S.HLSL().mergeShaderAttr(D, AL, ShaderType);
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLShaderAttr *
-Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                          HLSLShaderAttr::ShaderType ShaderType) {
-  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
-    if (NT->getType() != ShaderType) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
-      Diag(AL.getLoc(), diag::note_conflicting_attribute);
-    }
-    return nullptr;
-  }
-  return HLSLShaderAttr::Create(Context, ShaderType, AL);
-}
-
 static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
                                           const ParsedAttr &AL) {
   StringRef Space = "space0";
@@ -7391,34 +7366,13 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
 
 static void handleHLSLParamModifierAttr(Sema &S, Decl *D,
                                         const ParsedAttr &AL) {
-  HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr(
+  HLSLParamModifierAttr *NewAttr = S.HLSL().mergeParamModifierAttr(
       D, AL,
       static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLParamModifierAttr *
-Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
-                                 HLSLParamModifierAttr::Spelling Spelling) {
-  // We can only merge an `in` attribute with an `out` attribute. All other
-  // combinations of duplicated attributes are ill-formed.
-  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
-    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
-        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
-      D->dropAttr<HLSLParamModifierAttr>();
-      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
-      return HLSLParamModifierAttr::Create(
-          Context, /*MergedSpelling=*/true, AdjustedRange,
-          HLSLParamModifierAttr::Keyword_inout);
-    }
-    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
-    Diag(PA->getLocation(), diag::note_conflicting_attribute);
-    return nullptr;
-  }
-  return HLSLParamModifierAttr::Create(Context, AL);
-}
-
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.LangOpts.CPlusPlus) {
     S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 681849d6e6c8a2..bb9e37f18d370c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -9,17 +9,25 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Sema/SemaHLSL.h"
+#include "clang/Basic/DiagnosticSema.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/TargetInfo.h"
 #include "clang/Sema/Sema.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/TargetParser/Triple.h"
+#include <iterator>
 
 using namespace clang;
 
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
-Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
-                                     SourceLocation KwLoc,
-                                     IdentifierInfo *Ident,
-                                     SourceLocation IdentLoc,
-                                     SourceLocation LBrace) {
+Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
+                                 SourceLocation KwLoc, IdentifierInfo *Ident,
+                                 SourceLocation IdentLoc,
+                                 SourceLocation LBrace) {
   // For anonymous namespace, take the location of the left brace.
   DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
   HLSLBufferDecl *Result = HLSLBufferDecl::Create(
@@ -31,8 +39,174 @@ Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
   return Result;
 }
 
-void SemaHLSL::ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace) {
+void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
   BufDecl->setRBraceLoc(RBrace);
   SemaRef.PopDeclContext();
 }
+
+HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
+                                                  const AttributeCommonInfo &AL,
+                                                  int X, int Y, int Z) {
+  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
+    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return ::new (getASTContext())
+      HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
+}
+
+HLSLShaderAttr *
+SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                          HLSLShaderAttr::ShaderType ShaderType) {
+  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
+    if (NT->getType() != ShaderType) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
+}
+
+HLSLParamModifierAttr *
+SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                                 HLSLParamModifierAttr::Spelling Spelling) {
+  // We can only merge an `in` attribute with an `out` attribute. All other
+  // combinations of duplicated attributes are ill-formed.
+  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
+    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
+        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
+      D->dropAttr<HLSLParamModifierAttr>();
+      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
+      return HLSLParamModifierAttr::Create(
+          getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
+          HLSLParamModifierAttr::Keyword_inout);
+    }
+    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
+    Diag(PA->getLocation(), diag::note_conflicting_attribute);
+    return nullptr;
+  }
+  return HLSLParamModifierAttr::Create(getASTContext(), AL);
+}
+
+void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
+  auto &TargetInfo = getASTContext().getTargetInfo();
+
+  if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
+    return;
+
+  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
+  HLSLShaderAttr::ShaderType ShaderType;
+  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+    if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
+      // The entry point is already annotated - check that it matches the
+      // triple.
+      if (Shader->getType() != ShaderType) {
+        Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
+            << Shader;
+        FD->setInvalidDecl();
+      }
+    } else {
+      // Implicitly add the shader attribute if the entry function isn't
+      // explicitly annotated.
+      FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+                                                 FD->getBeginLoc()));
+    }
+  } else {
+    switch (TargetInfo.getTriple().getEnvironment()) {
+    case llvm::Triple::UnknownEnvironment:
+    case llvm::Triple::Library:
+      break;
+    default:
+      llvm_unreachable("Unhandled environment in triple");
+    }
+  }
+}
+
+void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
+  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
+  assert(ShaderAttr && "Entry point has no shader attribute");
+  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+  switch (ST) {
+  case HLSLShaderAttr::Pixel:
+  case HLSLShaderAttr::Vertex:
+  case HLSLShaderAttr::Geometry:
+  case HLSLShaderAttr::Hull:
+  case HLSLShaderAttr::Domain:
+  case HLSLShaderAttr::RayGeneration:
+  case HLSLShaderAttr::Intersection:
+  case HLSLShaderAttr::AnyHit:
+  case HLSLShaderAttr::ClosestHit:
+  case HLSLShaderAttr::Miss:
+  case HLSLShaderAttr::Callable:
+    if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
+      DiagnoseAttrStageMismatch(NT, ST,
+                                {HLSLShaderAttr::Compute,
+                                 HLSLShaderAttr::Amplification,
+                                 HLSLShaderAttr::Mesh});
+      FD->setInvalidDecl();
+    }
+    break;
+
+  case HLSLShaderAttr::Compute:
+  case HLSLShaderAttr::Amplification:
+  case HLSLShaderAttr::Mesh:
+    if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
+      Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
+          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+      FD->setInvalidDecl();
+    }
+    break;
+  }
+
+  for (ParmVarDecl *Param : FD->parameters()) {
+    if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
+      CheckSemanticAnnotation(FD, Param, AnnotationAttr);
+    } else {
+      // FIXME: Handle struct parameters where annotations are on struct fields.
+      // See: https://github.com/llvm/llvm-project/issues/57875
+      Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
+      Diag(Param->getLocation(), diag::note_previous_decl) << Param;
+      FD->setInvalidDecl();
+    }
+  }
+  // FIXME: Verify return type semantic annotation.
+}
+
+void SemaHLSL::CheckSemanticAnnotation(
+    FunctionDecl *EntryPoint, const Decl *Param,
+    const HLSLAnnotationAttr *AnnotationAttr) {
+  auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
+  assert(ShaderAttr && "Entry point has no shader attribute");
+  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+  switch (AnnotationAttr->getKind()) {
+  case attr::HLSLSV_DispatchThreadID:
+  case attr::HLSLSV_GroupIndex:
+    if (ST == HLSLShaderAttr::Compute)
+      return;
+    DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+    break;
+  default:
+    llvm_unreachable("Unknown HLSLAnnotationAttr");
+  }
+}
+
+void SemaHLSL::DiagnoseAttrStageMismatch(
+    const Attr *A, HLSLShaderAttr::ShaderType Stage,
+    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+  SmallVector<StringRef, 8> StageStrings;
+  llvm::transform(AllowedStages, std::back_inserter(StageStrings),
+                  [](HLSLShaderAttr::ShaderType ST) {
+                    return StringRef(
+                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+                  });
+  Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
+      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+      << (AllowedStages.size() != 1) << join(StageStrings, ", ");
+}



More information about the cfe-commits mailing list