[clang] 92fc1eb - [HLSL] add loop unroll (#93879)

via cfe-commits cfe-commits at lists.llvm.org
Thu Jul 11 14:08:17 PDT 2024


Author: Farzon Lotfi
Date: 2024-07-11T17:08:13-04:00
New Revision: 92fc1eb0c1ae3813f2ac9208e2c74207aae9d23f

URL: https://github.com/llvm/llvm-project/commit/92fc1eb0c1ae3813f2ac9208e2c74207aae9d23f
DIFF: https://github.com/llvm/llvm-project/commit/92fc1eb0c1ae3813f2ac9208e2c74207aae9d23f.diff

LOG: [HLSL] add loop unroll (#93879)

spec: https://github.com/microsoft/hlsl-specs/pull/263

- `Attr.td` - Define the HLSL loop attribute hints (unroll and loop)
- `AttrDocs.td` - Add documentation for unroll and loop
- `CGLoopInfo.cpp` - Add codegen for HLSL unroll that maps to clang
unroll expectations
- `ParseStmt.cpp` - For statements if HLSL define DeclSpecAttrs via
MaybeParseMicrosoftAttributes
- `SemaStmtAttr.cpp` - Add the HLSL loop unroll handeling

resolves #70114

dxc examples: 
- for loop: https://hlsl.godbolt.org/z/8EK6Pa139
- while loop:  https://hlsl.godbolt.org/z/ebr5MvEcK
- do while: https://hlsl.godbolt.org/z/be8cedoTs 

Documentation:

![Screenshot_20240531_143000](https://github.com/llvm/llvm-project/assets/1802579/9da9df9b-68a6-49eb-9d4f-e080aa2eff7f)

Added: 
    clang/test/CodeGenHLSL/loops/unroll.hlsl
    clang/test/SemaHLSL/Loops/unroll.hlsl

Modified: 
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Basic/AttrDocs.td
    clang/lib/CodeGen/CGLoopInfo.cpp
    clang/lib/Parse/ParseStmt.cpp
    clang/lib/Sema/SemaStmtAttr.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index d2d9dd24536cb..6d80f0a0a6586 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4172,6 +4172,18 @@ def LoopHint : Attr {
   let HasCustomParsing = 1;
 }
 
+/// The HLSL loop attributes
+def HLSLLoopHint: StmtAttr {
+  /// [unroll(directive)]
+  /// [loop]
+  let Spellings = [Microsoft<"unroll">, Microsoft<"loop">];
+  let Args = [UnsignedArgument<"directive", /*opt*/1>];
+  let Subjects = SubjectList<[ForStmt, WhileStmt, DoStmt],
+                              ErrorDiag, "'for', 'while', and 'do' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];

diff  --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index ab4bd003541fa..09cf4f80bd999 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7343,6 +7343,100 @@ where shaders must be compiled into a library and linked at runtime.
   }];
 }
 
+def HLSLLoopHintDocs : Documentation {
+  let Category = DocCatStmt;
+  let Heading = "[loop]";
+  let Content = [{
+The ``[loop]`` directive allows loop optimization hints to be
+specified for the subsequent loop. The directive allows unrolling to
+be disabled and is not compatible with [unroll(x)]. 
+
+Specifying the parameter, ``[loop]``, directs the
+unroller to not unroll the loop. 
+
+.. code-block:: hlsl
+
+  [loop]
+  for (...) {
+    ...
+  }
+
+.. code-block:: hlsl
+
+  [loop]
+  while (...) {
+    ...
+  }
+
+.. code-block:: hlsl
+
+  [loop]
+  do {
+    ...
+  } while (...)
+
+See `hlsl loop extensions <https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-for>`_
+for details.
+  }];
+}
+
+def HLSLUnrollHintDocs : Documentation {
+  let Category = DocCatStmt;
+  let Heading = "[unroll(x)], [unroll]";
+  let Content = [{
+Loop unrolling optimization hints can be specified with ``[unroll(x)]``
+. The attribute is placed immediately before a for, while,
+or do-while.
+Specifying the parameter, ``[unroll(_value_)]``, directs the
+unroller to unroll the loop ``_value_`` times. Note: [unroll(x)] is not compatible with [loop].
+
+.. code-block:: hlsl
+
+  [unroll(4)]
+  for (...) {
+    ...
+  }
+
+.. code-block:: hlsl
+
+  [unroll]
+  for (...) {
+    ...
+  }
+
+.. code-block:: hlsl
+
+  [unroll(4)]
+  while (...) {
+    ...
+  }
+
+.. code-block:: hlsl
+
+  [unroll]
+  while (...) {
+    ...
+  }
+
+.. code-block:: hlsl
+
+  [unroll(4)]
+  do {
+    ...
+  } while (...)
+
+.. code-block:: hlsl
+
+  [unroll]
+  do {
+    ...
+  } while (...)
+
+See `hlsl loop extensions <https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-for>`_
+for details.
+  }];
+}
+
 def ClangRandomizeLayoutDocs : Documentation {
   let Category = DocCatDecl;
   let Heading = "randomize_layout, no_randomize_layout";
@@ -7402,7 +7496,8 @@ b for constant buffer views (CBV).
 
 Register space is specified in the format ``space[number]`` and defaults to ``space0`` if omitted.
 Here're resource binding examples with and without space:
-.. code-block:: c++
+
+.. code-block:: hlsl
 
   RWBuffer<float> Uav : register(u3, space1);
   Buffer<float> Buf : register(t1);
@@ -7420,7 +7515,7 @@ A subcomponent is a register number, which is an integer. A component is in the
 
 Examples:
 
-.. code-block:: c++
+.. code-block:: hlsl
 
   cbuffer A {
     float3 a : packoffset(c0.y);

diff  --git a/clang/lib/CodeGen/CGLoopInfo.cpp b/clang/lib/CodeGen/CGLoopInfo.cpp
index 0d4800b90a2f2..6b886bd6b6d2c 100644
--- a/clang/lib/CodeGen/CGLoopInfo.cpp
+++ b/clang/lib/CodeGen/CGLoopInfo.cpp
@@ -612,9 +612,9 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
     const LoopHintAttr *LH = dyn_cast<LoopHintAttr>(Attr);
     const OpenCLUnrollHintAttr *OpenCLHint =
         dyn_cast<OpenCLUnrollHintAttr>(Attr);
-
+    const HLSLLoopHintAttr *HLSLLoopHint = dyn_cast<HLSLLoopHintAttr>(Attr);
     // Skip non loop hint attributes
-    if (!LH && !OpenCLHint) {
+    if (!LH && !OpenCLHint && !HLSLLoopHint) {
       continue;
     }
 
@@ -635,6 +635,17 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
         Option = LoopHintAttr::UnrollCount;
         State = LoopHintAttr::Numeric;
       }
+    } else if (HLSLLoopHint) {
+      ValueInt = HLSLLoopHint->getDirective();
+      if (HLSLLoopHint->getSemanticSpelling() ==
+          HLSLLoopHintAttr::Spelling::Microsoft_unroll) {
+        if (ValueInt == 0)
+          State = LoopHintAttr::Enable;
+        if (ValueInt > 0) {
+          Option = LoopHintAttr::UnrollCount;
+          State = LoopHintAttr::Numeric;
+        }
+      }
     } else if (LH) {
       auto *ValueExpr = LH->getValue();
       if (ValueExpr) {

diff  --git a/clang/lib/Parse/ParseStmt.cpp b/clang/lib/Parse/ParseStmt.cpp
index 16a5b7483ec1c..22d38adc28ebe 100644
--- a/clang/lib/Parse/ParseStmt.cpp
+++ b/clang/lib/Parse/ParseStmt.cpp
@@ -114,18 +114,21 @@ Parser::ParseStatementOrDeclaration(StmtVector &Stmts,
   // here because we don't want to allow arbitrary orderings.
   ParsedAttributes CXX11Attrs(AttrFactory);
   MaybeParseCXX11Attributes(CXX11Attrs, /*MightBeObjCMessageSend*/ true);
-  ParsedAttributes GNUAttrs(AttrFactory);
+  ParsedAttributes GNUOrMSAttrs(AttrFactory);
   if (getLangOpts().OpenCL)
-    MaybeParseGNUAttributes(GNUAttrs);
+    MaybeParseGNUAttributes(GNUOrMSAttrs);
+
+  if (getLangOpts().HLSL)
+    MaybeParseMicrosoftAttributes(GNUOrMSAttrs);
 
   StmtResult Res = ParseStatementOrDeclarationAfterAttributes(
-      Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUAttrs);
+      Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUOrMSAttrs);
   MaybeDestroyTemplateIds();
 
   // Attributes that are left should all go on the statement, so concatenate the
   // two lists.
   ParsedAttributes Attrs(AttrFactory);
-  takeAndConcatenateAttrs(CXX11Attrs, GNUAttrs, Attrs);
+  takeAndConcatenateAttrs(CXX11Attrs, GNUOrMSAttrs, Attrs);
 
   assert((Attrs.empty() || Res.isInvalid() || Res.isUsable()) &&
          "attributes on empty statement");

diff  --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 6f538ed55cb72..7f452d177c16f 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -16,6 +16,7 @@
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Sema/DelayedDiagnostic.h"
 #include "clang/Sema/Lookup.h"
+#include "clang/Sema/ParsedAttr.h"
 #include "clang/Sema/ScopeInfo.h"
 #include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/StringExtras.h"
@@ -584,6 +585,39 @@ static Attr *handleOpenCLUnrollHint(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) OpenCLUnrollHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
+                                    SourceRange Range) {
+
+  if (A.getSemanticSpelling() == HLSLLoopHintAttr::Spelling::Microsoft_loop &&
+      !A.checkAtMostNumArgs(S, 0))
+    return nullptr;
+
+  unsigned UnrollFactor = 0;
+  if (A.getNumArgs() == 1) {
+
+    if (A.isArgIdent(0)) {
+      S.Diag(A.getLoc(), diag::err_attribute_argument_type)
+          << A << AANT_ArgumentIntegerConstant << A.getRange();
+      return nullptr;
+    }
+
+    Expr *E = A.getArgAsExpr(0);
+
+    if (S.CheckLoopHintExpr(E, St->getBeginLoc(),
+                            /*AllowZero=*/false))
+      return nullptr;
+
+    std::optional<llvm::APSInt> ArgVal = E->getIntegerConstantExpr(S.Context);
+    // CheckLoopHintExpr handles non int const cases
+    assert(ArgVal != std::nullopt && "ArgVal should be an integer constant.");
+    int Val = ArgVal->getSExtValue();
+    // CheckLoopHintExpr handles negative and zero cases
+    assert(Val > 0 && "Val should be a positive integer greater than zero.");
+    UnrollFactor = static_cast<unsigned>(Val);
+  }
+  return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -618,6 +652,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleFallThroughAttr(S, St, A, Range);
   case ParsedAttr::AT_LoopHint:
     return handleLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLLoopHint:
+    return handleHLSLLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:

diff  --git a/clang/test/CodeGenHLSL/loops/unroll.hlsl b/clang/test/CodeGenHLSL/loops/unroll.hlsl
new file mode 100644
index 0000000000000..7389f21dd3472
--- /dev/null
+++ b/clang/test/CodeGenHLSL/loops/unroll.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library -disable-llvm-passes %s -emit-llvm -o - | FileCheck %s
+
+/*** for ***/
+void for_count()
+{
+// CHECK-LABEL: for_count
+    [unroll(8)]
+    for( int i = 0; i < 1000; ++i);
+// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISTINCT:.*]]
+}
+
+void for_disable()
+{
+// CHECK-LABEL: for_disable
+    [loop]
+    for( int i = 0; i < 1000; ++i);
+// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISABLE:.*]]
+}
+
+void for_enable()
+{
+// CHECK-LABEL: for_enable
+    [unroll]
+    for( int i = 0; i < 1000; ++i);
+// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_ENABLE:.*]]
+}
+
+void for_nested_one_unroll_enable()
+{
+// CHECK-LABEL: for_nested_one_unroll_enable
+    int s = 0;
+    [unroll]
+    for( int i = 0; i < 1000; ++i) {
+        for( int j = 0; j < 10; ++j)
+            s += i + j;
+    }
+// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED_ENABLE:.*]]
+// CHECK-NOT: br label %{{.*}}, !llvm.loop ![[FOR_NESTED_1_ENABLE:.*]]
+}
+
+void for_nested_two_unroll_enable()
+{
+// CHECK-LABEL: for_nested_two_unroll_enable
+    int s = 0;
+    [unroll]
+    for( int i = 0; i < 1000; ++i) {
+        [unroll]
+        for( int j = 0; j < 10; ++j)
+            s += i + j;
+    }
+// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED2_ENABLE:.*]]
+// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED2_1_ENABLE:.*]]
+}
+
+
+/*** while ***/
+void while_count()
+{
+// CHECK-LABEL: while_count
+    int i = 1000;
+    [unroll(4)]
+    while(i-->0);
+// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISTINCT:.*]]
+}
+
+void while_disable()
+{
+// CHECK-LABEL: while_disable
+    int i = 1000;
+    [loop]
+    while(i-->0);
+// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISABLE:.*]]
+}
+
+void while_enable()
+{
+// CHECK-LABEL: while_enable
+    int i = 1000;
+    [unroll]
+    while(i-->0);
+// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_ENABLE:.*]]
+}
+
+/*** do ***/
+void do_count()
+{
+// CHECK-LABEL: do_count
+    int i = 1000;
+    [unroll(16)]
+    do {} while(i--> 0);
+// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISTINCT:.*]]
+}
+
+void do_disable()
+{
+// CHECK-LABEL: do_disable
+    int i = 1000;
+    [loop]
+    do {} while(i--> 0);
+// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISABLE:.*]]
+}
+
+void do_enable()
+{
+// CHECK-LABEL: do_enable
+    int i = 1000;
+    [unroll]
+    do {} while(i--> 0);
+// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_ENABLE:.*]]
+}
+
+
+// CHECK: ![[FOR_DISTINCT]]     =  distinct !{![[FOR_DISTINCT]],  ![[FOR_COUNT:.*]]}
+// CHECK: ![[FOR_COUNT]]         =  !{!"llvm.loop.unroll.count", i32 8}
+// CHECK: ![[FOR_DISABLE]]   =  distinct !{![[FOR_DISABLE]],  ![[DISABLE:.*]]}
+// CHECK: ![[DISABLE]]       =  !{!"llvm.loop.unroll.disable"}
+// CHECK: ![[FOR_ENABLE]]      =  distinct !{![[FOR_ENABLE]],  ![[ENABLE:.*]]}
+// CHECK: ![[ENABLE]]          =  !{!"llvm.loop.unroll.enable"}
+// CHECK: ![[FOR_NESTED_ENABLE]] =  distinct !{![[FOR_NESTED_ENABLE]], ![[ENABLE]]}
+// CHECK: ![[FOR_NESTED2_ENABLE]] =  distinct !{![[FOR_NESTED2_ENABLE]], ![[ENABLE]]}
+// CHECK: ![[FOR_NESTED2_1_ENABLE]] =  distinct !{![[FOR_NESTED2_1_ENABLE]], ![[ENABLE]]}
+// CHECK: ![[WHILE_DISTINCT]]   =  distinct !{![[WHILE_DISTINCT]],    ![[WHILE_COUNT:.*]]}
+// CHECK: ![[WHILE_COUNT]]         =  !{!"llvm.loop.unroll.count", i32 4}
+// CHECK: ![[WHILE_DISABLE]] =  distinct !{![[WHILE_DISABLE]],  ![[DISABLE]]}
+// CHECK: ![[WHILE_ENABLE]]    =  distinct !{![[WHILE_ENABLE]],     ![[ENABLE]]}
+// CHECK: ![[DO_DISTINCT]]      =  distinct !{![[DO_DISTINCT]],       ![[DO_COUNT:.*]]}
+// CHECK: ![[DO_COUNT]]         =  !{!"llvm.loop.unroll.count", i32 16}
+// CHECK: ![[DO_DISABLE]]    =  distinct !{![[DO_DISABLE]],     ![[DISABLE]]}
+// CHECK: ![[DO_ENABLE]]       =  distinct !{![[DO_ENABLE]],        ![[ENABLE]]}

diff  --git a/clang/test/SemaHLSL/Loops/unroll.hlsl b/clang/test/SemaHLSL/Loops/unroll.hlsl
new file mode 100644
index 0000000000000..2e2be319e4666
--- /dev/null
+++ b/clang/test/SemaHLSL/Loops/unroll.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -O0 -finclude-default-header -fsyntax-only -triple dxil-pc-shadermodel6.6-library %s -verify
+void unroll_no_vars() {
+  int I = 3;
+  [unroll(I)]  // expected-error {{'unroll' attribute requires an integer constant}}
+  while (I--);
+}
+
+void unroll_arg_count() {
+   [unroll(2,4)] // expected-error {{'unroll' attribute takes no more than 1 argument}}
+  for(int i=0; i<100; i++);
+}
+
+void loop_arg_count() {
+   [loop(2)] // expected-error {{'loop' attribute takes no more than 0 argument}}
+  for(int i=0; i<100; i++);
+}
+
+void unroll_no_negative() {
+  [unroll(-1)] // expected-error {{invalid value '-1'; must be positive}}
+  for(int i=0; i<100; i++);
+}
+
+void unroll_no_zero() {
+  [unroll(0)] // expected-error {{invalid value '0'; must be positive}}
+  for(int i=0; i<100; i++);
+}
+
+void unroll_no_float() {
+  [unroll(2.1)] // expected-error {{invalid argument of type 'float'; expected an integer type}}
+  for(int i=0; i<100; i++);
+}
+
+void unroll_no_bool_false() {
+  [unroll(false)] // expected-error {{invalid argument of type 'bool'; expected an integer type}}
+  for(int i=0; i<100; i++);
+}
+
+void unroll_no_bool_true() {
+  [unroll(true)] // expected-error {{invalid argument of type 'bool'; expected an integer type}}
+  for(int i=0; i<100; i++);
+}
+
+void unroll_loop_enforcement() {
+  int x[10];
+  [unroll(4)] // expected-error {{'unroll' attribute only applies to 'for', 'while', and 'do' statements}}
+  if (x[0])
+    x[0] = 15;
+}


        


More information about the cfe-commits mailing list