[clang] [llvm] [HLSL] AST support for WaveSize attribute. (PR #101240)
Xiang Li via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 30 14:02:23 PDT 2024
https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/101240
>From c7a476a4d8b06e399e9c076cc15208871e1b5a25 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 30 Jul 2024 16:34:40 -0400
Subject: [PATCH 1/2] [HLSL] AST support for WaveSize attribute.
First step for support WaveSize attribute in
https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
A new attribute HLSLWaveSizeAttr was supported in the AST.
Implement both the wave size and the wave size range, rather than separately which might require more work.
For #70118
---
clang/include/clang/Basic/Attr.td | 16 +++
clang/include/clang/Basic/AttrDocs.td | 37 ++++++
clang/include/clang/Basic/DiagnosticGroups.td | 3 +
.../clang/Basic/DiagnosticSemaKinds.td | 15 +++
clang/include/clang/Sema/SemaHLSL.h | 4 +
clang/lib/Sema/SemaDecl.cpp | 4 +
clang/lib/Sema/SemaDeclAttr.cpp | 3 +
clang/lib/Sema/SemaHLSL.cpp | 118 +++++++++++++++++-
clang/test/AST/HLSL/WaveSize.hlsl | 25 ++++
.../test/SemaHLSL/WaveSize-invalid-param.hlsl | 101 +++++++++++++++
.../SemaHLSL/WaveSize-invalid-profiles.hlsl | 20 +++
clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl | 24 ++++
.../include/llvm/Frontend/HLSL/HLSLWaveSize.h | 94 ++++++++++++++
llvm/include/llvm/Support/DXILABI.h | 3 +
14 files changed, 465 insertions(+), 2 deletions(-)
create mode 100644 clang/test/AST/HLSL/WaveSize.hlsl
create mode 100644 clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
create mode 100644 clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
create mode 100644 clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 46d0a66d59c37..8b2f8358aec28 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4625,6 +4625,22 @@ def HLSLParamModifier : TypeAttr {
let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
}
+def HLSLWaveSize: InheritableAttr {
+ let Spellings = [Microsoft<"WaveSize">];
+ let Args = [IntArgument<"Min">, DefaultIntArgument<"Max", 0>, DefaultIntArgument<"Preferred", 0>];
+ let Subjects = SubjectList<[HLSLEntry]>;
+ let LangOpts = [HLSL];
+ let AdditionalMembers = [{
+ private:
+ int SpelledArgsCount = 0;
+
+ public:
+ void setSpelledArgsCount(int C) { SpelledArgsCount = C; }
+ int getSpelledArgsCount() const { return SpelledArgsCount; }
+ }];
+ let Documentation = [WaveSizeDocs];
+}
+
def RandomizeLayout : InheritableAttr {
let Spellings = [GCC<"randomize_layout">];
let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 4b8d520d73893..e3c98912c81f4 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7322,6 +7322,43 @@ flag.
}];
}
+def WaveSizeDocs : Documentation {
+ let Category = DocCatFunction;
+ let Content = [{
+The ``WaveSize`` attribute specify a wave size on a shader entry point in order
+to indicate either that a shader depends on or strongly prefers a specific wave
+size.
+There're 2 versions of the attribute: ``WaveSize`` and ``RangedWaveSize``.
+The syntax for ``WaveSize`` is:
+
+.. code-block:: text
+
+ ``[WaveSize(<numLanes>)]``
+
+The allowed wave sizes that an HLSL shader may specify are the powers of 2
+between 4 and 128, inclusive.
+In other words, the set: [4, 8, 16, 32, 64, 128].
+
+The syntax for ``RangedWaveSize`` is:
+
+.. code-block:: text
+
+ ``[WaveSize(<minWaveSize>, <maxWaveSize>, [prefWaveSize])]``
+
+Where minWaveSize is the minimum wave size supported by the shader representing
+the beginning of the allowed range, maxWaveSize is the maximum wave size
+supported by the shader representing the end of the allowed range, and
+prefWaveSize is the optional preferred wave size representing the size expected
+to be the most optimal for this shader.
+
+``WaveSize`` is available for HLSL shader model 6.6 and later.
+``RangedWaveSize`` available for HLSL shader model 6.8 and later.
+
+The full documentation is available here: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
+and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
+ }];
+}
+
def NumThreadsDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticGroups.td b/clang/include/clang/Basic/DiagnosticGroups.td
index 19c3f1e043349..122b95e9f9a2e 100644
--- a/clang/include/clang/Basic/DiagnosticGroups.td
+++ b/clang/include/clang/Basic/DiagnosticGroups.td
@@ -1547,6 +1547,9 @@ def DXILValidation : DiagGroup<"dxil-validation">;
// Warning for HLSL API availability
def HLSLAvailability : DiagGroup<"hlsl-availability">;
+// Warning for HLSL Attributes on Statement.
+def HLSLAttributeStatement : DiagGroup<"attribute-statement">;
+
// Warnings and notes related to const_var_decl_type attribute checks
def ReadOnlyPlacementChecks : DiagGroup<"read-only-types">;
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 581434d33c5c9..9010812837d42 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12361,6 +12361,21 @@ def warn_hlsl_availability_unavailable :
def err_hlsl_export_not_on_function : Error<
"export declaration can only be used on functions">;
+def err_hlsl_attribute_in_wrong_shader_model: Error<
+ "attribute %0 requires shader model %1 or greater">;
+
+def err_hlsl_wavesize_size: Error<
+ "wavesize arguments must be between 4 and 128 and a power of 2">;
+def err_hlsl_wavesize_min_geq_max: Error<
+ "minimum wavesize value %0 must be less than maximum wavesize value %1">;
+def warn_hlsl_wavesize_min_eq_max: Warning<
+ "wave size range minimum and maximum are equal">,
+ InGroup<HLSLAttributeStatement>, DefaultError;
+def err_hlsl_wavesize_pref_size_out_of_range: Error<
+ "preferred wavesize value %0 must be between %1 and %2">;
+def err_hlsl_wavesize_insufficient_shader_model: Error<
+ "wavesize only takes multiple arguments in shader model 6.8 or higher">;
+
// Layout randomization diagnostics.
def err_non_designated_init_used : Error<
"a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 2ddbee67c414b..a4d76818d29d2 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -38,6 +38,9 @@ class SemaHLSL : public SemaBase {
HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
const AttributeCommonInfo &AL, int X,
int Y, int Z);
+ HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
+ int Min, int Max, int Preferred,
+ int SpelledArgsCount);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
@@ -53,6 +56,7 @@ class SemaHLSL : public SemaBase {
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
+ void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 694a754646f27..c9a7c9e54d13c 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2862,6 +2862,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
NT->getZ());
+ else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr))
+ NewAttr =
+ S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
+ NT->getPreferred(), NT->getSpelledArgsCount());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 98e3df9083516..57ae83be12881 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6887,6 +6887,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLNumThreads:
S.HLSL().handleNumThreadsAttr(D, AL);
break;
+ case ParsedAttr::AT_HLSLWaveSize:
+ S.HLSL().handleWaveSizeAttr(D, AL);
+ break;
case ParsedAttr::AT_HLSLSV_GroupIndex:
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9940bc5b4a606..d386897d8251e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -20,7 +20,9 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Frontend/HLSL/HLSLWaveSize.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/TargetParser/Triple.h"
#include <iterator>
@@ -144,6 +146,25 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
}
+HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
+ const AttributeCommonInfo &AL,
+ int Min, int Max, int Preferred,
+ int SpelledArgsCount) {
+ if (HLSLWaveSizeAttr *NT = D->getAttr<HLSLWaveSizeAttr>()) {
+ if (NT->getMin() != Min || NT->getMax() != Max ||
+ NT->getPreferred() != Preferred ||
+ NT->getSpelledArgsCount() != SpelledArgsCount) {
+ Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+ Diag(AL.getLoc(), diag::note_conflicting_attribute);
+ }
+ return nullptr;
+ }
+ HLSLWaveSizeAttr *Result = ::new (getASTContext())
+ HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
+ Result->setSpelledArgsCount(SpelledArgsCount);
+ return Result;
+}
+
HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType) {
@@ -215,7 +236,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
-
+ auto &TargetInfo = getASTContext().getTargetInfo();
+ VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
switch (ST) {
case llvm::Triple::Pixel:
case llvm::Triple::Vertex:
@@ -235,6 +257,13 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
llvm::Triple::Mesh});
FD->setInvalidDecl();
}
+ if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+ DiagnoseAttrStageMismatch(NT, ST,
+ {llvm::Triple::Compute,
+ llvm::Triple::Amplification,
+ llvm::Triple::Mesh});
+ FD->setInvalidDecl();
+ }
break;
case llvm::Triple::Compute:
@@ -245,6 +274,20 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
<< llvm::Triple::getEnvironmentTypeName(ST);
FD->setInvalidDecl();
}
+ if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+ if (Ver.getMajor() < 6u ||
+ (Ver.getMajor() == 6u && Ver.getMinor() < 6u)) {
+ Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
+ << "wavesize"
+ << "6.6";
+ FD->setInvalidDecl();
+ } else if (NT->getSpelledArgsCount() > 1 &&
+ (Ver.getMajor() == 6u && Ver.getMinor() < 8u)) {
+ Diag(NT->getLocation(),
+ diag::err_hlsl_wavesize_insufficient_shader_model);
+ FD->setInvalidDecl();
+ }
+ }
break;
default:
llvm_unreachable("Unhandled environment in triple");
@@ -348,6 +391,77 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}
+void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
+ // validate that the wavesize argument is a power of 2 between 4 and 128
+ // inclusive
+ unsigned SpelledArgsCount = AL.getNumArgs();
+ if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
+ return;
+
+ uint32_t Min;
+ if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
+ return;
+
+ uint32_t Max = 0;
+ if (SpelledArgsCount > 1 &&
+ !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
+ return;
+
+ uint32_t Preferred = 0;
+ if (SpelledArgsCount > 2 &&
+ !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
+ return;
+ llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred);
+ llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate();
+ // WaveSize validation succeeds when not defined, but since we have an
+ // attribute, this means min was zero, which is invalid for min.
+ if (ValidationResult == llvm::hlsl::WaveSize::ValidationResult::Success &&
+ !WaveSize.isDefined())
+ ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMin;
+
+ // It is invalid to explicitly specify degenerate cases.
+ if (SpelledArgsCount > 1 && WaveSize.Max == 0)
+ ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMax;
+ else if (SpelledArgsCount > 2 && WaveSize.Preferred == 0)
+ ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred;
+
+ switch (ValidationResult) {
+ case llvm::hlsl::WaveSize::ValidationResult::Success:
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::InvalidMin:
+ case llvm::hlsl::WaveSize::ValidationResult::InvalidMax:
+ case llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred:
+ case llvm::hlsl::WaveSize::ValidationResult::NoRangeOrMin:
+ Diag(AL.getLoc(), diag::err_hlsl_wavesize_size)
+ << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::MaxEqualsMin:
+ Diag(AL.getLoc(), diag::warn_hlsl_wavesize_min_eq_max)
+ << WaveSize.Min << WaveSize.Max;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::MaxLessThanMin:
+ Diag(AL.getLoc(), diag::err_hlsl_wavesize_min_geq_max)
+ << WaveSize.Min << WaveSize.Max;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::PreferredOutOfRange:
+ Diag(AL.getLoc(), diag::err_hlsl_wavesize_pref_size_out_of_range)
+ << WaveSize.Preferred << WaveSize.Min << WaveSize.Max;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::MaxOrPreferredWhenUndefined:
+ case llvm::hlsl::WaveSize::ValidationResult::PreferredWhenNoRange:
+ llvm_unreachable("Should have hit InvalidMax or InvalidPreferred instead.");
+ break;
+ }
+
+ if (ValidationResult != llvm::hlsl::WaveSize::ValidationResult::Success)
+ return;
+
+ HLSLWaveSizeAttr *NewAttr =
+ mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
+ if (NewAttr)
+ D->addAttr(NewAttr);
+}
+
static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
if (!T->hasUnsignedIntegerRepresentation())
return false;
@@ -356,7 +470,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
return true;
}
-void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
+void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
diff --git a/clang/test/AST/HLSL/WaveSize.hlsl b/clang/test/AST/HLSL/WaveSize.hlsl
new file mode 100644
index 0000000000000..fd6dc7c94d6d0
--- /dev/null
+++ b/clang/test/AST/HLSL/WaveSize.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w0 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 128 0 0
+ [numthreads(8,8,1)]
+ [WaveSize(128)]
+ void w0() {
+ }
+
+
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w1 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 64 0
+ [numthreads(8,8,1)]
+ [WaveSize(8, 64)]
+ void w1() {
+ }
+
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w2 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
+ [numthreads(8,8,1)]
+ [WaveSize(8, 128, 64)]
+ void w2() {
+ }
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
new file mode 100644
index 0000000000000..10c562839eef6
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
@@ -0,0 +1,101 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl %s -verify
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(1)]
+void e0() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 2)]
+void e1() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 8, 7)]
+void e2() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{minimum wavesize value 16 must be less than maximum wavesize value 8}}
+[WaveSize(16, 8)]
+void e3() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{preferred wavesize value 8 must be between 16 and 128}}
+[WaveSize(16, 128, 8)]
+void e4() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{preferred wavesize value 32 must be between 8 and 16}}
+[WaveSize(8, 16, 32)]
+void e5() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 0)]
+void e6() {
+}
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 4, 0)]
+void e7() {
+}
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wave size range minimum and maximum are equal}}
+[WaveSize(16, 16)]
+void e8() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(0)]
+void e9() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(-4)]
+void e10() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute takes no more than 3 arguments}}
+[WaveSize(16, 128, 64, 64)]
+void e11() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute takes at least 1 argument}}
+[WaveSize()]
+void e12() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute takes at least 1 argument}}
+[WaveSize]
+void e13() {
+}
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
new file mode 100644
index 0000000000000..13e27a5c4b685
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
@@ -0,0 +1,20 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-pixel -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-vertex -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-geometry -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-hull -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-domain -x hlsl %s -verify
+
+#if __SHADER_TARGET_STAGE == __SHADER_STAGE_PIXEL
+// expected-error at +10 {{attribute 'WaveSize' is unsupported in 'pixel' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_VERTEX
+// expected-error at +8 {{attribute 'WaveSize' is unsupported in 'vertex' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_GEOMETRY
+// expected-error at +6 {{attribute 'WaveSize' is unsupported in 'geometry' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_HULL
+// expected-error at +4 {{attribute 'WaveSize' is unsupported in 'hull' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_DOMAIN
+// expected-error at +2 {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}}
+#endif
+[WaveSize(16)]
+void main() {
+}
diff --git a/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
new file mode 100644
index 0000000000000..fb9978c6ce3ce
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
@@ -0,0 +1,24 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.5-library -x hlsl %s -verify
+
+[shader("compute")]
+[numthreads(1,1,1)]
+#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
+// expected-error at +4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
+// expected-error at +2 {{attribute wavesize requires shader model 6.6 or greater}}
+#endif
+[WaveSize(4, 16)]
+void e0() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
+// expected-error at +4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
+// expected-error at +2 {{attribute wavesize requires shader model 6.6 or greater}}
+#endif
+[WaveSize(4, 16)]
+void e1() {
+}
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
new file mode 100644
index 0000000000000..ec8f22f58e1ad
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
@@ -0,0 +1,94 @@
+//===- HLSLResource.h - HLSL Resource helper objects ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with HLSL WaveSize.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
+#define LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
+
+namespace llvm {
+namespace hlsl {
+
+// SM 6.6 allows WaveSize specification for only a single required size.
+// SM 6.8+ allows specification of WaveSize as a min, max and preferred value.
+struct WaveSize {
+ unsigned Min = 0;
+ unsigned Max = 0;
+ unsigned Preferred = 0;
+
+ WaveSize() = default;
+ WaveSize(unsigned Min, unsigned Max = 0, unsigned Preferred = 0)
+ : Min(Min), Max(Max), Preferred(Preferred) {}
+ WaveSize(const WaveSize &) = default;
+ WaveSize &operator=(const WaveSize &) = default;
+ bool operator==(const WaveSize &Other) const {
+ return Min == Other.Min && Max == Other.Max && Preferred == Other.Preferred;
+ };
+
+ // Valid non-zero values are powers of 2 between 4 and 128, inclusive.
+ static bool isValidValue(unsigned Value) {
+ return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0));
+ }
+ // Valid representations:
+ // (not to be confused with encodings in metadata, PSV0, or RDAT)
+ // 0, 0, 0: Not defined
+ // Min, 0, 0: single WaveSize (SM 6.6/6.7)
+ // (single WaveSize is represented in metadata with the single Min value)
+ // Min, Max (> Min), 0 or Preferred (>= Min and <= Max): Range (SM 6.8+)
+ // (WaveSizeRange represenation in metadata is the same)
+ enum class ValidationResult {
+ Success,
+ InvalidMin,
+ InvalidMax,
+ InvalidPreferred,
+ MaxOrPreferredWhenUndefined,
+ PreferredWhenNoRange,
+ MaxEqualsMin,
+ MaxLessThanMin,
+ PreferredOutOfRange,
+ NoRangeOrMin,
+ };
+ ValidationResult validate() const {
+ if (Min == 0) { // Not defined
+ if (Max != 0 || Preferred != 0)
+ return ValidationResult::MaxOrPreferredWhenUndefined;
+ else
+ // all 3 parameters are 0
+ return ValidationResult::NoRangeOrMin;
+ } else if (!isValidValue(Min)) {
+ return ValidationResult::InvalidMin;
+ } else if (Max == 0) { // single WaveSize (SM 6.6/6.7)
+ if (Preferred != 0)
+ return ValidationResult::PreferredWhenNoRange;
+ } else if (!isValidValue(Max)) {
+ return ValidationResult::InvalidMax;
+ } else if (Min == Max) {
+ return ValidationResult::MaxEqualsMin;
+ } else if (Max < Min) {
+ return ValidationResult::MaxLessThanMin;
+ } else if (Preferred != 0) {
+ if (!isValidValue(Preferred))
+ return ValidationResult::InvalidPreferred;
+ if (Preferred < Min || Preferred > Max)
+ return ValidationResult::PreferredOutOfRange;
+ }
+ return ValidationResult::Success;
+ }
+ bool isValid() const { return validate() == ValidationResult::Success; }
+
+ bool isDefined() const { return Min != 0; }
+ bool isRange() const { return Max != 0; }
+ bool hasPreferred() const { return Preferred != 0; }
+};
+
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index a2222eec09ba8..8e86ba119345e 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -113,6 +113,9 @@ enum class SamplerFeedbackType : uint32_t {
MipRegionUsed = 1,
};
+const unsigned MinWaveSize = 4;
+const unsigned MaxWaveSize = 128;
+
} // namespace dxil
} // namespace llvm
>From a584cfd2674de3b91f691cc8446c27db659d1d6a Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 30 Jul 2024 17:01:45 -0400
Subject: [PATCH 2/2] Fix clang format.
---
clang/lib/Sema/SemaDecl.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index c9a7c9e54d13c..6ec53a791acab 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2863,9 +2863,9 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
NT->getZ());
else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr))
- NewAttr =
- S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
- NT->getPreferred(), NT->getSpelledArgsCount());
+ NewAttr = S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
+ NT->getPreferred(),
+ NT->getSpelledArgsCount());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
More information about the llvm-commits
mailing list