[clang] [llvm] [HLSL] Adding Flatten and Branch if attributes (PR #116331)

via cfe-commits cfe-commits at lists.llvm.org
Fri Nov 15 22:02:02 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v

@llvm/pr-subscribers-clang-codegen

Author: None (joaosaffran)

<details>
<summary>Changes</summary>

- adding Flatten and Branch to if stmt.
- adding dxil control flow hint metadata generation
- modifing spirv OpSelectMerge to account for the specific attributes.

Closes #<!-- -->70112

---

Patch is 28.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116331.diff


17 Files Affected:

- (modified) clang/include/clang/Basic/Attr.td (+10) 
- (modified) clang/lib/CodeGen/CGStmt.cpp (+7) 
- (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+13-1) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+6) 
- (modified) clang/lib/CodeGen/CodeGenPGO.cpp (+13) 
- (modified) clang/lib/Sema/SemaStmtAttr.cpp (+8) 
- (added) clang/test/AST/HLSL/HLSLBranchHint.hlsl (+43) 
- (added) clang/test/CodeGenHLSL/HLSLBranchHint.hlsl (+48) 
- (modified) llvm/include/llvm/IR/FixedMetadataKinds.def (+3) 
- (modified) llvm/include/llvm/IR/IRBuilder.h (+8-3) 
- (modified) llvm/include/llvm/IR/ProfDataUtils.h (+7) 
- (modified) llvm/lib/IR/ProfDataUtils.cpp (+7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+28-6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+17-8) 
- (added) llvm/test/CodeGen/DirectX/HLSLBranchHint.ll (+95) 
- (added) llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll (+91) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/HLSLBranchHint.ll (+91) 


``````````diff
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index b3c357ec906a23..6d3e07ce83100b 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4302,6 +4302,16 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
+def HLSLBranchHint: StmtAttr {
+  /// [branch]
+  /// [flatten]
+  let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
+  let Subjects = SubjectList<[IfStmt],
+                              ErrorDiag, "'if' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [InternalOnly];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 698baf853507f4..bf82e50c8b59e0 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -730,6 +730,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
+  HLSLBranchHintAttr::Spelling flattenOrBranch =
+      HLSLBranchHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -761,6 +763,9 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
+    case attr::HLSLBranchHint: {
+      flattenOrBranch = cast<HLSLBranchHintAttr>(A)->getSemanticSpelling();
+    } break;
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -768,6 +773,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
+  SaveAndRestore save_flattenOrBranch(HLSLBranchHintAttributedSpelling,
+                                      flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 6ead45793742d6..a790d1fde0e104 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2052,6 +2052,7 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
 
   llvm::MDNode *Weights = nullptr;
   llvm::MDNode *Unpredictable = nullptr;
+  llvm::MDNode *ControlFlowHint = nullptr;
 
   // If the branch has a condition wrapped by __builtin_unpredictable,
   // create metadata that specifies that the branch is unpredictable.
@@ -2076,7 +2077,18 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
+  switch (HLSLBranchHintAttributedSpelling) {
+
+  case HLSLBranchHintAttr::Microsoft_branch:
+  case HLSLBranchHintAttr::Microsoft_flatten:
+    ControlFlowHint = createControlFlowHint(HLSLBranchHintAttributedSpelling);
+    break;
+  case HLSLBranchHintAttr::SpellingNotCalculated:
+    break;
+  }
+
+  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable,
+                       ControlFlowHint);
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fcc1013d7361ec..d3f1af6f5a5ce9 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -615,6 +615,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// True if the current statement has noconvergent attribute.
   bool InNoConvergentAttributedStmt = false;
 
+  /// HLSL Branch attribute.
+  HLSLBranchHintAttr::Spelling HLSLBranchHintAttributedSpelling =
+      HLSLBranchHintAttr::SpellingNotCalculated;
+
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
   const CallExpr *MustTailCall = nullptr;
@@ -1612,6 +1616,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Bitmap used by MC/DC to track condition outcomes of a boolean expression.
   Address MCDCCondBitmapAddr = Address::invalid();
 
+  llvm::MDNode *createControlFlowHint(HLSLBranchHintAttr::Spelling S) const;
+
   /// Calculate branch weights appropriate for PGO data
   llvm::MDNode *createProfileWeights(uint64_t TrueCount,
                                      uint64_t FalseCount) const;
diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp
index 820bb521ccf850..ff3826ca6da854 100644
--- a/clang/lib/CodeGen/CodeGenPGO.cpp
+++ b/clang/lib/CodeGen/CodeGenPGO.cpp
@@ -1451,6 +1451,19 @@ static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
   return Scaled;
 }
 
+llvm::MDNode *
+CodeGenFunction::createControlFlowHint(HLSLBranchHintAttr::Spelling S) const {
+  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+  SmallVector<llvm::Metadata *, 2> Vals(llvm::ArrayRef<llvm::Metadata *>{
+      MDHelper.createString("dx.controlflow.hints"),
+      S == HLSLBranchHintAttr::Spelling::Microsoft_branch
+          ? MDHelper.createConstant(llvm::ConstantInt::get(Int32Ty, 1))
+          : MDHelper.createConstant(llvm::ConstantInt::get(Int32Ty, 2))});
+
+  return llvm::MDNode::get(CGM.getLLVMContext(), Vals);
+}
+
 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
                                                     uint64_t FalseCount) const {
   // Check for empty weights.
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index f801455596fe6f..827a674ece4bc0 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -623,6 +623,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLBranchHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                  SourceRange Range) {
+
+  return ::new (S.Context) HLSLBranchHintAttr(S.Context, A);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -659,6 +665,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLBranchHint:
+    return handleHLSLBranchHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:
diff --git a/clang/test/AST/HLSL/HLSLBranchHint.hlsl b/clang/test/AST/HLSL/HLSLBranchHint.hlsl
new file mode 100644
index 00000000000000..907d6c5ff580c0
--- /dev/null
+++ b/clang/test/AST/HLSL/HLSLBranchHint.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: HLSLBranchHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+export int branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: HLSLBranchHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+export int flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
+// CHECK-NO: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NO: HLSLBranchHintAttr
+export int no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
diff --git a/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl b/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
new file mode 100644
index 00000000000000..11b14b1ec43175
--- /dev/null
+++ b/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+
+// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+export int test_branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !dx.controlflow.hints [[HINT_FLATTEN:![0-9]+]]
+export int test_flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK-NO: !dx.controlflow.hints
+export int test_no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+//CHECK: [[HINT_BRANCH]] = !{!"dx.controlflow.hints", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"dx.controlflow.hints", i32 2}
diff --git a/llvm/include/llvm/IR/FixedMetadataKinds.def b/llvm/include/llvm/IR/FixedMetadataKinds.def
index df572e8791e13b..a239a076d131f7 100644
--- a/llvm/include/llvm/IR/FixedMetadataKinds.def
+++ b/llvm/include/llvm/IR/FixedMetadataKinds.def
@@ -53,3 +53,6 @@ LLVM_FIXED_MD_KIND(MD_DIAssignID, "DIAssignID", 38)
 LLVM_FIXED_MD_KIND(MD_coro_outside_frame, "coro.outside.frame", 39)
 LLVM_FIXED_MD_KIND(MD_mmra, "mmra", 40)
 LLVM_FIXED_MD_KIND(MD_noalias_addrspace, "noalias.addrspace", 41)
+// TODO: this will likelly be placed somewhere else,
+// so we don't mix dxil/hlsl/spirv and clang metadata
+LLVM_FIXED_MD_KIND(MD_dxil_controlflow_hints, "dx.controlflow.hints", 42)
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 23fd8350a29b3d..215d3362bd4f6c 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -1101,11 +1101,14 @@ class IRBuilderBase {
   /// instruction.
   /// \returns The annotated instruction.
   template <typename InstTy>
-  InstTy *addBranchMetadata(InstTy *I, MDNode *Weights, MDNode *Unpredictable) {
+  InstTy *addBranchMetadata(InstTy *I, MDNode *Weights, MDNode *Unpredictable,
+                            MDNode *ControlFlowHint = nullptr) {
     if (Weights)
       I->setMetadata(LLVMContext::MD_prof, Weights);
     if (Unpredictable)
       I->setMetadata(LLVMContext::MD_unpredictable, Unpredictable);
+    if (ControlFlowHint)
+      I->setMetadata(LLVMContext::MD_dxil_controlflow_hints, ControlFlowHint);
     return I;
   }
 
@@ -1143,9 +1146,11 @@ class IRBuilderBase {
   /// instruction.
   BranchInst *CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False,
                            MDNode *BranchWeights = nullptr,
-                           MDNode *Unpredictable = nullptr) {
+                           MDNode *Unpredictable = nullptr,
+                           MDNode *ControlFlowHint = nullptr) {
     return Insert(addBranchMetadata(BranchInst::Create(True, False, Cond),
-                                    BranchWeights, Unpredictable));
+                                    BranchWeights, Unpredictable,
+                                    ControlFlowHint));
   }
 
   /// Create a conditional 'br Cond, TrueDest, FalseDest'
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 0bea517df832e3..49f7da480f928d 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -48,6 +48,13 @@ bool hasValidBranchWeightMD(const Instruction &I);
 /// Nullptr otherwise.
 MDNode *getBranchWeightMDNode(const Instruction &I);
 
+/// Get the branching metadata information
+///
+/// \param I The Instruction to get the weights from.
+/// \returns A pointer to I's branch weights metadata node, if it exists.
+/// Nullptr otherwise.
+MDNode *getDxBranchHint(const Instruction &I);
+
 /// Get the valid branch weights metadata node
 ///
 /// \param I The Instruction to get the weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 5441228b3291ee..47a5059017a48f 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -150,6 +150,13 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
   return ProfileData;
 }
 
+MDNode *getDxBranchHint(const Instruction &I) {
+  MDNode *Node = I.getMetadata(LLVMContext::MD_dxil_controlflow_hints);
+  if (!isTargetMD(Node, "dx.controlflow.hints", 2))
+    return nullptr;
+  return Node;
+}
+
 MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   auto *ProfileData = getBranchWeightMDNode(I);
   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 8a8835e0269200..5576f46cf89f9a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2694,12 +2694,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
-  case Intrinsic::spv_loop_merge:
-  case Intrinsic::spv_selection_merge: {
-    const auto Opcode = IID == Intrinsic::spv_selection_merge
-                            ? SPIRV::OpSelectionMerge
-                            : SPIRV::OpLoopMerge;
-    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+  case Intrinsic::spv_loop_merge: {
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
     for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
@@ -2707,6 +2703,32 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     MIB.addImm(SPIRV::SelectionControl::None);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_selection_merge: {
+
+    auto SelectionControl = SPIRV::SelectionControl::None;
+    const MDNode *MDOp =
+        I.getOperand(I.getNumExplicitOperands() - 1).getMetadata();
+    if (MDOp->getNumOperands() > 0) {
+      ConstantInt *BranchHint =
+          mdconst::extract<ConstantInt>(MDOp->getOperand(1));
+
+      if (BranchHint->equalsInt(2))
+        SelectionControl = SPIRV::SelectionControl::Flatten;
+      else if (BranchHint->equalsInt(1))
+        SelectionControl = SPIRV::SelectionControl::DontFlatten;
+      else
+        llvm_unreachable("Invalid value for SelectionControl");
+    }
+
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
+    for (unsigned i = 1; i < I.getNumExplicitOperands() - 1; ++i) {
+      assert(I.getOperand(i).isMBB());
+      MIB.addMBB(I.getOperand(i).getMBB());
+    }
+    MIB.addImm(SelectionControl);
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 13e05b67927518..57285869b2e338 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -23,6 +23,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
@@ -644,8 +645,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
 
       Modified = true;
     }
@@ -767,10 +767,9 @@ class SPIRVStructurizer : public FunctionPass {
       BasicBlock *Merge = Candidates[0];
 
       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
       IRBuilder<> Builder(&BB);
       Builder.SetInsertPoint(BB.getTerminator());
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1103,8 +1102,7 @@ class SPIRVStructurizer : public FunctionPass {
         Builder.SetInsertPoint(Header->getTerminator());
 
         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-        SmallVector<Value *, 1> Args = {MergeAddress};
-        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        createOpSelectMerge(&Builder, MergeAddress);
         continue;
       }
 
@@ -1118,8 +1116,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1206,6 +1203,18 @@ class SPIRVStructurizer : public FunctionPass {
     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
+
+  void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
+    Instruction *BBTerminatoInst = Builder->GetInsertBlock()->getTerminator();
+
+    MDNode *BranchMdNode = getDxBranchHint(*BBTerminatoInst);
+    Value *MDNodeValue =
+        MetadataAsValue::get(Builder->getContext(), BranchMdNode);
+
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, MDNodeValue};
+
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+  }
 };
 } // namespace llvm
 
diff --git a/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll b/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
new file mode 100644
index 00000000000000..e7128d19283e80
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
@@ -0,0 +1,95 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; This test make sure LLVM metadata is propagating to DXIL.
+
+
+; CHECK: define i32 @test_branch(i32 %X)
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+define i32 @test_branch(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+
+...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116331


More information about the cfe-commits mailing list