[clang] [clang][NFC] Move more functions to `SemaHLSL` (PR #88354)
Vlad Serebrennikov via cfe-commits
cfe-commits at lists.llvm.org
Wed Apr 10 22:02:15 PDT 2024
https://github.com/Endilll created https://github.com/llvm/llvm-project/pull/88354
A follow-up to #87912. I'm moving more HLSL-related functions from `Sema` to `SemaHLSL`. I'm also dropping `HLSL` from their names in the process.
>From ecff8db824552872ba055fdc0bca42b1a0386c39 Mon Sep 17 00:00:00 2001
From: Vlad Serebrennikov <serebrennikov.vladislav at gmail.com>
Date: Thu, 11 Apr 2024 07:56:46 +0300
Subject: [PATCH] [clang][NFC] Move more functions to `SemaHLSL`
---
clang/include/clang/Sema/Sema.h | 15 ---
clang/include/clang/Sema/SemaHLSL.h | 27 +++-
clang/lib/Parse/ParseHLSL.cpp | 10 +-
clang/lib/Sema/SemaDecl.cpp | 130 +------------------
clang/lib/Sema/SemaDeclAttr.cpp | 54 +-------
clang/lib/Sema/SemaHLSL.cpp | 186 +++++++++++++++++++++++++++-
6 files changed, 218 insertions(+), 204 deletions(-)
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e3e255a0dd76f8..e904cd3ad13fb7 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2940,13 +2940,6 @@ class Sema final : public SemaBase {
QualType NewT, QualType OldT);
void CheckMain(FunctionDecl *FD, const DeclSpec &D);
void CheckMSVCRTEntryPoint(FunctionDecl *FD);
- void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
- void CheckHLSLEntryPoint(FunctionDecl *FD);
- void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
- const HLSLAnnotationAttr *AnnotationAttr);
- void DiagnoseHLSLAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
bool IsDefinition);
void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);
@@ -3707,14 +3700,6 @@ class Sema final : public SemaBase {
StringRef UuidAsWritten, MSGuidDecl *GuidDecl);
BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL);
- HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
- const AttributeCommonInfo &AL,
- int X, int Y, int Z);
- HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType);
- HLSLParamModifierAttr *
- mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLParamModifierAttr::Spelling Spelling);
WebAssemblyImportNameAttr *
mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index acc675963c23a5..34acaf19517f2a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -13,12 +13,16 @@
#ifndef LLVM_CLANG_SEMA_SEMAHLSL_H
#define LLVM_CLANG_SEMA_SEMAHLSL_H
+#include "clang/AST/Attr.h"
+#include "clang/AST/Decl.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/Expr.h"
+#include "clang/Basic/AttributeCommonInfo.h"
#include "clang/Basic/IdentifierTable.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Scope.h"
#include "clang/Sema/SemaBase.h"
+#include <initializer_list>
namespace clang {
@@ -26,10 +30,25 @@ class SemaHLSL : public SemaBase {
public:
SemaHLSL(Sema &S);
- Decl *ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
- SourceLocation KwLoc, IdentifierInfo *Ident,
- SourceLocation IdentLoc, SourceLocation LBrace);
- void ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace);
+ Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc,
+ IdentifierInfo *Ident, SourceLocation IdentLoc,
+ SourceLocation LBrace);
+ void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace);
+ HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
+ const AttributeCommonInfo &AL, int X,
+ int Y, int Z);
+ HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+ HLSLShaderAttr::ShaderType ShaderType);
+ HLSLParamModifierAttr *
+ mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+ HLSLParamModifierAttr::Spelling Spelling);
+ void ActOnTopLevelFunction(FunctionDecl *FD);
+ void CheckEntryPoint(FunctionDecl *FD);
+ void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
+ const HLSLAnnotationAttr *AnnotationAttr);
+ void DiagnoseAttrStageMismatch(
+ const Attr *A, HLSLShaderAttr::ShaderType Stage,
+ std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
};
} // namespace clang
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 5afc958600fa55..d97985d42369ad 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -72,9 +72,9 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
return nullptr;
}
- Decl *D = Actions.HLSL().ActOnStartHLSLBuffer(
- getCurScope(), IsCBuffer, BufferLoc, Identifier, IdentifierLoc,
- T.getOpenLocation());
+ Decl *D = Actions.HLSL().ActOnStartBuffer(getCurScope(), IsCBuffer, BufferLoc,
+ Identifier, IdentifierLoc,
+ T.getOpenLocation());
while (Tok.isNot(tok::r_brace) && Tok.isNot(tok::eof)) {
// FIXME: support attribute on constants inside cbuffer/tbuffer.
@@ -88,7 +88,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
T.skipToEnd();
DeclEnd = T.getCloseLocation();
BufferScope.Exit();
- Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+ Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
return nullptr;
}
}
@@ -96,7 +96,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
T.consumeClose();
DeclEnd = T.getCloseLocation();
BufferScope.Exit();
- Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+ Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
Actions.ProcessDeclAttributeList(Actions.CurScope, D, Attrs);
return D;
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8472aaeb6bad97..3beb4fb1f8c733 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -45,6 +45,7 @@
#include "clang/Sema/ParsedTemplate.h"
#include "clang/Sema/Scope.h"
#include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
#include "clang/Sema/SemaInternal.h"
#include "clang/Sema/Template.h"
#include "llvm/ADT/SmallString.h"
@@ -2972,10 +2973,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr))
NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA);
else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
- NewAttr =
- S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
+ NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
+ NT->getZ());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
- NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType());
+ NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
// Do nothing. Each redeclaration should be suppressed separately.
NewAttr = nullptr;
@@ -10809,10 +10810,10 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
if (getLangOpts().HLSL && D.isFunctionDefinition()) {
// Any top level function could potentially be specified as an entry.
if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
- ActOnHLSLTopLevelFunction(NewFD);
+ HLSL().ActOnTopLevelFunction(NewFD);
if (NewFD->hasAttr<HLSLShaderAttr>())
- CheckHLSLEntryPoint(NewFD);
+ HLSL().CheckEntryPoint(NewFD);
}
// If this is the first declaration of a library builtin function, add
@@ -12660,125 +12661,6 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) {
}
}
-void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
- auto &TargetInfo = getASTContext().getTargetInfo();
-
- if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
- return;
-
- StringRef Env = TargetInfo.getTriple().getEnvironmentName();
- HLSLShaderAttr::ShaderType ShaderType;
- if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
- if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
- // The entry point is already annotated - check that it matches the
- // triple.
- if (Shader->getType() != ShaderType) {
- Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
- << Shader;
- FD->setInvalidDecl();
- }
- } else {
- // Implicitly add the shader attribute if the entry function isn't
- // explicitly annotated.
- FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
- FD->getBeginLoc()));
- }
- } else {
- switch (TargetInfo.getTriple().getEnvironment()) {
- case llvm::Triple::UnknownEnvironment:
- case llvm::Triple::Library:
- break;
- default:
- llvm_unreachable("Unhandled environment in triple");
- }
- }
-}
-
-void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
- const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
- assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
- switch (ST) {
- case HLSLShaderAttr::Pixel:
- case HLSLShaderAttr::Vertex:
- case HLSLShaderAttr::Geometry:
- case HLSLShaderAttr::Hull:
- case HLSLShaderAttr::Domain:
- case HLSLShaderAttr::RayGeneration:
- case HLSLShaderAttr::Intersection:
- case HLSLShaderAttr::AnyHit:
- case HLSLShaderAttr::ClosestHit:
- case HLSLShaderAttr::Miss:
- case HLSLShaderAttr::Callable:
- if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
- DiagnoseHLSLAttrStageMismatch(NT, ST,
- {HLSLShaderAttr::Compute,
- HLSLShaderAttr::Amplification,
- HLSLShaderAttr::Mesh});
- FD->setInvalidDecl();
- }
- break;
-
- case HLSLShaderAttr::Compute:
- case HLSLShaderAttr::Amplification:
- case HLSLShaderAttr::Mesh:
- if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
- Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
- << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
- FD->setInvalidDecl();
- }
- break;
- }
-
- for (ParmVarDecl *Param : FD->parameters()) {
- if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
- CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
- } else {
- // FIXME: Handle struct parameters where annotations are on struct fields.
- // See: https://github.com/llvm/llvm-project/issues/57875
- Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
- Diag(Param->getLocation(), diag::note_previous_decl) << Param;
- FD->setInvalidDecl();
- }
- }
- // FIXME: Verify return type semantic annotation.
-}
-
-void Sema::CheckHLSLSemanticAnnotation(
- FunctionDecl *EntryPoint, const Decl *Param,
- const HLSLAnnotationAttr *AnnotationAttr) {
- auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
- assert(ShaderAttr && "Entry point has no shader attribute");
- HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
- switch (AnnotationAttr->getKind()) {
- case attr::HLSLSV_DispatchThreadID:
- case attr::HLSLSV_GroupIndex:
- if (ST == HLSLShaderAttr::Compute)
- return;
- DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST,
- {HLSLShaderAttr::Compute});
- break;
- default:
- llvm_unreachable("Unknown HLSLAnnotationAttr");
- }
-}
-
-void Sema::DiagnoseHLSLAttrStageMismatch(
- const Attr *A, HLSLShaderAttr::ShaderType Stage,
- std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
- SmallVector<StringRef, 8> StageStrings;
- llvm::transform(AllowedStages, std::back_inserter(StageStrings),
- [](HLSLShaderAttr::ShaderType ST) {
- return StringRef(
- HLSLShaderAttr::ConvertShaderTypeToStr(ST));
- });
- Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
- << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
- << (AllowedStages.size() != 1) << join(StageStrings, ", ");
-}
-
bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
// FIXME: Need strict checking. In C89, we need to check for
// any assignment, increment, decrement, function-calls, or
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 8bce04640e748e..b91064e28e4153 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -39,6 +39,7 @@
#include "clang/Sema/ParsedAttr.h"
#include "clang/Sema/Scope.h"
#include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
#include "clang/Sema/SemaInternal.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
@@ -7238,24 +7239,11 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
return;
}
- HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z);
+ HLSLNumThreadsAttr *NewAttr = S.HLSL().mergeNumThreadsAttr(D, AL, X, Y, Z);
if (NewAttr)
D->addAttr(NewAttr);
}
-HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
- const AttributeCommonInfo &AL,
- int X, int Y, int Z) {
- if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
- if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
- Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
- Diag(AL.getLoc(), diag::note_conflicting_attribute);
- }
- return nullptr;
- }
- return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
-}
-
static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
if (!T->hasUnsignedIntegerRepresentation())
return false;
@@ -7299,24 +7287,11 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
// FIXME: check function match the shader stage.
- HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType);
+ HLSLShaderAttr *NewAttr = S.HLSL().mergeShaderAttr(D, AL, ShaderType);
if (NewAttr)
D->addAttr(NewAttr);
}
-HLSLShaderAttr *
-Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLShaderAttr::ShaderType ShaderType) {
- if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
- if (NT->getType() != ShaderType) {
- Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
- Diag(AL.getLoc(), diag::note_conflicting_attribute);
- }
- return nullptr;
- }
- return HLSLShaderAttr::Create(Context, ShaderType, AL);
-}
-
static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
const ParsedAttr &AL) {
StringRef Space = "space0";
@@ -7391,34 +7366,13 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
static void handleHLSLParamModifierAttr(Sema &S, Decl *D,
const ParsedAttr &AL) {
- HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr(
+ HLSLParamModifierAttr *NewAttr = S.HLSL().mergeParamModifierAttr(
D, AL,
static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
if (NewAttr)
D->addAttr(NewAttr);
}
-HLSLParamModifierAttr *
-Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
- HLSLParamModifierAttr::Spelling Spelling) {
- // We can only merge an `in` attribute with an `out` attribute. All other
- // combinations of duplicated attributes are ill-formed.
- if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
- if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
- (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
- D->dropAttr<HLSLParamModifierAttr>();
- SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
- return HLSLParamModifierAttr::Create(
- Context, /*MergedSpelling=*/true, AdjustedRange,
- HLSLParamModifierAttr::Keyword_inout);
- }
- Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
- Diag(PA->getLocation(), diag::note_conflicting_attribute);
- return nullptr;
- }
- return HLSLParamModifierAttr::Create(Context, AL);
-}
-
static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
if (!S.LangOpts.CPlusPlus) {
S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 681849d6e6c8a2..bb9e37f18d370c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -9,17 +9,25 @@
//===----------------------------------------------------------------------===//
#include "clang/Sema/SemaHLSL.h"
+#include "clang/Basic/DiagnosticSema.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/TargetInfo.h"
#include "clang/Sema/Sema.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/TargetParser/Triple.h"
+#include <iterator>
using namespace clang;
SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
-Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
- SourceLocation KwLoc,
- IdentifierInfo *Ident,
- SourceLocation IdentLoc,
- SourceLocation LBrace) {
+Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
+ SourceLocation KwLoc, IdentifierInfo *Ident,
+ SourceLocation IdentLoc,
+ SourceLocation LBrace) {
// For anonymous namespace, take the location of the left brace.
DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
HLSLBufferDecl *Result = HLSLBufferDecl::Create(
@@ -31,8 +39,174 @@ Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
return Result;
}
-void SemaHLSL::ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace) {
+void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
BufDecl->setRBraceLoc(RBrace);
SemaRef.PopDeclContext();
}
+
+HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
+ const AttributeCommonInfo &AL,
+ int X, int Y, int Z) {
+ if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
+ if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
+ Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+ Diag(AL.getLoc(), diag::note_conflicting_attribute);
+ }
+ return nullptr;
+ }
+ return ::new (getASTContext())
+ HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
+}
+
+HLSLShaderAttr *
+SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+ HLSLShaderAttr::ShaderType ShaderType) {
+ if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
+ if (NT->getType() != ShaderType) {
+ Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+ Diag(AL.getLoc(), diag::note_conflicting_attribute);
+ }
+ return nullptr;
+ }
+ return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
+}
+
+HLSLParamModifierAttr *
+SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+ HLSLParamModifierAttr::Spelling Spelling) {
+ // We can only merge an `in` attribute with an `out` attribute. All other
+ // combinations of duplicated attributes are ill-formed.
+ if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
+ if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
+ (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
+ D->dropAttr<HLSLParamModifierAttr>();
+ SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
+ return HLSLParamModifierAttr::Create(
+ getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
+ HLSLParamModifierAttr::Keyword_inout);
+ }
+ Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
+ Diag(PA->getLocation(), diag::note_conflicting_attribute);
+ return nullptr;
+ }
+ return HLSLParamModifierAttr::Create(getASTContext(), AL);
+}
+
+void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
+ auto &TargetInfo = getASTContext().getTargetInfo();
+
+ if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
+ return;
+
+ StringRef Env = TargetInfo.getTriple().getEnvironmentName();
+ HLSLShaderAttr::ShaderType ShaderType;
+ if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+ if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
+ // The entry point is already annotated - check that it matches the
+ // triple.
+ if (Shader->getType() != ShaderType) {
+ Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
+ << Shader;
+ FD->setInvalidDecl();
+ }
+ } else {
+ // Implicitly add the shader attribute if the entry function isn't
+ // explicitly annotated.
+ FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
+ FD->getBeginLoc()));
+ }
+ } else {
+ switch (TargetInfo.getTriple().getEnvironment()) {
+ case llvm::Triple::UnknownEnvironment:
+ case llvm::Triple::Library:
+ break;
+ default:
+ llvm_unreachable("Unhandled environment in triple");
+ }
+ }
+}
+
+void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
+ const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
+ assert(ShaderAttr && "Entry point has no shader attribute");
+ HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+ switch (ST) {
+ case HLSLShaderAttr::Pixel:
+ case HLSLShaderAttr::Vertex:
+ case HLSLShaderAttr::Geometry:
+ case HLSLShaderAttr::Hull:
+ case HLSLShaderAttr::Domain:
+ case HLSLShaderAttr::RayGeneration:
+ case HLSLShaderAttr::Intersection:
+ case HLSLShaderAttr::AnyHit:
+ case HLSLShaderAttr::ClosestHit:
+ case HLSLShaderAttr::Miss:
+ case HLSLShaderAttr::Callable:
+ if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
+ DiagnoseAttrStageMismatch(NT, ST,
+ {HLSLShaderAttr::Compute,
+ HLSLShaderAttr::Amplification,
+ HLSLShaderAttr::Mesh});
+ FD->setInvalidDecl();
+ }
+ break;
+
+ case HLSLShaderAttr::Compute:
+ case HLSLShaderAttr::Amplification:
+ case HLSLShaderAttr::Mesh:
+ if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
+ Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
+ << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
+ FD->setInvalidDecl();
+ }
+ break;
+ }
+
+ for (ParmVarDecl *Param : FD->parameters()) {
+ if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
+ CheckSemanticAnnotation(FD, Param, AnnotationAttr);
+ } else {
+ // FIXME: Handle struct parameters where annotations are on struct fields.
+ // See: https://github.com/llvm/llvm-project/issues/57875
+ Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
+ Diag(Param->getLocation(), diag::note_previous_decl) << Param;
+ FD->setInvalidDecl();
+ }
+ }
+ // FIXME: Verify return type semantic annotation.
+}
+
+void SemaHLSL::CheckSemanticAnnotation(
+ FunctionDecl *EntryPoint, const Decl *Param,
+ const HLSLAnnotationAttr *AnnotationAttr) {
+ auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
+ assert(ShaderAttr && "Entry point has no shader attribute");
+ HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+ switch (AnnotationAttr->getKind()) {
+ case attr::HLSLSV_DispatchThreadID:
+ case attr::HLSLSV_GroupIndex:
+ if (ST == HLSLShaderAttr::Compute)
+ return;
+ DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
+ break;
+ default:
+ llvm_unreachable("Unknown HLSLAnnotationAttr");
+ }
+}
+
+void SemaHLSL::DiagnoseAttrStageMismatch(
+ const Attr *A, HLSLShaderAttr::ShaderType Stage,
+ std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
+ SmallVector<StringRef, 8> StageStrings;
+ llvm::transform(AllowedStages, std::back_inserter(StageStrings),
+ [](HLSLShaderAttr::ShaderType ST) {
+ return StringRef(
+ HLSLShaderAttr::ConvertShaderTypeToStr(ST));
+ });
+ Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
+ << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
+ << (AllowedStages.size() != 1) << join(StageStrings, ", ");
+}
More information about the cfe-commits
mailing list