[clang] [llvm] [HLSL] AST support for WaveSize attribute. (PR #101240)

Xiang Li via cfe-commits cfe-commits at lists.llvm.org
Fri Aug 30 11:08:42 PDT 2024


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/101240

>From 65b4ab94bc533c8dee9733761947671a4d326e90 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 30 Jul 2024 16:34:40 -0400
Subject: [PATCH 1/7] [HLSL] AST support for WaveSize attribute.

First step for support WaveSize attribute in
 https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html

A new attribute HLSLWaveSizeAttr was supported in the AST.

Implement both the wave size and the wave size range, rather than separately which might require more work.

For #70118
---
 clang/include/clang/Basic/Attr.td             |  16 +++
 clang/include/clang/Basic/AttrDocs.td         |  37 ++++++
 clang/include/clang/Basic/DiagnosticGroups.td |   3 +
 .../clang/Basic/DiagnosticSemaKinds.td        |  15 +++
 clang/include/clang/Sema/SemaHLSL.h           |   4 +
 clang/lib/Sema/SemaDecl.cpp                   |   4 +
 clang/lib/Sema/SemaDeclAttr.cpp               |   3 +
 clang/lib/Sema/SemaHLSL.cpp                   | 116 +++++++++++++++++-
 clang/test/AST/HLSL/WaveSize.hlsl             |  25 ++++
 .../test/SemaHLSL/WaveSize-invalid-param.hlsl | 101 +++++++++++++++
 .../SemaHLSL/WaveSize-invalid-profiles.hlsl   |  20 +++
 clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl   |  24 ++++
 .../include/llvm/Frontend/HLSL/HLSLWaveSize.h |  94 ++++++++++++++
 llvm/include/llvm/Support/DXILABI.h           |   3 +
 14 files changed, 464 insertions(+), 1 deletion(-)
 create mode 100644 clang/test/AST/HLSL/WaveSize.hlsl
 create mode 100644 clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
 create mode 100644 clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
 create mode 100644 clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
 create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index a83e908899c83b..0d4256433365c4 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4651,6 +4651,22 @@ def HLSLParamModifier : TypeAttr {
   let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
 }
 
+def HLSLWaveSize: InheritableAttr {
+  let Spellings = [Microsoft<"WaveSize">];
+  let Args = [IntArgument<"Min">, DefaultIntArgument<"Max", 0>, DefaultIntArgument<"Preferred", 0>];
+  let Subjects = SubjectList<[HLSLEntry]>;
+  let LangOpts = [HLSL];
+  let AdditionalMembers = [{
+    private:
+      int SpelledArgsCount = 0;
+
+    public:
+      void setSpelledArgsCount(int C) { SpelledArgsCount = C; }
+      int getSpelledArgsCount() const { return SpelledArgsCount; }
+  }];
+  let Documentation = [WaveSizeDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index c2b9d7cb93c309..ef077db298831f 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7421,6 +7421,43 @@ flag.
   }];
 }
 
+def WaveSizeDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``WaveSize`` attribute specify a wave size on a shader entry point in order
+to indicate either that a shader depends on or strongly prefers a specific wave
+size.
+There're 2 versions of the attribute: ``WaveSize`` and ``RangedWaveSize``.
+The syntax for ``WaveSize`` is:
+
+.. code-block:: text
+
+  ``[WaveSize(<numLanes>)]``
+
+The allowed wave sizes that an HLSL shader may specify are the powers of 2
+between 4 and 128, inclusive.
+In other words, the set: [4, 8, 16, 32, 64, 128].
+
+The syntax for ``RangedWaveSize`` is:
+
+.. code-block:: text
+
+  ``[WaveSize(<minWaveSize>, <maxWaveSize>, [prefWaveSize])]``
+
+Where minWaveSize is the minimum wave size supported by the shader representing
+the beginning of the allowed range, maxWaveSize is the maximum wave size
+supported by the shader representing the end of the allowed range, and
+prefWaveSize is the optional preferred wave size representing the size expected
+to be the most optimal for this shader.
+
+``WaveSize`` is available for HLSL shader model 6.6 and later.
+``RangedWaveSize`` available for HLSL shader model 6.8 and later.
+
+The full documentation is available here: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
+and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
+  }];
+}
+
 def NumThreadsDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticGroups.td b/clang/include/clang/Basic/DiagnosticGroups.td
index 28d315f63e5c47..c4c29942ee1cbd 100644
--- a/clang/include/clang/Basic/DiagnosticGroups.td
+++ b/clang/include/clang/Basic/DiagnosticGroups.td
@@ -1550,6 +1550,9 @@ def HLSLAvailability : DiagGroup<"hlsl-availability">;
 // Warnings for legacy binding behavior
 def LegacyConstantRegisterBinding : DiagGroup<"legacy-constant-register-binding">;
 
+// Warning for HLSL Attributes on Statement.
+def HLSLAttributeStatement : DiagGroup<"attribute-statement">;
+
 // Warnings and notes related to const_var_decl_type attribute checks
 def ReadOnlyPlacementChecks : DiagGroup<"read-only-types">;
 
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index edf22b909c4d57..c0600aa6d99646 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12384,6 +12384,21 @@ def warn_hlsl_availability_unavailable :
 def err_hlsl_export_not_on_function : Error<
   "export declaration can only be used on functions">;
 
+def err_hlsl_attribute_in_wrong_shader_model: Error<
+  "attribute %0 requires shader model %1 or greater">;
+
+def err_hlsl_wavesize_size: Error<
+  "wavesize arguments must be between 4 and 128 and a power of 2">;
+def err_hlsl_wavesize_min_geq_max: Error<
+  "minimum wavesize value %0 must be less than maximum wavesize value %1">;
+def warn_hlsl_wavesize_min_eq_max:  Warning<
+  "wave size range minimum and maximum are equal">,
+  InGroup<HLSLAttributeStatement>, DefaultError;
+def err_hlsl_wavesize_pref_size_out_of_range: Error<
+  "preferred wavesize value %0 must be between %1 and %2">;
+def err_hlsl_wavesize_insufficient_shader_model: Error<
+  "wavesize only takes multiple arguments in shader model 6.8 or higher">;
+
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
   "a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 363a3ee6b4c1f2..210eb1167aa6ef 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -37,6 +37,9 @@ class SemaHLSL : public SemaBase {
   HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
                                           const AttributeCommonInfo &AL, int X,
                                           int Y, int Z);
+  HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
+                                      int Min, int Max, int Preferred,
+                                      int SpelledArgsCount);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -52,6 +55,7 @@ class SemaHLSL : public SemaBase {
   void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
 
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
+  void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
   void handleShaderAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 6327ae9b99aa4c..13482285736247 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2863,6 +2863,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
     NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
                                            NT->getZ());
+  else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr))
+    NewAttr =
+        S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
+                                   NT->getPreferred(), NT->getSpelledArgsCount());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 1e074298ac5289..33547c2e6e1452 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6886,6 +6886,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLNumThreads:
     S.HLSL().handleNumThreadsAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLWaveSize:
+    S.HLSL().handleWaveSizeAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupIndex:
     handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 65972987458d70..d67b43bc9cb0e5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -21,7 +21,9 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Frontend/HLSL/HLSLWaveSize.h"
 #include "llvm/Support/Casting.h"
+#include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/TargetParser/Triple.h"
 #include <iterator>
@@ -153,6 +155,25 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
       HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
 }
 
+HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
+                                              const AttributeCommonInfo &AL,
+                                              int Min, int Max, int Preferred,
+                                              int SpelledArgsCount) {
+  if (HLSLWaveSizeAttr *NT = D->getAttr<HLSLWaveSizeAttr>()) {
+    if (NT->getMin() != Min || NT->getMax() != Max ||
+        NT->getPreferred() != Preferred ||
+        NT->getSpelledArgsCount() != SpelledArgsCount) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  HLSLWaveSizeAttr *Result = ::new (getASTContext())
+      HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
+  Result->setSpelledArgsCount(SpelledArgsCount);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -224,7 +245,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
   assert(ShaderAttr && "Entry point has no shader attribute");
   llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
-
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
   switch (ST) {
   case llvm::Triple::Pixel:
   case llvm::Triple::Vertex:
@@ -244,6 +266,13 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
                                  llvm::Triple::Mesh});
       FD->setInvalidDecl();
     }
+    if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+      DiagnoseAttrStageMismatch(NT, ST,
+                                {llvm::Triple::Compute,
+                                 llvm::Triple::Amplification,
+                                 llvm::Triple::Mesh});
+      FD->setInvalidDecl();
+    }
     break;
 
   case llvm::Triple::Compute:
@@ -254,6 +283,20 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
           << llvm::Triple::getEnvironmentTypeName(ST);
       FD->setInvalidDecl();
     }
+    if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+      if (Ver.getMajor() < 6u ||
+          (Ver.getMajor() == 6u && Ver.getMinor() < 6u)) {
+        Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
+            << "wavesize"
+            << "6.6";
+        FD->setInvalidDecl();
+      } else if (NT->getSpelledArgsCount() > 1 &&
+                 (Ver.getMajor() == 6u && Ver.getMinor() < 8u)) {
+        Diag(NT->getLocation(),
+             diag::err_hlsl_wavesize_insufficient_shader_model);
+        FD->setInvalidDecl();
+      }
+    }
     break;
   default:
     llvm_unreachable("Unhandled environment in triple");
@@ -357,6 +400,77 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
     D->addAttr(NewAttr);
 }
 
+void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
+  // validate that the wavesize argument is a power of 2 between 4 and 128
+  // inclusive
+  unsigned SpelledArgsCount = AL.getNumArgs();
+  if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
+    return;
+
+  uint32_t Min;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
+    return;
+
+  uint32_t Max = 0;
+  if (SpelledArgsCount > 1 &&
+      !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
+    return;
+
+  uint32_t Preferred = 0;
+  if (SpelledArgsCount > 2 &&
+      !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
+    return;
+  llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred);
+  llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate();
+  // WaveSize validation succeeds when not defined, but since we have an
+  // attribute, this means min was zero, which is invalid for min.
+  if (ValidationResult == llvm::hlsl::WaveSize::ValidationResult::Success &&
+      !WaveSize.isDefined())
+    ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMin;
+
+  // It is invalid to explicitly specify degenerate cases.
+  if (SpelledArgsCount > 1 && WaveSize.Max == 0)
+    ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMax;
+  else if (SpelledArgsCount > 2 && WaveSize.Preferred == 0)
+    ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred;
+
+  switch (ValidationResult) {
+  case llvm::hlsl::WaveSize::ValidationResult::Success:
+    break;
+  case llvm::hlsl::WaveSize::ValidationResult::InvalidMin:
+  case llvm::hlsl::WaveSize::ValidationResult::InvalidMax:
+  case llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred:
+  case llvm::hlsl::WaveSize::ValidationResult::NoRangeOrMin:
+    Diag(AL.getLoc(), diag::err_hlsl_wavesize_size)
+        << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize;
+    break;
+  case llvm::hlsl::WaveSize::ValidationResult::MaxEqualsMin:
+    Diag(AL.getLoc(), diag::warn_hlsl_wavesize_min_eq_max)
+        << WaveSize.Min << WaveSize.Max;
+    break;
+  case llvm::hlsl::WaveSize::ValidationResult::MaxLessThanMin:
+    Diag(AL.getLoc(), diag::err_hlsl_wavesize_min_geq_max)
+        << WaveSize.Min << WaveSize.Max;
+    break;
+  case llvm::hlsl::WaveSize::ValidationResult::PreferredOutOfRange:
+    Diag(AL.getLoc(), diag::err_hlsl_wavesize_pref_size_out_of_range)
+        << WaveSize.Preferred << WaveSize.Min << WaveSize.Max;
+    break;
+  case llvm::hlsl::WaveSize::ValidationResult::MaxOrPreferredWhenUndefined:
+  case llvm::hlsl::WaveSize::ValidationResult::PreferredWhenNoRange:
+    llvm_unreachable("Should have hit InvalidMax or InvalidPreferred instead.");
+    break;
+  }
+
+  if (ValidationResult != llvm::hlsl::WaveSize::ValidationResult::Success)
+    return;
+
+  HLSLWaveSizeAttr *NewAttr =
+      mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   if (!T->hasUnsignedIntegerRepresentation())
     return false;
diff --git a/clang/test/AST/HLSL/WaveSize.hlsl b/clang/test/AST/HLSL/WaveSize.hlsl
new file mode 100644
index 00000000000000..fd6dc7c94d6d00
--- /dev/null
+++ b/clang/test/AST/HLSL/WaveSize.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w0 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 128 0 0
+ [numthreads(8,8,1)]
+ [WaveSize(128)]
+ void w0() {
+ }
+
+
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w1 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 64 0
+ [numthreads(8,8,1)]
+ [WaveSize(8, 64)]
+ void w1() {
+ }
+
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w2 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
+ [numthreads(8,8,1)]
+ [WaveSize(8, 128, 64)]
+ void w2() {
+ }
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
new file mode 100644
index 00000000000000..10c562839eef62
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
@@ -0,0 +1,101 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl %s -verify
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(1)]
+void e0() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 2)]
+void e1() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 8, 7)]
+void e2() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{minimum wavesize value 16 must be less than maximum wavesize value 8}}
+[WaveSize(16, 8)]
+void e3() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{preferred wavesize value 8 must be between 16 and 128}}
+[WaveSize(16, 128, 8)]
+void e4() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{preferred wavesize value 32 must be between 8 and 16}}
+[WaveSize(8, 16, 32)]
+void e5() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 0)]
+void e6() {
+}
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 4, 0)]
+void e7() {
+}
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wave size range minimum and maximum are equal}}
+[WaveSize(16, 16)]
+void e8() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(0)]
+void e9() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(-4)]
+void e10() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute takes no more than 3 arguments}}
+[WaveSize(16, 128, 64, 64)]
+void e11() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute takes at least 1 argument}}
+[WaveSize()]
+void e12() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute takes at least 1 argument}}
+[WaveSize]
+void e13() {
+}
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
new file mode 100644
index 00000000000000..13e27a5c4b685a
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
@@ -0,0 +1,20 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-pixel -x hlsl %s  -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-vertex -x hlsl %s  -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-geometry -x hlsl %s  -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-hull -x hlsl %s  -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-domain -x hlsl %s  -verify
+
+#if __SHADER_TARGET_STAGE == __SHADER_STAGE_PIXEL
+// expected-error at +10 {{attribute 'WaveSize' is unsupported in 'pixel' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_VERTEX
+// expected-error at +8 {{attribute 'WaveSize' is unsupported in 'vertex' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_GEOMETRY
+// expected-error at +6 {{attribute 'WaveSize' is unsupported in 'geometry' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_HULL
+// expected-error at +4 {{attribute 'WaveSize' is unsupported in 'hull' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_DOMAIN
+// expected-error at +2 {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}}
+#endif
+[WaveSize(16)]
+void main() {
+}
diff --git a/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
new file mode 100644
index 00000000000000..fb9978c6ce3ceb
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
@@ -0,0 +1,24 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.5-library -x hlsl %s -verify
+
+[shader("compute")]
+[numthreads(1,1,1)]
+#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
+// expected-error at +4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
+// expected-error at +2 {{attribute wavesize requires shader model 6.6 or greater}}
+#endif
+[WaveSize(4, 16)]
+void e0() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
+// expected-error at +4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
+// expected-error at +2 {{attribute wavesize requires shader model 6.6 or greater}}
+#endif
+[WaveSize(4, 16)]
+void e1() {
+}
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
new file mode 100644
index 00000000000000..ec8f22f58e1ad7
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
@@ -0,0 +1,94 @@
+//===- HLSLResource.h - HLSL Resource helper objects ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with HLSL WaveSize.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
+#define LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
+
+namespace llvm {
+namespace hlsl {
+
+// SM 6.6 allows WaveSize specification for only a single required size.
+// SM 6.8+ allows specification of WaveSize as a min, max and preferred value.
+struct WaveSize {
+  unsigned Min = 0;
+  unsigned Max = 0;
+  unsigned Preferred = 0;
+
+  WaveSize() = default;
+  WaveSize(unsigned Min, unsigned Max = 0, unsigned Preferred = 0)
+      : Min(Min), Max(Max), Preferred(Preferred) {}
+  WaveSize(const WaveSize &) = default;
+  WaveSize &operator=(const WaveSize &) = default;
+  bool operator==(const WaveSize &Other) const {
+    return Min == Other.Min && Max == Other.Max && Preferred == Other.Preferred;
+  };
+
+  // Valid non-zero values are powers of 2 between 4 and 128, inclusive.
+  static bool isValidValue(unsigned Value) {
+    return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0));
+  }
+  // Valid representations:
+  //    (not to be confused with encodings in metadata, PSV0, or RDAT)
+  //  0, 0, 0: Not defined
+  //  Min, 0, 0: single WaveSize (SM 6.6/6.7)
+  //    (single WaveSize is represented in metadata with the single Min value)
+  //  Min, Max (> Min), 0 or Preferred (>= Min and <= Max): Range (SM 6.8+)
+  //    (WaveSizeRange represenation in metadata is the same)
+  enum class ValidationResult {
+    Success,
+    InvalidMin,
+    InvalidMax,
+    InvalidPreferred,
+    MaxOrPreferredWhenUndefined,
+    PreferredWhenNoRange,
+    MaxEqualsMin,
+    MaxLessThanMin,
+    PreferredOutOfRange,
+    NoRangeOrMin,
+  };
+  ValidationResult validate() const {
+    if (Min == 0) { // Not defined
+      if (Max != 0 || Preferred != 0)
+        return ValidationResult::MaxOrPreferredWhenUndefined;
+      else
+        // all 3 parameters are 0
+        return ValidationResult::NoRangeOrMin;
+    } else if (!isValidValue(Min)) {
+      return ValidationResult::InvalidMin;
+    } else if (Max == 0) { // single WaveSize (SM 6.6/6.7)
+      if (Preferred != 0)
+        return ValidationResult::PreferredWhenNoRange;
+    } else if (!isValidValue(Max)) {
+      return ValidationResult::InvalidMax;
+    } else if (Min == Max) {
+      return ValidationResult::MaxEqualsMin;
+    } else if (Max < Min) {
+      return ValidationResult::MaxLessThanMin;
+    } else if (Preferred != 0) {
+      if (!isValidValue(Preferred))
+        return ValidationResult::InvalidPreferred;
+      if (Preferred < Min || Preferred > Max)
+        return ValidationResult::PreferredOutOfRange;
+    }
+    return ValidationResult::Success;
+  }
+  bool isValid() const { return validate() == ValidationResult::Success; }
+
+  bool isDefined() const { return Min != 0; }
+  bool isRange() const { return Max != 0; }
+  bool hasPreferred() const { return Preferred != 0; }
+};
+
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index cf2c42c689889d..b479f7c73eba36 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -96,6 +96,9 @@ enum class SamplerFeedbackType : uint32_t {
   MipRegionUsed = 1,
 };
 
+const unsigned MinWaveSize = 4;
+const unsigned MaxWaveSize = 128;
+
 } // namespace dxil
 } // namespace llvm
 

>From 0e12b2a87822e118105b81a9a7182c585e3e7d96 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 30 Jul 2024 17:01:45 -0400
Subject: [PATCH 2/7] Fix clang format.

---
 clang/lib/Sema/SemaDecl.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 13482285736247..49afa2e4b827ba 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2864,9 +2864,9 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
                                            NT->getZ());
   else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr))
-    NewAttr =
-        S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
-                                   NT->getPreferred(), NT->getSpelledArgsCount());
+    NewAttr = S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
+                                         NT->getPreferred(),
+                                         NT->getSpelledArgsCount());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))

>From c33aae582b507bbb3af25beef3e8dabb97ec79ed Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 30 Jul 2024 19:04:47 -0400
Subject: [PATCH 3/7] Compare VeresionTuple directly.

---
 clang/lib/Sema/SemaHLSL.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d67b43bc9cb0e5..2c40ef7fefe86c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -284,14 +284,12 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
       FD->setInvalidDecl();
     }
     if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
-      if (Ver.getMajor() < 6u ||
-          (Ver.getMajor() == 6u && Ver.getMinor() < 6u)) {
+      if (Ver < VersionTuple(6, 6)) {
         Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
             << "wavesize"
             << "6.6";
         FD->setInvalidDecl();
-      } else if (NT->getSpelledArgsCount() > 1 &&
-                 (Ver.getMajor() == 6u && Ver.getMinor() < 8u)) {
+      } else if (NT->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
         Diag(NT->getLocation(),
              diag::err_hlsl_wavesize_insufficient_shader_model);
         FD->setInvalidDecl();

>From aad8ffcb7a18c43bc89a06d3349ecd236c2cdd6d Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Wed, 31 Jul 2024 10:39:08 -0400
Subject: [PATCH 4/7] Update diag message.

---
 .../clang/Basic/DiagnosticSemaKinds.td        |  15 +--
 clang/lib/Sema/SemaHLSL.cpp                   | 103 +++++++++---------
 .../test/SemaHLSL/WaveSize-invalid-param.hlsl |  22 ++--
 clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl   |  10 +-
 .../include/llvm/Frontend/HLSL/HLSLWaveSize.h |  94 ----------------
 5 files changed, 71 insertions(+), 173 deletions(-)
 delete mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index c0600aa6d99646..2e759b5b67b68d 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12387,17 +12387,12 @@ def err_hlsl_export_not_on_function : Error<
 def err_hlsl_attribute_in_wrong_shader_model: Error<
   "attribute %0 requires shader model %1 or greater">;
 
-def err_hlsl_wavesize_size: Error<
-  "wavesize arguments must be between 4 and 128 and a power of 2">;
-def err_hlsl_wavesize_min_geq_max: Error<
-  "minimum wavesize value %0 must be less than maximum wavesize value %1">;
-def warn_hlsl_wavesize_min_eq_max:  Warning<
-  "wave size range minimum and maximum are equal">,
+def warn_attr_min_eq_max:  Warning<
+  "%0 attribute minimum and maximum arguments are equal">,
   InGroup<HLSLAttributeStatement>, DefaultError;
-def err_hlsl_wavesize_pref_size_out_of_range: Error<
-  "preferred wavesize value %0 must be between %1 and %2">;
-def err_hlsl_wavesize_insufficient_shader_model: Error<
-  "wavesize only takes multiple arguments in shader model 6.8 or higher">;
+
+def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
+  "attribute %0 with %1 arguments requires shader model %2 or greater">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 2c40ef7fefe86c..b2343c6ea6b451 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -21,7 +21,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
-#include "llvm/Frontend/HLSL/HLSLWaveSize.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -266,8 +265,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
                                  llvm::Triple::Mesh});
       FD->setInvalidDecl();
     }
-    if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
-      DiagnoseAttrStageMismatch(NT, ST,
+    if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
+      DiagnoseAttrStageMismatch(WS, ST,
                                 {llvm::Triple::Compute,
                                  llvm::Triple::Amplification,
                                  llvm::Triple::Mesh});
@@ -283,15 +282,16 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
           << llvm::Triple::getEnvironmentTypeName(ST);
       FD->setInvalidDecl();
     }
-    if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+    if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
       if (Ver < VersionTuple(6, 6)) {
-        Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
-            << "wavesize"
-            << "6.6";
+        Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
+            << WS << "6.6";
         FD->setInvalidDecl();
-      } else if (NT->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
-        Diag(NT->getLocation(),
-             diag::err_hlsl_wavesize_insufficient_shader_model);
+      } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
+        Diag(
+            WS->getLocation(),
+            diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
+            << WS << WS->getSpelledArgsCount() << "6.8";
         FD->setInvalidDecl();
       }
     }
@@ -398,6 +398,10 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
     D->addAttr(NewAttr);
 }
 
+static bool isValidWaveSizeValue(unsigned Value) {
+  return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0));
+}
+
 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
   // validate that the wavesize argument is a power of 2 between 4 and 128
   // inclusive
@@ -418,50 +422,43 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
   if (SpelledArgsCount > 2 &&
       !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
     return;
-  llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred);
-  llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate();
-  // WaveSize validation succeeds when not defined, but since we have an
-  // attribute, this means min was zero, which is invalid for min.
-  if (ValidationResult == llvm::hlsl::WaveSize::ValidationResult::Success &&
-      !WaveSize.isDefined())
-    ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMin;
-
-  // It is invalid to explicitly specify degenerate cases.
-  if (SpelledArgsCount > 1 && WaveSize.Max == 0)
-    ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMax;
-  else if (SpelledArgsCount > 2 && WaveSize.Preferred == 0)
-    ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred;
-
-  switch (ValidationResult) {
-  case llvm::hlsl::WaveSize::ValidationResult::Success:
-    break;
-  case llvm::hlsl::WaveSize::ValidationResult::InvalidMin:
-  case llvm::hlsl::WaveSize::ValidationResult::InvalidMax:
-  case llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred:
-  case llvm::hlsl::WaveSize::ValidationResult::NoRangeOrMin:
-    Diag(AL.getLoc(), diag::err_hlsl_wavesize_size)
-        << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize;
-    break;
-  case llvm::hlsl::WaveSize::ValidationResult::MaxEqualsMin:
-    Diag(AL.getLoc(), diag::warn_hlsl_wavesize_min_eq_max)
-        << WaveSize.Min << WaveSize.Max;
-    break;
-  case llvm::hlsl::WaveSize::ValidationResult::MaxLessThanMin:
-    Diag(AL.getLoc(), diag::err_hlsl_wavesize_min_geq_max)
-        << WaveSize.Min << WaveSize.Max;
-    break;
-  case llvm::hlsl::WaveSize::ValidationResult::PreferredOutOfRange:
-    Diag(AL.getLoc(), diag::err_hlsl_wavesize_pref_size_out_of_range)
-        << WaveSize.Preferred << WaveSize.Min << WaveSize.Max;
-    break;
-  case llvm::hlsl::WaveSize::ValidationResult::MaxOrPreferredWhenUndefined:
-  case llvm::hlsl::WaveSize::ValidationResult::PreferredWhenNoRange:
-    llvm_unreachable("Should have hit InvalidMax or InvalidPreferred instead.");
-    break;
-  }
 
-  if (ValidationResult != llvm::hlsl::WaveSize::ValidationResult::Success)
-    return;
+  if (SpelledArgsCount > 2) {
+    if (!isValidWaveSizeValue(Preferred)) {
+      Diag(AL.getArgAsExpr(2)->getExprLoc(),
+           diag::err_attribute_power_of_two_in_range)
+          << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
+          << Preferred;
+      return;
+    }
+    // Preferred not in range.
+    if (Preferred < Min || Preferred > Max) {
+      Diag(AL.getArgAsExpr(2)->getExprLoc(),
+           diag::err_attribute_power_of_two_in_range)
+          << AL << Min << Max << Preferred;
+      return;
+    }
+  } else if (SpelledArgsCount > 1) {
+    if (!isValidWaveSizeValue(Max)) {
+      Diag(AL.getArgAsExpr(1)->getExprLoc(),
+           diag::err_attribute_power_of_two_in_range)
+          << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
+      return;
+    }
+    if (Max < Min) {
+      Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
+      return;
+    } else if (Max == Min) {
+      Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
+    }
+  } else {
+    if (!isValidWaveSizeValue(Min)) {
+      Diag(AL.getArgAsExpr(0)->getExprLoc(),
+           diag::err_attribute_power_of_two_in_range)
+          << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
+      return;
+    }
+  }
 
   HLSLWaveSizeAttr *NewAttr =
       mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
index 10c562839eef62..4a15da6a22f6b9 100644
--- a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
+++ b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
@@ -2,49 +2,49 @@
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 1}}
 [WaveSize(1)]
 void e0() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 2}}
 [WaveSize(4, 2)]
 void e1() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 7}}
 [WaveSize(4, 8, 7)]
 void e2() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{minimum wavesize value 16 must be less than maximum wavesize value 8}}
+// expected-error at +1 {{'WaveSize' attribute argument is invalid: min must not be greater than max}}
 [WaveSize(16, 8)]
 void e3() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{preferred wavesize value 8 must be between 16 and 128}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 16 and 128 inclusive; provided argument was 8}}
 [WaveSize(16, 128, 8)]
 void e4() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{preferred wavesize value 32 must be between 8 and 16}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 8 and 16 inclusive; provided argument was 32}}
 [WaveSize(8, 16, 32)]
 void e5() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 0}}
 [WaveSize(4, 0)]
 void e6() {
 }
@@ -52,7 +52,7 @@ void e6() {
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 0}}
 [WaveSize(4, 4, 0)]
 void e7() {
 }
@@ -60,21 +60,21 @@ void e7() {
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wave size range minimum and maximum are equal}}
+// expected-error at +1 {{'WaveSize' attribute minimum and maximum arguments are equal}}
 [WaveSize(16, 16)]
 void e8() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 0}}
 [WaveSize(0)]
 void e9() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
-// expected-error at +1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+// expected-error at +1 {{'WaveSize' attribute requires an integer argument which is a constant power of two between 4 and 128 inclusive; provided argument was 4294967292}}
 [WaveSize(-4)]
 void e10() {
 }
diff --git a/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
index fb9978c6ce3ceb..c6718cfec8ef4c 100644
--- a/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
+++ b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
@@ -4,20 +4,20 @@
 [shader("compute")]
 [numthreads(1,1,1)]
 #if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
-// expected-error at +4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+// expected-error at +4 {{attribute 'WaveSize' with 3 arguments requires shader model 6.8 or greater}}
 #elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
-// expected-error at +2 {{attribute wavesize requires shader model 6.6 or greater}}
+// expected-error at +2 {{attribute 'WaveSize' requires shader model 6.6 or greater}}
 #endif
-[WaveSize(4, 16)]
+[WaveSize(4, 16, 8)]
 void e0() {
 }
 
 [shader("compute")]
 [numthreads(1,1,1)]
 #if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
-// expected-error at +4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+// expected-error at +4 {{attribute 'WaveSize' with 2 arguments requires shader model 6.8 or greater}}
 #elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
-// expected-error at +2 {{attribute wavesize requires shader model 6.6 or greater}}
+// expected-error at +2 {{attribute 'WaveSize' requires shader model 6.6 or greater}}
 #endif
 [WaveSize(4, 16)]
 void e1() {
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
deleted file mode 100644
index ec8f22f58e1ad7..00000000000000
--- a/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h
+++ /dev/null
@@ -1,94 +0,0 @@
-//===- HLSLResource.h - HLSL Resource helper objects ----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file This file contains helper objects for working with HLSL WaveSize.
-///
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
-#define LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H
-
-namespace llvm {
-namespace hlsl {
-
-// SM 6.6 allows WaveSize specification for only a single required size.
-// SM 6.8+ allows specification of WaveSize as a min, max and preferred value.
-struct WaveSize {
-  unsigned Min = 0;
-  unsigned Max = 0;
-  unsigned Preferred = 0;
-
-  WaveSize() = default;
-  WaveSize(unsigned Min, unsigned Max = 0, unsigned Preferred = 0)
-      : Min(Min), Max(Max), Preferred(Preferred) {}
-  WaveSize(const WaveSize &) = default;
-  WaveSize &operator=(const WaveSize &) = default;
-  bool operator==(const WaveSize &Other) const {
-    return Min == Other.Min && Max == Other.Max && Preferred == Other.Preferred;
-  };
-
-  // Valid non-zero values are powers of 2 between 4 and 128, inclusive.
-  static bool isValidValue(unsigned Value) {
-    return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0));
-  }
-  // Valid representations:
-  //    (not to be confused with encodings in metadata, PSV0, or RDAT)
-  //  0, 0, 0: Not defined
-  //  Min, 0, 0: single WaveSize (SM 6.6/6.7)
-  //    (single WaveSize is represented in metadata with the single Min value)
-  //  Min, Max (> Min), 0 or Preferred (>= Min and <= Max): Range (SM 6.8+)
-  //    (WaveSizeRange represenation in metadata is the same)
-  enum class ValidationResult {
-    Success,
-    InvalidMin,
-    InvalidMax,
-    InvalidPreferred,
-    MaxOrPreferredWhenUndefined,
-    PreferredWhenNoRange,
-    MaxEqualsMin,
-    MaxLessThanMin,
-    PreferredOutOfRange,
-    NoRangeOrMin,
-  };
-  ValidationResult validate() const {
-    if (Min == 0) { // Not defined
-      if (Max != 0 || Preferred != 0)
-        return ValidationResult::MaxOrPreferredWhenUndefined;
-      else
-        // all 3 parameters are 0
-        return ValidationResult::NoRangeOrMin;
-    } else if (!isValidValue(Min)) {
-      return ValidationResult::InvalidMin;
-    } else if (Max == 0) { // single WaveSize (SM 6.6/6.7)
-      if (Preferred != 0)
-        return ValidationResult::PreferredWhenNoRange;
-    } else if (!isValidValue(Max)) {
-      return ValidationResult::InvalidMax;
-    } else if (Min == Max) {
-      return ValidationResult::MaxEqualsMin;
-    } else if (Max < Min) {
-      return ValidationResult::MaxLessThanMin;
-    } else if (Preferred != 0) {
-      if (!isValidValue(Preferred))
-        return ValidationResult::InvalidPreferred;
-      if (Preferred < Min || Preferred > Max)
-        return ValidationResult::PreferredOutOfRange;
-    }
-    return ValidationResult::Success;
-  }
-  bool isValid() const { return validate() == ValidationResult::Success; }
-
-  bool isDefined() const { return Min != 0; }
-  bool isRange() const { return Max != 0; }
-  bool hasPreferred() const { return Preferred != 0; }
-};
-
-} // namespace hlsl
-} // namespace llvm
-
-#endif // LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H

>From 2863b9845a52b5b2722d538b889fdd6b1fa940e6 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Wed, 7 Aug 2024 15:49:49 -0400
Subject: [PATCH 5/7] Add test for dup WaveSize attribute.

---
 clang/lib/Sema/SemaDecl.cpp                     |  8 ++++----
 clang/test/AST/HLSL/WaveSize.hlsl               |  3 +++
 clang/test/SemaHLSL/WaveSize-invalid-param.hlsl | 10 ++++++++++
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 49afa2e4b827ba..69b793b987e42c 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2863,10 +2863,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
     NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
                                            NT->getZ());
-  else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr))
-    NewAttr = S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
-                                         NT->getPreferred(),
-                                         NT->getSpelledArgsCount());
+  else if (const auto *WS = dyn_cast<HLSLWaveSizeAttr>(Attr))
+    NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
+                                         WS->getPreferred(),
+                                         WS->getSpelledArgsCount());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
diff --git a/clang/test/AST/HLSL/WaveSize.hlsl b/clang/test/AST/HLSL/WaveSize.hlsl
index fd6dc7c94d6d00..44a7bfab1788b5 100644
--- a/clang/test/AST/HLSL/WaveSize.hlsl
+++ b/clang/test/AST/HLSL/WaveSize.hlsl
@@ -19,7 +19,10 @@
 
 // CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w2 'void ()'
 // CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
+// Duplicate WaveSize attribute will be ignored.
+// CHECK-NOT:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
  [numthreads(8,8,1)]
  [WaveSize(8, 128, 64)]
+ [WaveSize(8, 128, 64)]
  void w2() {
  }
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
index 4a15da6a22f6b9..e10be5a94df517 100644
--- a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
+++ b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
@@ -99,3 +99,13 @@ void e12() {
 [WaveSize]
 void e13() {
 }
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error at +1 {{'WaveSize' attribute parameters do not match the previous declaration}}
+[WaveSize(8)]
+// expected-note at +1 {{conflicting attribute is here}}
+[WaveSize(16)]
+void e14() {
+}

>From f249ddb045c081bb43bac873c572fa060584404c Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Fri, 30 Aug 2024 14:04:19 -0400
Subject: [PATCH 6/7] Update per comment.

---
 clang/lib/Sema/SemaHLSL.cpp                        |  2 +-
 clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl | 12 ++++++------
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b2343c6ea6b451..3259fcf3b9407a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -399,7 +399,7 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
 }
 
 static bool isValidWaveSizeValue(unsigned Value) {
-  return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0));
+  return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
 }
 
 void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
index 13e27a5c4b685a..f14c0141816fd9 100644
--- a/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
+++ b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
@@ -5,16 +5,16 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-domain -x hlsl %s  -verify
 
 #if __SHADER_TARGET_STAGE == __SHADER_STAGE_PIXEL
-// expected-error at +10 {{attribute 'WaveSize' is unsupported in 'pixel' shaders, requires one of the following: compute, amplification, mesh}}
+// expected-error@#WaveSize {{attribute 'WaveSize' is unsupported in 'pixel' shaders, requires one of the following: compute, amplification, mesh}}
 #elif __SHADER_TARGET_STAGE == __SHADER_STAGE_VERTEX
-// expected-error at +8 {{attribute 'WaveSize' is unsupported in 'vertex' shaders, requires one of the following: compute, amplification, mesh}}
+// expected-error@#WaveSize {{attribute 'WaveSize' is unsupported in 'vertex' shaders, requires one of the following: compute, amplification, mesh}}
 #elif __SHADER_TARGET_STAGE == __SHADER_STAGE_GEOMETRY
-// expected-error at +6 {{attribute 'WaveSize' is unsupported in 'geometry' shaders, requires one of the following: compute, amplification, mesh}}
+// expected-error@#WaveSize {{attribute 'WaveSize' is unsupported in 'geometry' shaders, requires one of the following: compute, amplification, mesh}}
 #elif __SHADER_TARGET_STAGE == __SHADER_STAGE_HULL
-// expected-error at +4 {{attribute 'WaveSize' is unsupported in 'hull' shaders, requires one of the following: compute, amplification, mesh}}
+// expected-error@#WaveSize {{attribute 'WaveSize' is unsupported in 'hull' shaders, requires one of the following: compute, amplification, mesh}}
 #elif __SHADER_TARGET_STAGE == __SHADER_STAGE_DOMAIN
-// expected-error at +2 {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}}
+// expected-error@#WaveSize {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}}
 #endif
-[WaveSize(16)]
+[WaveSize(16)] // #WaveSize
 void main() {
 }

>From cef87f9da53d551823ee2e32e6ec2e9233dd6ac2 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Fri, 30 Aug 2024 14:08:21 -0400
Subject: [PATCH 7/7] Fix name.

---
 clang/lib/Sema/SemaHLSL.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 3259fcf3b9407a..1373c2ea034bf5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -158,11 +158,11 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
                                               const AttributeCommonInfo &AL,
                                               int Min, int Max, int Preferred,
                                               int SpelledArgsCount) {
-  if (HLSLWaveSizeAttr *NT = D->getAttr<HLSLWaveSizeAttr>()) {
-    if (NT->getMin() != Min || NT->getMax() != Max ||
-        NT->getPreferred() != Preferred ||
-        NT->getSpelledArgsCount() != SpelledArgsCount) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+  if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
+    if (WS->getMin() != Min || WS->getMax() != Max ||
+        WS->getPreferred() != Preferred ||
+        WS->getSpelledArgsCount() != SpelledArgsCount) {
+      Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
       Diag(AL.getLoc(), diag::note_conflicting_attribute);
     }
     return nullptr;



More information about the cfe-commits mailing list