[clang] [HLSL] Support packoffset attribute in AST (PR #89836)

Xiang Li via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 25 10:58:37 PDT 2024


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/89836

>From 4d8c72688656fe3b2ce8817087d8cf7352b5876b Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 23 Apr 2024 17:49:02 -0400
Subject: [PATCH 1/3] [HLSL] Support packoffset attribute in AST

Add HLSLPackOffsetAttr to save packoffset in AST.

Since we have to parse the attribute manually in ParseHLSLAnnotations,
we could create the ParsedAttribute with a integer offset parameter instead of string.
This approach avoids parsing the string if the offset is saved as a string in HLSLPackOffsetAttr.
---
 clang/include/clang/Basic/Attr.td             |  7 ++
 clang/include/clang/Basic/AttrDocs.td         | 18 +++++
 .../clang/Basic/DiagnosticParseKinds.td       |  2 +
 .../clang/Basic/DiagnosticSemaKinds.td        |  3 +
 clang/lib/Parse/ParseHLSL.cpp                 | 80 +++++++++++++++++++
 clang/lib/Sema/SemaDeclAttr.cpp               | 44 ++++++++++
 clang/lib/Sema/SemaHLSL.cpp                   | 48 +++++++++++
 clang/test/AST/HLSL/packoffset.hlsl           | 16 ++++
 clang/test/SemaHLSL/packoffset-invalid.hlsl   | 55 +++++++++++++
 9 files changed, 273 insertions(+)
 create mode 100644 clang/test/AST/HLSL/packoffset.hlsl
 create mode 100644 clang/test/SemaHLSL/packoffset-invalid.hlsl

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 4408d517e70e58..d3d006ed9633f4 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4372,6 +4372,13 @@ def HLSLResourceBinding: InheritableAttr {
   let Documentation = [HLSLResourceBindingDocs];
 }
 
+def HLSLPackOffset: HLSLAnnotationAttr {
+  let Spellings = [HLSLAnnotation<"packoffset">];
+  let LangOpts = [HLSL];
+  let Args = [IntArgument<"Offset">];
+  let Documentation = [HLSLPackOffsetDocs];
+}
+
 def HLSLSV_DispatchThreadID: HLSLAnnotationAttr {
   let Spellings = [HLSLAnnotation<"SV_DispatchThreadID">];
   let Subjects = SubjectList<[ParmVar, Field]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index a0bbe5861c5722..6bc7813bd43cb4 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7398,6 +7398,24 @@ The full documentation is available here: https://docs.microsoft.com/en-us/windo
   }];
 }
 
+def HLSLPackOffsetDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The packoffset attribute is used to change the layout of a cbuffer.
+Attribute spelling in HLSL is: ``packoffset(c[Subcomponent[.component]])``.
+A subcomponent is a register number, which is an integer. A component is in the form of [.xyzw].
+
+Here're packoffset examples with and without component:
+.. code-block:: c++
+  cbuffer A {
+    float3 a : packoffset(c0.y);
+    float4 b : packoffset(c4);
+  }
+
+The full documentation is available here: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-packoffset
+  }];
+}
+
 def HLSLSV_DispatchThreadIDDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 38174cf3549f14..81433ee79d48b2 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1745,5 +1745,7 @@ def err_hlsl_separate_attr_arg_and_number : Error<"wrong argument format for hls
 def ext_hlsl_access_specifiers : ExtWarn<
   "access specifiers are a clang HLSL extension">,
   InGroup<HLSLExtension>;
+def err_hlsl_unsupported_component : Error<"invalid component '%0' used; expected 'x', 'y', 'z', or 'w'">;
+def err_hlsl_packoffset_invalid_reg : Error<"invalid resource class specifier '%0' for packoffset, expected 'c'">;
 
 } // end of Parser diagnostics
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 63e951daec7477..bde9617c9820a8 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12167,6 +12167,9 @@ def err_hlsl_init_priority_unsupported : Error<
 def err_hlsl_unsupported_register_type : Error<"invalid resource class specifier '%0' used; expected 'b', 's', 't', or 'u'">;
 def err_hlsl_unsupported_register_number : Error<"register number should be an integer">;
 def err_hlsl_expected_space : Error<"invalid space specifier '%0' used; expected 'space' followed by an integer, like space1">;
+def err_hlsl_packoffset_mix : Error<"cannot mix packoffset elements with nonpackoffset elements in a cbuffer">;
+def err_hlsl_packoffset_overlap : Error<"packoffset overlap between %0, %1">;
+def err_hlsl_packoffset_cross_reg_boundary : Error<"packoffset cannot cross register boundary">;
 def err_hlsl_pointers_unsupported : Error<
   "%select{pointers|references}0 are unsupported in HLSL">;
 
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index f4cbece31f1810..eac4876ccab49a 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -183,6 +183,86 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
       return;
     }
   } break;
+  case ParsedAttr::AT_HLSLPackOffset: {
+    // Parse 'packoffset( c[Subcomponent][.component] )'.
+    // Check '('.
+    if (ExpectAndConsume(tok::l_paren, diag::err_expected_lparen_after)) {
+      SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+      return;
+    }
+    // Check c[Subcomponent] as an identifier.
+    if (!Tok.is(tok::identifier)) {
+      Diag(Tok.getLocation(), diag::err_expected) << tok::identifier;
+      SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+      return;
+    }
+    StringRef OffsetStr = Tok.getIdentifierInfo()->getName();
+    SourceLocation OffsetLoc = Tok.getLocation();
+    if (OffsetStr[0] != 'c') {
+      Diag(Tok.getLocation(), diag::err_hlsl_packoffset_invalid_reg)
+          << OffsetStr;
+      SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+      return;
+    }
+    OffsetStr = OffsetStr.substr(1);
+    unsigned SubComponent = 0;
+    if (!OffsetStr.empty()) {
+      // Make sure SubComponent is a number.
+      if (OffsetStr.getAsInteger(10, SubComponent)) {
+        Diag(OffsetLoc.getLocWithOffset(1),
+             diag::err_hlsl_unsupported_register_number);
+        return;
+      }
+    }
+    unsigned Component = 0;
+    ConsumeToken(); // consume identifier.
+    if (Tok.is(tok::period)) {
+      ConsumeToken(); // consume period.
+      if (!Tok.is(tok::identifier)) {
+        Diag(Tok.getLocation(), diag::err_expected) << tok::identifier;
+        SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+        return;
+      }
+      StringRef ComponentStr = Tok.getIdentifierInfo()->getName();
+      SourceLocation SpaceLoc = Tok.getLocation();
+      ConsumeToken(); // consume identifier.
+      // Make sure Component is a single character.
+      if (ComponentStr.size() != 1) {
+        Diag(SpaceLoc, diag::err_hlsl_unsupported_component) << ComponentStr;
+        SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+        return;
+      } else {
+        switch (ComponentStr[0]) {
+        case 'x':
+          Component = 0;
+          break;
+        case 'y':
+          Component = 1;
+          break;
+        case 'z':
+          Component = 2;
+          break;
+        case 'w':
+          Component = 3;
+          break;
+        default:
+          Diag(SpaceLoc, diag::err_hlsl_unsupported_component) << ComponentStr;
+          SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+          return;
+        }
+      }
+    }
+    unsigned Offset = SubComponent * 4 + Component;
+    ASTContext &Ctx = Actions.getASTContext();
+    ArgExprs.push_back(IntegerLiteral::Create(
+        Ctx, llvm::APInt(Ctx.getTypeSize(Ctx.getSizeType()), Offset),
+        Ctx.getSizeType(),
+        SourceLocation())); // Dummy location for integer literal.
+    if (ExpectAndConsume(tok::r_paren, diag::err_expected)) {
+      SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+      return;
+    }
+  } break;
   case ParsedAttr::UnknownAttribute:
     Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
     return;
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 363ae93cb62df1..373f2e8f50cdb5 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7314,6 +7314,47 @@ static void handleHLSLSV_DispatchThreadIDAttr(Sema &S, Decl *D,
   D->addAttr(::new (S.Context) HLSLSV_DispatchThreadIDAttr(S.Context, AL));
 }
 
+static void handleHLSLPackOffsetAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
+  if (!isa<VarDecl>(D)) {
+    S.Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
+        << AL << "cbuffer constant";
+    return;
+  }
+  auto *BufDecl = dyn_cast<HLSLBufferDecl>(D->getDeclContext());
+  if (!BufDecl) {
+    S.Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
+        << AL << "cbuffer constant";
+    return;
+  }
+
+  uint32_t Offset;
+  if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(0), Offset))
+    return;
+
+  QualType T = cast<VarDecl>(D)->getType().getCanonicalType();
+  // Check if T is an array or struct type.
+  // TODO: mark matrix type as aggregate type.
+  bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
+
+  unsigned ComponentNum = Offset & 0x3;
+  // Check ComponentNum is valid for T.
+  if (IsAggregateTy) {
+    if (ComponentNum != 0) {
+      S.Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
+      return;
+    }
+  } else {
+    unsigned size = S.getASTContext().getTypeSize(T);
+    // Make sure ComponentNum + sizeof(T) <= 4.
+    if ((ComponentNum * 32 + size) > 128) {
+      S.Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);
+      return;
+    }
+  }
+
+  D->addAttr(::new (S.Context) HLSLPackOffsetAttr(S.Context, AL, Offset));
+}
+
 static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   StringRef Str;
   SourceLocation ArgLoc;
@@ -9735,6 +9776,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLSV_DispatchThreadID:
     handleHLSLSV_DispatchThreadIDAttr(S, D, AL);
     break;
+  case ParsedAttr::AT_HLSLPackOffset:
+    handleHLSLPackOffsetAttr(S, D, AL);
+    break;
   case ParsedAttr::AT_HLSLShader:
     handleHLSLShaderAttr(S, D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bb9e37f18d370c..fa62cab54e6902 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -42,6 +42,54 @@ Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
 void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
   BufDecl->setRBraceLoc(RBrace);
+
+  // Validate packoffset.
+  llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
+  bool HasPackOffset = false;
+  bool HasNonPackOffset = false;
+  for (auto *Field : BufDecl->decls()) {
+    VarDecl *Var = dyn_cast<VarDecl>(Field);
+    if (!Var)
+      continue;
+    if (Field->hasAttr<HLSLPackOffsetAttr>()) {
+      PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());
+      HasPackOffset = true;
+    } else {
+      HasNonPackOffset = true;
+    }
+  }
+
+  if (HasPackOffset && HasNonPackOffset) {
+    Diag(BufDecl->getLocation(), diag::err_hlsl_packoffset_mix);
+  } else if (HasPackOffset) {
+    ASTContext &Context = getASTContext();
+    // Make sure no overlap in packoffset.
+    llvm::SmallDenseMap<VarDecl *, std::pair<unsigned, unsigned>>
+        PackOffsetRanges;
+    for (auto &Pair : PackOffsetVec) {
+      VarDecl *Var = Pair.first;
+      HLSLPackOffsetAttr *Attr = Pair.second;
+      unsigned Size = Context.getTypeSize(Var->getType());
+      unsigned Begin = Attr->getOffset() * 32;
+      unsigned End = Begin + Size;
+      for (auto &Range : PackOffsetRanges) {
+        VarDecl *OtherVar = Range.first;
+        unsigned OtherBegin = Range.second.first;
+        unsigned OtherEnd = Range.second.second;
+        if (Begin < OtherEnd && OtherBegin < Begin) {
+          Diag(Var->getLocation(), diag::err_hlsl_packoffset_overlap)
+              << Var << OtherVar;
+          break;
+        } else if (OtherBegin < End && Begin < OtherBegin) {
+          Diag(Var->getLocation(), diag::err_hlsl_packoffset_overlap)
+              << Var << OtherVar;
+          break;
+        }
+      }
+      PackOffsetRanges[Var] = std::make_pair(Begin, End);
+    }
+  }
+
   SemaRef.PopDeclContext();
 }
 
diff --git a/clang/test/AST/HLSL/packoffset.hlsl b/clang/test/AST/HLSL/packoffset.hlsl
new file mode 100644
index 00000000000000..d3cf798c995758
--- /dev/null
+++ b/clang/test/AST/HLSL/packoffset.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -triple dxil-unknown-shadermodel6.3-library -S -finclude-default-header  -ast-dump  -x hlsl %s | FileCheck %s
+
+
+// CHECK: HLSLBufferDecl {{.*}} cbuffer A
+cbuffer A
+{
+    // CHECK-NEXT: VarDecl {{.*}} C1 'float4'
+    // CHECK-NEXT: HLSLPackOffsetAttr {{.*}} 0
+    float4 C1 : packoffset(c);
+    // CHECK-NEXT: VarDecl {{.*}} col:11 C2 'float'
+    // CHECK-NEXT: HLSLPackOffsetAttr {{.*}} 4
+    float C2 : packoffset(c1);
+    // CHECK-NEXT: VarDecl {{.*}} col:11 C3 'float'
+    // CHECK-NEXT: HLSLPackOffsetAttr {{.*}} 5
+    float C3 : packoffset(c1.y);
+}
diff --git a/clang/test/SemaHLSL/packoffset-invalid.hlsl b/clang/test/SemaHLSL/packoffset-invalid.hlsl
new file mode 100644
index 00000000000000..a2c7bb5a1e05cd
--- /dev/null
+++ b/clang/test/SemaHLSL/packoffset-invalid.hlsl
@@ -0,0 +1,55 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -verify %s
+
+// expected-error at +1{{cannot mix packoffset elements with nonpackoffset elements in a cbuffer}}
+cbuffer Mix
+{
+    float4 M1 : packoffset(c0);
+    float M2;
+    float M3 : packoffset(c1.y);
+}
+
+// expected-error at +1{{cannot mix packoffset elements with nonpackoffset elements in a cbuffer}}
+cbuffer Mix2
+{
+    float4 M4;
+    float M5 : packoffset(c1.y);
+    float M6 ;
+}
+
+// expected-error at +1{{attribute 'packoffset' only applies to cbuffer constant}}
+float4 g : packoffset(c0);
+
+cbuffer IllegalOffset
+{
+    // expected-error at +1{{invalid resource class specifier 't2' for packoffset, expected 'c'}}
+    float4 i1 : packoffset(t2);
+    // expected-error at +1{{invalid component 'm' used; expected 'x', 'y', 'z', or 'w'}}
+    float i2 : packoffset(c1.m);
+}
+
+cbuffer Overlap
+{
+    float4 o1 : packoffset(c0);
+    // expected-error at +1{{packoffset overlap between 'o2', 'o1'}}
+    float2 o2 : packoffset(c0.z);
+}
+
+cbuffer CrossReg
+{
+    // expected-error at +1{{packoffset cannot cross register boundary}}
+    float4 c1 : packoffset(c0.y);
+    // expected-error at +1{{packoffset cannot cross register boundary}}
+    float2 c2 : packoffset(c1.w);
+}
+
+struct ST {
+  float s;
+};
+
+cbuffer Aggregate
+{
+    // expected-error at +1{{packoffset cannot cross register boundary}}
+    ST A1 : packoffset(c0.y);
+    // expected-error at +1{{packoffset cannot cross register boundary}}
+    float A2[2] : packoffset(c1.w);
+}

>From c356db64f307bd042d40cbade7f09824c7bf238c Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Wed, 24 Apr 2024 09:37:45 -0400
Subject: [PATCH 2/3] Fix doc format.

---
 clang/include/clang/Basic/AttrDocs.td | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 6bc7813bd43cb4..20e58897f20b2d 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7406,7 +7406,9 @@ Attribute spelling in HLSL is: ``packoffset(c[Subcomponent[.component]])``.
 A subcomponent is a register number, which is an integer. A component is in the form of [.xyzw].
 
 Here're packoffset examples with and without component:
+
 .. code-block:: c++
+
   cbuffer A {
     float3 a : packoffset(c0.y);
     float4 b : packoffset(c4);

>From 2e406f5e644dcdd04259551079953e6339060330 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Thu, 25 Apr 2024 13:58:18 -0400
Subject: [PATCH 3/3] Update per comment.

---
 clang/include/clang/Basic/AttrDocs.td |  2 +-
 clang/lib/Parse/ParseHLSL.cpp         | 38 +++++++++++++--------------
 clang/lib/Sema/SemaDeclAttr.cpp       |  8 +-----
 clang/lib/Sema/SemaHLSL.cpp           | 34 +++++++++++-------------
 4 files changed, 36 insertions(+), 46 deletions(-)

diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 20e58897f20b2d..ac08799faa91fb 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7402,7 +7402,7 @@ def HLSLPackOffsetDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
 The packoffset attribute is used to change the layout of a cbuffer.
-Attribute spelling in HLSL is: ``packoffset(c[Subcomponent[.component]])``.
+Attribute spelling in HLSL is: ``packoffset( c[Subcomponent][.component] )``.
 A subcomponent is a register number, which is an integer. A component is in the form of [.xyzw].
 
 Here're packoffset examples with and without component:
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index eac4876ccab49a..ae9ceaa700af88 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -211,6 +211,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
       if (OffsetStr.getAsInteger(10, SubComponent)) {
         Diag(OffsetLoc.getLocWithOffset(1),
              diag::err_hlsl_unsupported_register_number);
+        SkipUntil(tok::r_paren, StopAtSemi); // skip through )
         return;
       }
     }
@@ -231,25 +232,24 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
         Diag(SpaceLoc, diag::err_hlsl_unsupported_component) << ComponentStr;
         SkipUntil(tok::r_paren, StopAtSemi); // skip through )
         return;
-      } else {
-        switch (ComponentStr[0]) {
-        case 'x':
-          Component = 0;
-          break;
-        case 'y':
-          Component = 1;
-          break;
-        case 'z':
-          Component = 2;
-          break;
-        case 'w':
-          Component = 3;
-          break;
-        default:
-          Diag(SpaceLoc, diag::err_hlsl_unsupported_component) << ComponentStr;
-          SkipUntil(tok::r_paren, StopAtSemi); // skip through )
-          return;
-        }
+      }
+      switch (ComponentStr[0]) {
+      case 'x':
+        Component = 0;
+        break;
+      case 'y':
+        Component = 1;
+        break;
+      case 'z':
+        Component = 2;
+        break;
+      case 'w':
+        Component = 3;
+        break;
+      default:
+        Diag(SpaceLoc, diag::err_hlsl_unsupported_component) << ComponentStr;
+        SkipUntil(tok::r_paren, StopAtSemi); // skip through )
+        return;
       }
     }
     unsigned Offset = SubComponent * 4 + Component;
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 373f2e8f50cdb5..cf7d690f62883a 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7315,13 +7315,7 @@ static void handleHLSLSV_DispatchThreadIDAttr(Sema &S, Decl *D,
 }
 
 static void handleHLSLPackOffsetAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  if (!isa<VarDecl>(D)) {
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
-        << AL << "cbuffer constant";
-    return;
-  }
-  auto *BufDecl = dyn_cast<HLSLBufferDecl>(D->getDeclContext());
-  if (!BufDecl) {
+  if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
     S.Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
         << AL << "cbuffer constant";
     return;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fa62cab54e6902..b2eb69dd5a1450 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -64,29 +64,25 @@ void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
   } else if (HasPackOffset) {
     ASTContext &Context = getASTContext();
     // Make sure no overlap in packoffset.
-    llvm::SmallDenseMap<VarDecl *, std::pair<unsigned, unsigned>>
-        PackOffsetRanges;
-    for (auto &Pair : PackOffsetVec) {
-      VarDecl *Var = Pair.first;
-      HLSLPackOffsetAttr *Attr = Pair.second;
+    // Sort PackOffsetVec by offset.
+    std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),
+              [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
+                 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
+                return LHS.second->getOffset() < RHS.second->getOffset();
+              });
+
+    for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
+      VarDecl *Var = PackOffsetVec[i].first;
+      HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
       unsigned Size = Context.getTypeSize(Var->getType());
       unsigned Begin = Attr->getOffset() * 32;
       unsigned End = Begin + Size;
-      for (auto &Range : PackOffsetRanges) {
-        VarDecl *OtherVar = Range.first;
-        unsigned OtherBegin = Range.second.first;
-        unsigned OtherEnd = Range.second.second;
-        if (Begin < OtherEnd && OtherBegin < Begin) {
-          Diag(Var->getLocation(), diag::err_hlsl_packoffset_overlap)
-              << Var << OtherVar;
-          break;
-        } else if (OtherBegin < End && Begin < OtherBegin) {
-          Diag(Var->getLocation(), diag::err_hlsl_packoffset_overlap)
-              << Var << OtherVar;
-          break;
-        }
+      unsigned NextBegin = PackOffsetVec[i + 1].second->getOffset() * 32;
+      if (End > NextBegin) {
+        VarDecl *NextVar = PackOffsetVec[i + 1].first;
+        Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)
+            << NextVar << Var;
       }
-      PackOffsetRanges[Var] = std::make_pair(Begin, End);
     }
   }
 



More information about the cfe-commits mailing list