[clang] [llvm] [HLSL] Adding Flatten and Branch if attributes (PR #116331)

via cfe-commits cfe-commits at lists.llvm.org
Thu Dec 5 12:16:50 PST 2024


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/116331

>From 3c792216f88e87b69b3ea7415c2fd74b7f5d7469 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 25 Oct 2024 17:48:41 +0000
Subject: [PATCH 01/12] adding comments

---
 clang/include/clang/Basic/Attr.td     | 10 ++++++++++
 clang/lib/CodeGen/CGStmt.cpp          |  2 ++
 clang/lib/CodeGen/CodeGenFunction.cpp |  2 ++
 clang/lib/Sema/SemaStmtAttr.cpp       |  8 ++++++++
 4 files changed, 22 insertions(+)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index b3c357ec906a23..6d3e07ce83100b 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4302,6 +4302,16 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
+def HLSLBranchHint: StmtAttr {
+  /// [branch]
+  /// [flatten]
+  let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
+  let Subjects = SubjectList<[IfStmt],
+                              ErrorDiag, "'if' statements">;
+  let LangOpts = [HLSL];
+  let Documentation = [InternalOnly];
+}
+
 def CapturedRecord : InheritableAttr {
   // This attribute has no spellings as it is only ever created implicitly.
   let Spellings = [];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 698baf853507f4..7b01dc84b55365 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -761,6 +761,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
+    // [jderezende] TODO: Add HLSLBranchHint, to mark if flatten/branch is present.
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -768,6 +769,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
+  // [jderezende] TODO: Save HLSLBranchHint information
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 6ead45793742d6..fc4fde81984d64 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -43,6 +43,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/CRC.h"
 #include "llvm/Support/xxhash.h"
@@ -2076,6 +2077,7 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
+  // [jderezende] TODO: Emit branch metadata marking it as flatten/branch, if exists.
   Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
 }
 
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index f801455596fe6f..68323092cb564d 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -623,6 +623,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
+static Attr *handleHLSLBranchHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                    SourceRange Range) {
+
+  return ::new (S.Context) HLSLBranchHintAttr(S.Context, A);
+}
+
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
                                   SourceRange Range) {
   if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -659,6 +665,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
+  case ParsedAttr::AT_HLSLBranchHint:
+    return handleHLSLBranchHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:

>From f48e382b7ff396fbf1b2ce7dcea9529a06fa9b12 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 28 Oct 2024 23:58:10 +0000
Subject: [PATCH 02/12] continue exploration

---
 llvm/include/llvm/IR/FixedMetadataKinds.def | 3 +++
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp | 1 +
 2 files changed, 4 insertions(+)

diff --git a/llvm/include/llvm/IR/FixedMetadataKinds.def b/llvm/include/llvm/IR/FixedMetadataKinds.def
index df572e8791e13b..02a986d42f1933 100644
--- a/llvm/include/llvm/IR/FixedMetadataKinds.def
+++ b/llvm/include/llvm/IR/FixedMetadataKinds.def
@@ -53,3 +53,6 @@ LLVM_FIXED_MD_KIND(MD_DIAssignID, "DIAssignID", 38)
 LLVM_FIXED_MD_KIND(MD_coro_outside_frame, "coro.outside.frame", 39)
 LLVM_FIXED_MD_KIND(MD_mmra, "mmra", 40)
 LLVM_FIXED_MD_KIND(MD_noalias_addrspace, "noalias.addrspace", 41)
+// [jderezende] TODO: this will likelly be placed somewhere else,
+// so we don't mix dxil/hlsl/spirv and clang metadata
+LLVM_FIXED_MD_KIND(MD_dxil_controlflow_hints, "dx.controlflow.hints", 42)
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 13e05b67927518..bff8cc1cfc5821 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -645,6 +645,7 @@ class SPIRVStructurizer : public FunctionPass {
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
       SmallVector<Value *, 1> Args = {MergeAddress};
+      // [jderezende] TODO: Pass metadata from Header->getTerminator() to modify the intrinsic
       Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
 
       Modified = true;

>From feb9d2dfc020e27a69637774f896d16f74e20625 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 14 Nov 2024 02:13:12 +0000
Subject: [PATCH 03/12] adding attribute lowering

---
 clang/lib/CodeGen/CGStmt.cpp                  | 10 ++++--
 clang/lib/CodeGen/CodeGenFunction.cpp         | 15 +++++++-
 clang/lib/CodeGen/CodeGenFunction.h           |  7 ++++
 clang/lib/CodeGen/CodeGenPGO.cpp              | 17 ++++++++++
 llvm/include/llvm/IR/IRBuilder.h              | 11 ++++--
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +-
 llvm/include/llvm/IR/ProfDataUtils.h          |  7 ++++
 llvm/lib/IR/ProfDataUtils.cpp                 |  7 ++++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 34 +++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVMetadata.cpp       |  2 +-
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp   | 14 ++++++--
 11 files changed, 109 insertions(+), 17 deletions(-)

diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 7b01dc84b55365..2f95ceb297f7d8 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -16,6 +16,7 @@
 #include "CodeGenModule.h"
 #include "TargetInfo.h"
 #include "clang/AST/Attr.h"
+#include "clang/AST/Attrs.inc"
 #include "clang/AST/Expr.h"
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtVisitor.h"
@@ -730,6 +731,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
+  HLSLBranchHintAttr::Spelling flattenOrBranch =
+      HLSLBranchHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -761,7 +764,9 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
-    // [jderezende] TODO: Add HLSLBranchHint, to mark if flatten/branch is present.
+    case attr::HLSLBranchHint: {
+      flattenOrBranch = cast<HLSLBranchHintAttr>(A)->getSemanticSpelling();
+    } break;
     }
   }
   SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
@@ -769,7 +774,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
-  // [jderezende] TODO: Save HLSLBranchHint information
+  SaveAndRestore save_flattenOrBranch(HLSLBranchHintAttributedSpelling,
+                                      flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index fc4fde81984d64..a799e017f46c5e 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -24,6 +24,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTLambda.h"
 #include "clang/AST/Attr.h"
+#include "clang/AST/Attrs.inc"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/Expr.h"
@@ -2053,6 +2054,7 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
 
   llvm::MDNode *Weights = nullptr;
   llvm::MDNode *Unpredictable = nullptr;
+  llvm::MDNode *ControlFlowHint = nullptr;
 
   // If the branch has a condition wrapped by __builtin_unpredictable,
   // create metadata that specifies that the branch is unpredictable.
@@ -2077,8 +2079,19 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
+  switch (HLSLBranchHintAttributedSpelling) {
+
+  case HLSLBranchHintAttr::Microsoft_branch:
+  case HLSLBranchHintAttr::Microsoft_flatten:
+    ControlFlowHint = createControlFlowHint(HLSLBranchHintAttributedSpelling);
+    break;
+  case HLSLBranchHintAttr::SpellingNotCalculated:
+    break;
+  }
+
   // [jderezende] TODO: Emit branch metadata marking it as flatten/branch, if exists.
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
+  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable,
+                       ControlFlowHint);
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fcc1013d7361ec..f1001dc21e5e5d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -21,6 +21,7 @@
 #include "CodeGenPGO.h"
 #include "EHScopeStack.h"
 #include "VarBypassDetector.h"
+#include "clang/AST/Attrs.inc"
 #include "clang/AST/CharUnits.h"
 #include "clang/AST/CurrentSourceLocExprScope.h"
 #include "clang/AST/ExprCXX.h"
@@ -615,6 +616,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// True if the current statement has noconvergent attribute.
   bool InNoConvergentAttributedStmt = false;
 
+  /// HLSL Branch attribute.
+  HLSLBranchHintAttr::Spelling HLSLBranchHintAttributedSpelling =
+      HLSLBranchHintAttr::SpellingNotCalculated;
+
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
   const CallExpr *MustTailCall = nullptr;
@@ -1612,6 +1617,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Bitmap used by MC/DC to track condition outcomes of a boolean expression.
   Address MCDCCondBitmapAddr = Address::invalid();
 
+  llvm::MDNode *createControlFlowHint(HLSLBranchHintAttr::Spelling S) const;
+
   /// Calculate branch weights appropriate for PGO data
   llvm::MDNode *createProfileWeights(uint64_t TrueCount,
                                      uint64_t FalseCount) const;
diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp
index 820bb521ccf850..d8d72f91f28708 100644
--- a/clang/lib/CodeGen/CodeGenPGO.cpp
+++ b/clang/lib/CodeGen/CodeGenPGO.cpp
@@ -13,10 +13,14 @@
 #include "CodeGenPGO.h"
 #include "CodeGenFunction.h"
 #include "CoverageMappingGen.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Attrs.inc"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/StmtVisitor.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Metadata.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/FileSystem.h"
@@ -1451,6 +1455,19 @@ static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
   return Scaled;
 }
 
+llvm::MDNode *
+CodeGenFunction::createControlFlowHint(HLSLBranchHintAttr::Spelling S) const {
+  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+  SmallVector<llvm::Metadata *, 2> Vals(llvm::ArrayRef<llvm::Metadata *>{
+      MDHelper.createString("dx.controlflow.hints"),
+      S == HLSLBranchHintAttr::Spelling::Microsoft_branch
+          ? MDHelper.createConstant(llvm::ConstantInt::get(Int32Ty, 1))
+          : MDHelper.createConstant(llvm::ConstantInt::get(Int32Ty, 2))});
+
+  return llvm::MDNode::get(CGM.getLLVMContext(), Vals);
+}
+
 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
                                                     uint64_t FalseCount) const {
   // Check for empty weights.
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 23fd8350a29b3d..215d3362bd4f6c 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -1101,11 +1101,14 @@ class IRBuilderBase {
   /// instruction.
   /// \returns The annotated instruction.
   template <typename InstTy>
-  InstTy *addBranchMetadata(InstTy *I, MDNode *Weights, MDNode *Unpredictable) {
+  InstTy *addBranchMetadata(InstTy *I, MDNode *Weights, MDNode *Unpredictable,
+                            MDNode *ControlFlowHint = nullptr) {
     if (Weights)
       I->setMetadata(LLVMContext::MD_prof, Weights);
     if (Unpredictable)
       I->setMetadata(LLVMContext::MD_unpredictable, Unpredictable);
+    if (ControlFlowHint)
+      I->setMetadata(LLVMContext::MD_dxil_controlflow_hints, ControlFlowHint);
     return I;
   }
 
@@ -1143,9 +1146,11 @@ class IRBuilderBase {
   /// instruction.
   BranchInst *CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False,
                            MDNode *BranchWeights = nullptr,
-                           MDNode *Unpredictable = nullptr) {
+                           MDNode *Unpredictable = nullptr,
+                           MDNode *ControlFlowHint = nullptr) {
     return Insert(addBranchMetadata(BranchInst::Create(True, False, Cond),
-                                    BranchWeights, Unpredictable));
+                                    BranchWeights, Unpredictable,
+                                    ControlFlowHint));
   }
 
   /// Create a conditional 'br Cond, TrueDest, FalseDest'
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index f29eb7ee22b2d2..4a5f744c74d9be 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_metadata_ty, llvm_vararg_ty]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 0bea517df832e3..49f7da480f928d 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -48,6 +48,13 @@ bool hasValidBranchWeightMD(const Instruction &I);
 /// Nullptr otherwise.
 MDNode *getBranchWeightMDNode(const Instruction &I);
 
+/// Get the branching metadata information
+///
+/// \param I The Instruction to get the weights from.
+/// \returns A pointer to I's branch weights metadata node, if it exists.
+/// Nullptr otherwise.
+MDNode *getDxBranchHint(const Instruction &I);
+
 /// Get the valid branch weights metadata node
 ///
 /// \param I The Instruction to get the weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 5441228b3291ee..47a5059017a48f 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -150,6 +150,13 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
   return ProfileData;
 }
 
+MDNode *getDxBranchHint(const Instruction &I) {
+  MDNode *Node = I.getMetadata(LLVMContext::MD_dxil_controlflow_hints);
+  if (!isTargetMD(Node, "dx.controlflow.hints", 2))
+    return nullptr;
+  return Node;
+}
+
 MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   auto *ProfileData = getBranchWeightMDNode(I);
   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 8a8835e0269200..f3f78fb4f2bdbb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -32,6 +32,7 @@
 #include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/Metadata.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "spirv-isel"
@@ -2694,12 +2695,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
-  case Intrinsic::spv_loop_merge:
-  case Intrinsic::spv_selection_merge: {
-    const auto Opcode = IID == Intrinsic::spv_selection_merge
-                            ? SPIRV::OpSelectionMerge
-                            : SPIRV::OpLoopMerge;
-    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+  case Intrinsic::spv_loop_merge: {
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
     for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
@@ -2707,6 +2704,31 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     MIB.addImm(SPIRV::SelectionControl::None);
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_selection_merge: {
+
+    auto SelectionControl = SPIRV::SelectionControl::None;
+    const MDNode *MDOp = I.getOperand(1).getMetadata();
+    if (MDOp->getNumOperands() > 0) {
+      ConstantInt *BranchHint =
+          mdconst::extract<ConstantInt>(MDOp->getOperand(1));
+
+      if (BranchHint->equalsInt(2))
+        SelectionControl = SPIRV::SelectionControl::Flatten;
+      else if (BranchHint->equalsInt(1))
+        SelectionControl = SPIRV::SelectionControl::DontFlatten;
+      else
+        llvm_unreachable("Invalid value for SelectionControl");
+    }
+
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
+    for (unsigned i = 2; i < I.getNumExplicitOperands(); ++i) {
+      assert(I.getOperand(i).isMBB());
+      MIB.addMBB(I.getOperand(i).getMBB());
+    }
+    MIB.addImm(SelectionControl);
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
index 3800aac70df327..7d5617919df87c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVMetadata.h"
+#include "llvm/IR/Metadata.h"
 
 using namespace llvm;
 
@@ -81,5 +82,4 @@ MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
       "Kernel attributes are attached/belong only to OpenCL kernel functions");
   return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
 }
-
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index bff8cc1cfc5821..6fed0b81bf61e7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -13,6 +13,7 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
+#include "llvm-c/Core.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
@@ -23,6 +24,8 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
@@ -95,7 +98,7 @@ BasicBlock *getDesignatedMergeBlock(Instruction *I) {
       II->getIntrinsicID() != Intrinsic::spv_selection_merge)
     return nullptr;
 
-  BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
+  BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
   return BA->getBasicBlock();
 }
 
@@ -644,8 +647,13 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      // [jderezende] TODO: Pass metadata from Header->getTerminator() to modify the intrinsic
+
+      MDNode *BranchMdNode = getDxBranchHint(*Header->getTerminator());
+      Value *MDNodeValue =
+          MetadataAsValue::get(Builder.getContext(), BranchMdNode);
+
+      SmallVector<Value *, 2> Args = {MDNodeValue, MergeAddress};
+
       Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
 
       Modified = true;

>From 9beb5f17bfefcea148c1954b1b7bd458c4372389 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 15 Nov 2024 00:08:26 +0000
Subject: [PATCH 04/12] adding tests

---
 clang/lib/CodeGen/CGStmt.cpp                  |  1 -
 clang/lib/CodeGen/CodeGenFunction.cpp         |  3 -
 clang/lib/CodeGen/CodeGenFunction.h           |  1 -
 clang/lib/CodeGen/CodeGenPGO.cpp              |  4 -
 clang/test/AST/HLSL/HLSLBranchHint.hlsl       | 43 +++++++++
 clang/test/CodeGenHLSL/HLSLBranchHint.hlsl    | 48 ++++++++++
 llvm/include/llvm/IR/FixedMetadataKinds.def   |  2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  1 -
 llvm/lib/Target/SPIRV/SPIRVMetadata.cpp       |  2 +-
 llvm/test/CodeGen/DirectX/HLSLBranchHint.ll   | 95 +++++++++++++++++++
 llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll     | 91 ++++++++++++++++++
 11 files changed, 279 insertions(+), 12 deletions(-)
 create mode 100644 clang/test/AST/HLSL/HLSLBranchHint.hlsl
 create mode 100644 clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll

diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 2f95ceb297f7d8..bf82e50c8b59e0 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -16,7 +16,6 @@
 #include "CodeGenModule.h"
 #include "TargetInfo.h"
 #include "clang/AST/Attr.h"
-#include "clang/AST/Attrs.inc"
 #include "clang/AST/Expr.h"
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtVisitor.h"
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index a799e017f46c5e..a790d1fde0e104 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -24,7 +24,6 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTLambda.h"
 #include "clang/AST/Attr.h"
-#include "clang/AST/Attrs.inc"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/Expr.h"
@@ -44,7 +43,6 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
-#include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/CRC.h"
 #include "llvm/Support/xxhash.h"
@@ -2089,7 +2087,6 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     break;
   }
 
-  // [jderezende] TODO: Emit branch metadata marking it as flatten/branch, if exists.
   Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable,
                        ControlFlowHint);
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index f1001dc21e5e5d..d3f1af6f5a5ce9 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -21,7 +21,6 @@
 #include "CodeGenPGO.h"
 #include "EHScopeStack.h"
 #include "VarBypassDetector.h"
-#include "clang/AST/Attrs.inc"
 #include "clang/AST/CharUnits.h"
 #include "clang/AST/CurrentSourceLocExprScope.h"
 #include "clang/AST/ExprCXX.h"
diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp
index d8d72f91f28708..ff3826ca6da854 100644
--- a/clang/lib/CodeGen/CodeGenPGO.cpp
+++ b/clang/lib/CodeGen/CodeGenPGO.cpp
@@ -13,14 +13,10 @@
 #include "CodeGenPGO.h"
 #include "CodeGenFunction.h"
 #include "CoverageMappingGen.h"
-#include "clang/AST/ASTContext.h"
-#include "clang/AST/Attrs.inc"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/StmtVisitor.h"
-#include "llvm/ADT/ArrayRef.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
-#include "llvm/IR/Metadata.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/FileSystem.h"
diff --git a/clang/test/AST/HLSL/HLSLBranchHint.hlsl b/clang/test/AST/HLSL/HLSLBranchHint.hlsl
new file mode 100644
index 00000000000000..907d6c5ff580c0
--- /dev/null
+++ b/clang/test/AST/HLSL/HLSLBranchHint.hlsl
@@ -0,0 +1,43 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: HLSLBranchHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+export int branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
+// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NEXT: HLSLBranchHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+export int flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
+// CHECK-NO: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
+// CHECK-NO: HLSLBranchHintAttr
+export int no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
diff --git a/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl b/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
new file mode 100644
index 00000000000000..11b14b1ec43175
--- /dev/null
+++ b/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
@@ -0,0 +1,48 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
+
+// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+export int test_branch(int X){
+    int resp;
+    [branch] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
+// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !dx.controlflow.hints [[HINT_FLATTEN:![0-9]+]]
+export int test_flatten(int X){
+    int resp;
+    [flatten] if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
+// CHECK-NO: !dx.controlflow.hints
+export int test_no_attr(int X){
+    int resp;
+    if (X > 0) {
+        resp = -X;
+    } else {
+        resp = X * 2;
+    }
+
+    return resp;
+}
+
+//CHECK: [[HINT_BRANCH]] = !{!"dx.controlflow.hints", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"dx.controlflow.hints", i32 2}
diff --git a/llvm/include/llvm/IR/FixedMetadataKinds.def b/llvm/include/llvm/IR/FixedMetadataKinds.def
index 02a986d42f1933..a239a076d131f7 100644
--- a/llvm/include/llvm/IR/FixedMetadataKinds.def
+++ b/llvm/include/llvm/IR/FixedMetadataKinds.def
@@ -53,6 +53,6 @@ LLVM_FIXED_MD_KIND(MD_DIAssignID, "DIAssignID", 38)
 LLVM_FIXED_MD_KIND(MD_coro_outside_frame, "coro.outside.frame", 39)
 LLVM_FIXED_MD_KIND(MD_mmra, "mmra", 40)
 LLVM_FIXED_MD_KIND(MD_noalias_addrspace, "noalias.addrspace", 41)
-// [jderezende] TODO: this will likelly be placed somewhere else,
+// TODO: this will likelly be placed somewhere else,
 // so we don't mix dxil/hlsl/spirv and clang metadata
 LLVM_FIXED_MD_KIND(MD_dxil_controlflow_hints, "dx.controlflow.hints", 42)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f3f78fb4f2bdbb..4903185309328e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -32,7 +32,6 @@
 #include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
-#include "llvm/IR/Metadata.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "spirv-isel"
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
index 7d5617919df87c..3800aac70df327 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
@@ -12,7 +12,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVMetadata.h"
-#include "llvm/IR/Metadata.h"
 
 using namespace llvm;
 
@@ -82,4 +81,5 @@ MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
       "Kernel attributes are attached/belong only to OpenCL kernel functions");
   return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
 }
+
 } // namespace llvm
diff --git a/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll b/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
new file mode 100644
index 00000000000000..e7128d19283e80
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
@@ -0,0 +1,95 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; This test make sure LLVM metadata is propagating to DXIL.
+
+
+; CHECK: define i32 @test_branch(i32 %X)
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+define i32 @test_branch(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+
+; CHECK: define i32 @test_flatten(i32 %X)
+; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_FLATTEN:![0-9]+]]
+define i32 @test_flatten(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !1
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+
+; CHECK: define i32 @test_no_attr(i32 %X)
+; CHECK-NO: !dx.controlflow.hints
+define i32 @test_no_attr(i32 %X) {
+entry:
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+; CHECK: [[HINT_BRANCH]] = !{!"dx.controlflow.hints", i32 1}
+; CHECK: [[HINT_FLATTEN]] = !{!"dx.controlflow.hints", i32 2}
+!0 = !{!"dx.controlflow.hints", i32 1}
+!1 = !{!"dx.controlflow.hints", i32 2}
diff --git a/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll b/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll
new file mode 100644
index 00000000000000..771f5603cf526f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll
@@ -0,0 +1,91 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+
+define spir_func noundef i32 @test_branch(i32 noundef %X) {
+entry:
+; CHECK-LABEL: ; -- Begin function test_branch
+; OpSelectionMerge %[[#]] DontFlatten
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+
+define spir_func noundef i32 @test_flatten(i32 noundef %X) {
+entry:
+; CHECK-LABEL: ; -- Begin function test_flatten
+; OpSelectionMerge %[[#]] Flatten
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !1
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+define spir_func noundef i32 @test_no_attr(i32 noundef %X) {
+entry:
+; CHECK-LABEL: ; -- Begin function test_no_attr
+; OpSelectionMerge %[[#]] None
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+!0 = !{!"dx.controlflow.hints", i32 1}
+!1 = !{!"dx.controlflow.hints", i32 2}

>From 4c33ffb6a243a278583e79eb0694a5cd8a016e6a Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 15 Nov 2024 18:48:14 +0000
Subject: [PATCH 05/12] fixing spirv failures

---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  5 +--
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp   | 35 ++++++++++---------
 .../{ => structurizer}/HLSLBranchHint.ll      |  0
 4 files changed, 22 insertions(+), 20 deletions(-)
 rename llvm/test/CodeGen/SPIRV/{ => structurizer}/HLSLBranchHint.ll (100%)

diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 4a5f744c74d9be..f29eb7ee22b2d2 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_metadata_ty, llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 4903185309328e..5576f46cf89f9a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2706,7 +2706,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_selection_merge: {
 
     auto SelectionControl = SPIRV::SelectionControl::None;
-    const MDNode *MDOp = I.getOperand(1).getMetadata();
+    const MDNode *MDOp =
+        I.getOperand(I.getNumExplicitOperands() - 1).getMetadata();
     if (MDOp->getNumOperands() > 0) {
       ConstantInt *BranchHint =
           mdconst::extract<ConstantInt>(MDOp->getOperand(1));
@@ -2721,7 +2722,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
 
     auto MIB =
         BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
-    for (unsigned i = 2; i < I.getNumExplicitOperands(); ++i) {
+    for (unsigned i = 1; i < I.getNumExplicitOperands() - 1; ++i) {
       assert(I.getOperand(i).isMBB());
       MIB.addMBB(I.getOperand(i).getMBB());
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 6fed0b81bf61e7..1d837296ddc647 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -13,18 +13,17 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
-#include "llvm-c/Core.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
-#include "llvm/IR/Metadata.h"
 #include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Transforms/Utils/Cloning.h"
@@ -98,7 +97,7 @@ BasicBlock *getDesignatedMergeBlock(Instruction *I) {
       II->getIntrinsicID() != Intrinsic::spv_selection_merge)
     return nullptr;
 
-  BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
+  BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
   return BA->getBasicBlock();
 }
 
@@ -647,14 +646,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
-
-      MDNode *BranchMdNode = getDxBranchHint(*Header->getTerminator());
-      Value *MDNodeValue =
-          MetadataAsValue::get(Builder.getContext(), BranchMdNode);
-
-      SmallVector<Value *, 2> Args = {MDNodeValue, MergeAddress};
-
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
 
       Modified = true;
     }
@@ -776,10 +768,9 @@ class SPIRVStructurizer : public FunctionPass {
       BasicBlock *Merge = Candidates[0];
 
       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
       IRBuilder<> Builder(&BB);
       Builder.SetInsertPoint(BB.getTerminator());
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1112,8 +1103,7 @@ class SPIRVStructurizer : public FunctionPass {
         Builder.SetInsertPoint(Header->getTerminator());
 
         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
-        SmallVector<Value *, 1> Args = {MergeAddress};
-        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        createOpSelectMerge(&Builder, MergeAddress);
         continue;
       }
 
@@ -1127,8 +1117,7 @@ class SPIRVStructurizer : public FunctionPass {
       Builder.SetInsertPoint(Header->getTerminator());
 
       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
-      SmallVector<Value *, 1> Args = {MergeAddress};
-      Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+      createOpSelectMerge(&Builder, MergeAddress);
     }
 
     return Modified;
@@ -1215,6 +1204,18 @@ class SPIRVStructurizer : public FunctionPass {
     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
+
+  void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
+    Instruction *BBTerminatoInst = Builder->GetInsertBlock()->getTerminator();
+
+    MDNode *BranchMdNode = getDxBranchHint(*BBTerminatoInst);
+    Value *MDNodeValue =
+        MetadataAsValue::get(Builder->getContext(), BranchMdNode);
+
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, MDNodeValue};
+
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+  }
 };
 } // namespace llvm
 
diff --git a/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll b/llvm/test/CodeGen/SPIRV/structurizer/HLSLBranchHint.ll
similarity index 100%
rename from llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll
rename to llvm/test/CodeGen/SPIRV/structurizer/HLSLBranchHint.ll

>From 6d01f396b71282b3be72633d25caa8dda10dc353 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 15 Nov 2024 19:29:29 +0000
Subject: [PATCH 06/12] removing headers

---
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 1d837296ddc647..57285869b2e338 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -17,7 +17,6 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/CodeGen/IntrinsicLowering.h"
-#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"

>From a0d11883d2f506a89323a5804bc44bf3fc130bb5 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 15 Nov 2024 19:37:47 +0000
Subject: [PATCH 07/12] fixing format

---
 clang/lib/Sema/SemaStmtAttr.cpp           |  2 +-
 llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll | 91 +++++++++++++++++++++++
 2 files changed, 92 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll

diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 68323092cb564d..827a674ece4bc0 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -624,7 +624,7 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
 }
 
 static Attr *handleHLSLBranchHint(Sema &S, Stmt *St, const ParsedAttr &A,
-                                    SourceRange Range) {
+                                  SourceRange Range) {
 
   return ::new (S.Context) HLSLBranchHintAttr(S.Context, A);
 }
diff --git a/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll b/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll
new file mode 100644
index 00000000000000..771f5603cf526f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll
@@ -0,0 +1,91 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+
+define spir_func noundef i32 @test_branch(i32 noundef %X) {
+entry:
+; CHECK-LABEL: ; -- Begin function test_branch
+; OpSelectionMerge %[[#]] DontFlatten
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+
+define spir_func noundef i32 @test_flatten(i32 noundef %X) {
+entry:
+; CHECK-LABEL: ; -- Begin function test_flatten
+; OpSelectionMerge %[[#]] Flatten
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !1
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+define spir_func noundef i32 @test_no_attr(i32 noundef %X) {
+entry:
+; CHECK-LABEL: ; -- Begin function test_no_attr
+; OpSelectionMerge %[[#]] None
+  %X.addr = alloca i32, align 4
+  %resp = alloca i32, align 4
+  store i32 %X, ptr %X.addr, align 4
+  %0 = load i32, ptr %X.addr, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %1 = load i32, ptr %X.addr, align 4
+  %sub = sub nsw i32 0, %1
+  store i32 %sub, ptr %resp, align 4
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %2 = load i32, ptr %X.addr, align 4
+  %mul = mul nsw i32 %2, 2
+  store i32 %mul, ptr %resp, align 4
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %3 = load i32, ptr %resp, align 4
+  ret i32 %3
+}
+
+!0 = !{!"dx.controlflow.hints", i32 1}
+!1 = !{!"dx.controlflow.hints", i32 2}

>From f0ab2617ebf2a46ea7f46a6e2695d0722a516a28 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 21 Nov 2024 01:35:43 +0000
Subject: [PATCH 08/12] removing metadata from IR Files

---
 clang/include/clang/Basic/Attr.td             |  2 +-
 clang/lib/CodeGen/CGStmt.cpp                  | 11 +++----
 clang/lib/CodeGen/CodeGenFunction.cpp         | 32 ++++++++++++-------
 clang/lib/CodeGen/CodeGenFunction.h           |  6 ++--
 clang/lib/CodeGen/CodeGenPGO.cpp              | 13 --------
 clang/lib/Sema/SemaStmtAttr.cpp               | 10 +++---
 ...anchHint.hlsl => HLSLControlFlowHint.hlsl} |  6 ++--
 ...anchHint.hlsl => HLSLControlFlowHint.hlsl} | 10 +++---
 llvm/include/llvm/IR/FixedMetadataKinds.def   |  3 --
 llvm/include/llvm/IR/IRBuilder.h              | 11 ++-----
 llvm/include/llvm/IR/ProfDataUtils.h          |  7 ----
 llvm/lib/IR/ProfDataUtils.cpp                 |  7 ----
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 28 ++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp   | 12 ++++---
 ...SLBranchHint.ll => HLSLControlFlowHint.ll} | 17 ++++++----
 ...SLBranchHint.ll => HLSLControlFlowHint.ll} |  8 ++---
 .../HLSLControlFlowHint.ll}                   |  0
 17 files changed, 94 insertions(+), 89 deletions(-)
 rename clang/test/AST/HLSL/{HLSLBranchHint.hlsl => HLSLControlFlowHint.hlsl} (83%)
 rename clang/test/CodeGenHLSL/{HLSLBranchHint.hlsl => HLSLControlFlowHint.hlsl} (76%)
 rename llvm/test/CodeGen/DirectX/{HLSLBranchHint.ll => HLSLControlFlowHint.ll} (83%)
 rename llvm/test/CodeGen/SPIRV/{structurizer/HLSLBranchHint.ll => HLSLControlFlowHint.ll} (92%)
 rename llvm/test/CodeGen/SPIRV/{HLSLBranchHint.ll => structurizer/HLSLControlFlowHint.ll} (100%)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 6d3e07ce83100b..18cc8e2e38e1b0 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4302,7 +4302,7 @@ def HLSLLoopHint: StmtAttr {
   let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
 }
 
-def HLSLBranchHint: StmtAttr {
+def HLSLControlFlowHint: StmtAttr {
   /// [branch]
   /// [flatten]
   let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index bf82e50c8b59e0..f438bec756e47a 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -730,8 +730,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   bool noinline = false;
   bool alwaysinline = false;
   bool noconvergent = false;
-  HLSLBranchHintAttr::Spelling flattenOrBranch =
-      HLSLBranchHintAttr::SpellingNotCalculated;
+  HLSLControlFlowHintAttr::Spelling flattenOrBranch =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
   const CallExpr *musttail = nullptr;
 
   for (const auto *A : S.getAttrs()) {
@@ -763,8 +763,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
         Builder.CreateAssumption(AssumptionVal);
       }
     } break;
-    case attr::HLSLBranchHint: {
-      flattenOrBranch = cast<HLSLBranchHintAttr>(A)->getSemanticSpelling();
+    case attr::HLSLControlFlowHint: {
+      flattenOrBranch = cast<HLSLControlFlowHintAttr>(A)->getSemanticSpelling();
     } break;
     }
   }
@@ -773,8 +773,7 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
   SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
   SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
   SaveAndRestore save_musttail(MustTailCall, musttail);
-  SaveAndRestore save_flattenOrBranch(HLSLBranchHintAttributedSpelling,
-                                      flattenOrBranch);
+  SaveAndRestore save_flattenOrBranch(HLSLControlFlowAttr, flattenOrBranch);
   EmitStmt(S.getSubStmt(), S.getAttrs());
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index a790d1fde0e104..ab6659b89c597e 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -43,6 +43,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/CRC.h"
 #include "llvm/Support/xxhash.h"
@@ -2052,7 +2053,6 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
 
   llvm::MDNode *Weights = nullptr;
   llvm::MDNode *Unpredictable = nullptr;
-  llvm::MDNode *ControlFlowHint = nullptr;
 
   // If the branch has a condition wrapped by __builtin_unpredictable,
   // create metadata that specifies that the branch is unpredictable.
@@ -2077,18 +2077,28 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
     Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
   }
 
-  switch (HLSLBranchHintAttributedSpelling) {
-
-  case HLSLBranchHintAttr::Microsoft_branch:
-  case HLSLBranchHintAttr::Microsoft_flatten:
-    ControlFlowHint = createControlFlowHint(HLSLBranchHintAttributedSpelling);
-    break;
-  case HLSLBranchHintAttr::SpellingNotCalculated:
+  auto *BrInst = Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights,
+                                      Unpredictable);
+  switch (HLSLControlFlowAttr) {
+  case HLSLControlFlowHintAttr::Microsoft_branch:
+  case HLSLControlFlowHintAttr::Microsoft_flatten: {
+    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
+
+    llvm::ConstantInt *BranchHintConstant =
+        HLSLControlFlowAttr ==
+                HLSLControlFlowHintAttr::Spelling::Microsoft_branch
+            ? llvm::ConstantInt::get(CGM.Int32Ty, 1)
+            : llvm::ConstantInt::get(CGM.Int32Ty, 2);
+
+    SmallVector<llvm::Metadata *, 2> Vals(
+        {MDHelper.createString("hlsl.controlflow.hint"),
+         MDHelper.createConstant(BranchHintConstant)});
+    BrInst->setMetadata("hlsl.controlflow.hint",
+                        llvm::MDNode::get(CGM.getLLVMContext(), Vals));
+  } break;
+  case HLSLControlFlowHintAttr::SpellingNotCalculated:
     break;
   }
-
-  Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable,
-                       ControlFlowHint);
 }
 
 /// ErrorUnsupported - Print out an error that codegen doesn't support the
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index d3f1af6f5a5ce9..be280ce887e3f8 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -616,8 +616,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   bool InNoConvergentAttributedStmt = false;
 
   /// HLSL Branch attribute.
-  HLSLBranchHintAttr::Spelling HLSLBranchHintAttributedSpelling =
-      HLSLBranchHintAttr::SpellingNotCalculated;
+  HLSLControlFlowHintAttr::Spelling HLSLControlFlowAttr =
+      HLSLControlFlowHintAttr::SpellingNotCalculated;
 
   // The CallExpr within the current statement that the musttail attribute
   // applies to.  nullptr if there is no 'musttail' on the current statement.
@@ -1616,8 +1616,6 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Bitmap used by MC/DC to track condition outcomes of a boolean expression.
   Address MCDCCondBitmapAddr = Address::invalid();
 
-  llvm::MDNode *createControlFlowHint(HLSLBranchHintAttr::Spelling S) const;
-
   /// Calculate branch weights appropriate for PGO data
   llvm::MDNode *createProfileWeights(uint64_t TrueCount,
                                      uint64_t FalseCount) const;
diff --git a/clang/lib/CodeGen/CodeGenPGO.cpp b/clang/lib/CodeGen/CodeGenPGO.cpp
index ff3826ca6da854..820bb521ccf850 100644
--- a/clang/lib/CodeGen/CodeGenPGO.cpp
+++ b/clang/lib/CodeGen/CodeGenPGO.cpp
@@ -1451,19 +1451,6 @@ static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
   return Scaled;
 }
 
-llvm::MDNode *
-CodeGenFunction::createControlFlowHint(HLSLBranchHintAttr::Spelling S) const {
-  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
-
-  SmallVector<llvm::Metadata *, 2> Vals(llvm::ArrayRef<llvm::Metadata *>{
-      MDHelper.createString("dx.controlflow.hints"),
-      S == HLSLBranchHintAttr::Spelling::Microsoft_branch
-          ? MDHelper.createConstant(llvm::ConstantInt::get(Int32Ty, 1))
-          : MDHelper.createConstant(llvm::ConstantInt::get(Int32Ty, 2))});
-
-  return llvm::MDNode::get(CGM.getLLVMContext(), Vals);
-}
-
 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
                                                     uint64_t FalseCount) const {
   // Check for empty weights.
diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp
index 827a674ece4bc0..824f1ea3ccd65e 100644
--- a/clang/lib/Sema/SemaStmtAttr.cpp
+++ b/clang/lib/Sema/SemaStmtAttr.cpp
@@ -623,10 +623,10 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
   return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
 }
 
-static Attr *handleHLSLBranchHint(Sema &S, Stmt *St, const ParsedAttr &A,
-                                  SourceRange Range) {
+static Attr *handleHLSLControlFlowHint(Sema &S, Stmt *St, const ParsedAttr &A,
+                                       SourceRange Range) {
 
-  return ::new (S.Context) HLSLBranchHintAttr(S.Context, A);
+  return ::new (S.Context) HLSLControlFlowHintAttr(S.Context, A);
 }
 
 static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
@@ -665,8 +665,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
     return handleLoopHintAttr(S, St, A, Range);
   case ParsedAttr::AT_HLSLLoopHint:
     return handleHLSLLoopHintAttr(S, St, A, Range);
-  case ParsedAttr::AT_HLSLBranchHint:
-    return handleHLSLBranchHint(S, St, A, Range);
+  case ParsedAttr::AT_HLSLControlFlowHint:
+    return handleHLSLControlFlowHint(S, St, A, Range);
   case ParsedAttr::AT_OpenCLUnrollHint:
     return handleOpenCLUnrollHint(S, St, A, Range);
   case ParsedAttr::AT_Suppress:
diff --git a/clang/test/AST/HLSL/HLSLBranchHint.hlsl b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
similarity index 83%
rename from clang/test/AST/HLSL/HLSLBranchHint.hlsl
rename to clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
index 907d6c5ff580c0..754d9de57d83bb 100644
--- a/clang/test/AST/HLSL/HLSLBranchHint.hlsl
+++ b/clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
@@ -2,7 +2,7 @@
 
 // CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
 // CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
-// CHECK-NEXT: HLSLBranchHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
 export int branch(int X){
     int resp;
     [branch] if (X > 0) {
@@ -16,7 +16,7 @@ export int branch(int X){
 
 // CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
 // CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
-// CHECK-NEXT: HLSLBranchHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
+// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
 export int flatten(int X){
     int resp;
     [flatten] if (X > 0) {
@@ -30,7 +30,7 @@ export int flatten(int X){
 
 // CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
 // CHECK-NO: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
-// CHECK-NO: HLSLBranchHintAttr
+// CHECK-NO: -HLSLControlFlowHintAttr
 export int no_attr(int X){
     int resp;
     if (X > 0) {
diff --git a/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
similarity index 76%
rename from clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
rename to clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
index 11b14b1ec43175..4f9982fabb84f7 100644
--- a/clang/test/CodeGenHLSL/HLSLBranchHint.hlsl
+++ b/clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
@@ -4,7 +4,7 @@
 // CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
 // CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
 // CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
-// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_BRANCH:![0-9]+]]
 export int test_branch(int X){
     int resp;
     [branch] if (X > 0) {
@@ -19,7 +19,7 @@ export int test_branch(int X){
 // CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
 // CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
 // CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
-// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !dx.controlflow.hints [[HINT_FLATTEN:![0-9]+]]
+// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_FLATTEN:![0-9]+]]
 export int test_flatten(int X){
     int resp;
     [flatten] if (X > 0) {
@@ -32,7 +32,7 @@ export int test_flatten(int X){
 }
 
 // CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
-// CHECK-NO: !dx.controlflow.hints
+// CHECK-NO: !hlsl.controlflow.hint
 export int test_no_attr(int X){
     int resp;
     if (X > 0) {
@@ -44,5 +44,5 @@ export int test_no_attr(int X){
     return resp;
 }
 
-//CHECK: [[HINT_BRANCH]] = !{!"dx.controlflow.hints", i32 1}
-//CHECK: [[HINT_FLATTEN]] = !{!"dx.controlflow.hints", i32 2}
+//CHECK: [[HINT_BRANCH]] = !{!"hlsl.controlflow.hint", i32 1}
+//CHECK: [[HINT_FLATTEN]] = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/include/llvm/IR/FixedMetadataKinds.def b/llvm/include/llvm/IR/FixedMetadataKinds.def
index a239a076d131f7..df572e8791e13b 100644
--- a/llvm/include/llvm/IR/FixedMetadataKinds.def
+++ b/llvm/include/llvm/IR/FixedMetadataKinds.def
@@ -53,6 +53,3 @@ LLVM_FIXED_MD_KIND(MD_DIAssignID, "DIAssignID", 38)
 LLVM_FIXED_MD_KIND(MD_coro_outside_frame, "coro.outside.frame", 39)
 LLVM_FIXED_MD_KIND(MD_mmra, "mmra", 40)
 LLVM_FIXED_MD_KIND(MD_noalias_addrspace, "noalias.addrspace", 41)
-// TODO: this will likelly be placed somewhere else,
-// so we don't mix dxil/hlsl/spirv and clang metadata
-LLVM_FIXED_MD_KIND(MD_dxil_controlflow_hints, "dx.controlflow.hints", 42)
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 215d3362bd4f6c..23fd8350a29b3d 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -1101,14 +1101,11 @@ class IRBuilderBase {
   /// instruction.
   /// \returns The annotated instruction.
   template <typename InstTy>
-  InstTy *addBranchMetadata(InstTy *I, MDNode *Weights, MDNode *Unpredictable,
-                            MDNode *ControlFlowHint = nullptr) {
+  InstTy *addBranchMetadata(InstTy *I, MDNode *Weights, MDNode *Unpredictable) {
     if (Weights)
       I->setMetadata(LLVMContext::MD_prof, Weights);
     if (Unpredictable)
       I->setMetadata(LLVMContext::MD_unpredictable, Unpredictable);
-    if (ControlFlowHint)
-      I->setMetadata(LLVMContext::MD_dxil_controlflow_hints, ControlFlowHint);
     return I;
   }
 
@@ -1146,11 +1143,9 @@ class IRBuilderBase {
   /// instruction.
   BranchInst *CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False,
                            MDNode *BranchWeights = nullptr,
-                           MDNode *Unpredictable = nullptr,
-                           MDNode *ControlFlowHint = nullptr) {
+                           MDNode *Unpredictable = nullptr) {
     return Insert(addBranchMetadata(BranchInst::Create(True, False, Cond),
-                                    BranchWeights, Unpredictable,
-                                    ControlFlowHint));
+                                    BranchWeights, Unpredictable));
   }
 
   /// Create a conditional 'br Cond, TrueDest, FalseDest'
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 49f7da480f928d..0bea517df832e3 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -48,13 +48,6 @@ bool hasValidBranchWeightMD(const Instruction &I);
 /// Nullptr otherwise.
 MDNode *getBranchWeightMDNode(const Instruction &I);
 
-/// Get the branching metadata information
-///
-/// \param I The Instruction to get the weights from.
-/// \returns A pointer to I's branch weights metadata node, if it exists.
-/// Nullptr otherwise.
-MDNode *getDxBranchHint(const Instruction &I);
-
 /// Get the valid branch weights metadata node
 ///
 /// \param I The Instruction to get the weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 47a5059017a48f..5441228b3291ee 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -150,13 +150,6 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
   return ProfileData;
 }
 
-MDNode *getDxBranchHint(const Instruction &I) {
-  MDNode *Node = I.getMetadata(LLVMContext::MD_dxil_controlflow_hints);
-  if (!isTargetMD(Node, "dx.controlflow.hints", 2))
-    return nullptr;
-  return Node;
-}
-
 MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   auto *ProfileData = getBranchWeightMDNode(I);
   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index be370e10df6943..4c09dcf56c99bf 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -11,6 +11,7 @@
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -21,6 +22,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
@@ -355,6 +357,32 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
       M.getOrInsertNamedMetadata("dx.entryPoints");
   for (auto *Entry : EntryFnMDNodes)
     EntryPointsNamedMD->addOperand(Entry);
+
+  for (auto &F : M) {
+    for (auto &BB : F) {
+      auto *BBTerminatorInst = BB.getTerminator();
+
+      auto *HlslControlFlowMD =
+          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+      if (!HlslControlFlowMD || HlslControlFlowMD->getNumOperands() < 2)
+        continue;
+
+      MDBuilder MDHelper(M.getContext());
+      auto *Op1 =
+          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
+
+      SmallVector<llvm::Metadata *, 2> Vals(
+          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
+                               MDHelper.createConstant(Op1)});
+
+      auto *MDNode = llvm::MDNode::get(M.getContext(), Vals);
+
+      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
+      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
+    }
+    F.clearMetadata();
+  }
 }
 
 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 57285869b2e338..5028e4ff51222b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -23,8 +23,8 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
-#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
@@ -1205,11 +1205,13 @@ class SPIRVStructurizer : public FunctionPass {
   }
 
   void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
-    Instruction *BBTerminatoInst = Builder->GetInsertBlock()->getTerminator();
+    Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
 
-    MDNode *BranchMdNode = getDxBranchHint(*BBTerminatoInst);
-    Value *MDNodeValue =
-        MetadataAsValue::get(Builder->getContext(), BranchMdNode);
+    MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+    if (MDNode && MDNode->getNumOperands() != 2)
+      llvm_unreachable("invalid metadata hlsl.controlflow.hint");
+
+    Value *MDNodeValue = MetadataAsValue::get(Builder->getContext(), MDNode);
 
     llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, MDNodeValue};
 
diff --git a/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
similarity index 83%
rename from llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
rename to llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
index e7128d19283e80..fe66e481359bb7 100644
--- a/llvm/test/CodeGen/DirectX/HLSLBranchHint.ll
+++ b/llvm/test/CodeGen/DirectX/HLSLControlFlowHint.ll
@@ -1,9 +1,10 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -dxil-translate-metadata -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
-; This test make sure LLVM metadata is propagating to DXIL.
+; This test make sure LLVM metadata is being translated into DXIL.
 
 
 ; CHECK: define i32 @test_branch(i32 %X)
+; CHECK-NO: hlsl.controlflow.hint
 ; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_BRANCH:![0-9]+]]
 define i32 @test_branch(i32 %X) {
 entry:
@@ -12,7 +13,7 @@ entry:
   store i32 %X, ptr %X.addr, align 4
   %0 = load i32, ptr %X.addr, align 4
   %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !0
 
 if.then:                                          ; preds = %entry
   %1 = load i32, ptr %X.addr, align 4
@@ -33,6 +34,7 @@ if.end:                                           ; preds = %if.else, %if.then
 
 
 ; CHECK: define i32 @test_flatten(i32 %X)
+; CHECK-NO: hlsl.controlflow.hint
 ; CHECK: br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints [[HINT_FLATTEN:![0-9]+]]
 define i32 @test_flatten(i32 %X) {
 entry:
@@ -41,7 +43,7 @@ entry:
   store i32 %X, ptr %X.addr, align 4
   %0 = load i32, ptr %X.addr, align 4
   %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !1
+  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !1
 
 if.then:                                          ; preds = %entry
   %1 = load i32, ptr %X.addr, align 4
@@ -62,6 +64,7 @@ if.end:                                           ; preds = %if.else, %if.then
 
 
 ; CHECK: define i32 @test_no_attr(i32 %X)
+; CHECK-NO: hlsl.controlflow.hint
 ; CHECK-NO: !dx.controlflow.hints
 define i32 @test_no_attr(i32 %X) {
 entry:
@@ -88,8 +91,8 @@ if.end:                                           ; preds = %if.else, %if.then
   %3 = load i32, ptr %resp, align 4
   ret i32 %3
 }
-
+; CHECK-NO: hlsl.controlflow.hint
 ; CHECK: [[HINT_BRANCH]] = !{!"dx.controlflow.hints", i32 1}
 ; CHECK: [[HINT_FLATTEN]] = !{!"dx.controlflow.hints", i32 2}
-!0 = !{!"dx.controlflow.hints", i32 1}
-!1 = !{!"dx.controlflow.hints", i32 2}
+!0 = !{!"hlsl.controlflow.hint", i32 1}
+!1 = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/HLSLBranchHint.ll b/llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll
similarity index 92%
rename from llvm/test/CodeGen/SPIRV/structurizer/HLSLBranchHint.ll
rename to llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll
index 771f5603cf526f..848eaf70f5a199 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/HLSLBranchHint.ll
+++ b/llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll
@@ -11,7 +11,7 @@ entry:
   store i32 %X, ptr %X.addr, align 4
   %0 = load i32, ptr %X.addr, align 4
   %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !0
 
 if.then:                                          ; preds = %entry
   %1 = load i32, ptr %X.addr, align 4
@@ -40,7 +40,7 @@ entry:
   store i32 %X, ptr %X.addr, align 4
   %0 = load i32, ptr %X.addr, align 4
   %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !1
+  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !1
 
 if.then:                                          ; preds = %entry
   %1 = load i32, ptr %X.addr, align 4
@@ -87,5 +87,5 @@ if.end:                                           ; preds = %if.else, %if.then
   ret i32 %3
 }
 
-!0 = !{!"dx.controlflow.hints", i32 1}
-!1 = !{!"dx.controlflow.hints", i32 2}
+!0 = !{!"hlsl.controlflow.hint", i32 1}
+!1 = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll b/llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll
similarity index 100%
rename from llvm/test/CodeGen/SPIRV/HLSLBranchHint.ll
rename to llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll

>From 7213095752c06ada834c33942df420c9845826a5 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 21 Nov 2024 01:48:22 +0000
Subject: [PATCH 09/12] removing unecessary headers

---
 clang/lib/CodeGen/CodeGenFunction.cpp             | 1 -
 llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp | 1 -
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp       | 1 -
 3 files changed, 3 deletions(-)

diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index ab6659b89c597e..657559f3d915eb 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -43,7 +43,6 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
-#include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/CRC.h"
 #include "llvm/Support/xxhash.h"
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 4c09dcf56c99bf..abad9e1c11b22d 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -11,7 +11,6 @@
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
-#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 5028e4ff51222b..bd6352b2089adf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -24,7 +24,6 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/InitializePasses.h"
-#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/LoopSimplify.h"
 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"

>From 1ea547b037bdebcedc6a4f87de521eb1c772e9ff Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 21 Nov 2024 17:44:00 +0000
Subject: [PATCH 10/12] removing unecessary test

---
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 58 ++++++------
 .../test/CodeGen/SPIRV/HLSLControlFlowHint.ll | 91 -------------------
 .../SPIRV/structurizer/HLSLControlFlowHint.ll |  8 +-
 3 files changed, 36 insertions(+), 121 deletions(-)
 delete mode 100644 llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index abad9e1c11b22d..b3791b52b4ca19 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -301,6 +301,36 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
 }
 
+// TODO: We might need to refactor this to be more generic,
+// in case we need more metadata to be replaced.
+static void replaceMetadata(Module &M) {
+  for (auto &F : M) {
+    for (auto &BB : F) {
+      auto *BBTerminatorInst = BB.getTerminator();
+
+      auto *HlslControlFlowMD =
+          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
+
+      if (!HlslControlFlowMD || HlslControlFlowMD->getNumOperands() < 2)
+        continue;
+
+      MDBuilder MDHelper(M.getContext());
+      auto *Op1 =
+          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
+
+      SmallVector<llvm::Metadata *, 2> Vals(
+          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
+                               MDHelper.createConstant(Op1)});
+
+      auto *MDNode = llvm::MDNode::get(M.getContext(), Vals);
+
+      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
+      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
+    }
+    F.clearMetadata();
+  }
+}
+
 static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                               const Resources &MDResources,
                               const ComputedShaderFlags &ShaderFlags,
@@ -356,32 +386,6 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
       M.getOrInsertNamedMetadata("dx.entryPoints");
   for (auto *Entry : EntryFnMDNodes)
     EntryPointsNamedMD->addOperand(Entry);
-
-  for (auto &F : M) {
-    for (auto &BB : F) {
-      auto *BBTerminatorInst = BB.getTerminator();
-
-      auto *HlslControlFlowMD =
-          BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
-
-      if (!HlslControlFlowMD || HlslControlFlowMD->getNumOperands() < 2)
-        continue;
-
-      MDBuilder MDHelper(M.getContext());
-      auto *Op1 =
-          mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
-
-      SmallVector<llvm::Metadata *, 2> Vals(
-          ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
-                               MDHelper.createConstant(Op1)});
-
-      auto *MDNode = llvm::MDNode::get(M.getContext(), Vals);
-
-      BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
-      BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
-    }
-    F.clearMetadata();
-  }
 }
 
 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
@@ -393,6 +397,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
   translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
+  replaceMetadata(M);
 
   return PreservedAnalyses::all();
 }
@@ -426,6 +431,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
     translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
+    replaceMetadata(M);
     return true;
   }
 };
diff --git a/llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll b/llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll
deleted file mode 100644
index 848eaf70f5a199..00000000000000
--- a/llvm/test/CodeGen/SPIRV/HLSLControlFlowHint.ll
+++ /dev/null
@@ -1,91 +0,0 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-
-
-define spir_func noundef i32 @test_branch(i32 noundef %X) {
-entry:
-; CHECK-LABEL: ; -- Begin function test_branch
-; OpSelectionMerge %[[#]] DontFlatten
-  %X.addr = alloca i32, align 4
-  %resp = alloca i32, align 4
-  store i32 %X, ptr %X.addr, align 4
-  %0 = load i32, ptr %X.addr, align 4
-  %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !0
-
-if.then:                                          ; preds = %entry
-  %1 = load i32, ptr %X.addr, align 4
-  %sub = sub nsw i32 0, %1
-  store i32 %sub, ptr %resp, align 4
-  br label %if.end
-
-if.else:                                          ; preds = %entry
-  %2 = load i32, ptr %X.addr, align 4
-  %mul = mul nsw i32 %2, 2
-  store i32 %mul, ptr %resp, align 4
-  br label %if.end
-
-if.end:                                           ; preds = %if.else, %if.then
-  %3 = load i32, ptr %resp, align 4
-  ret i32 %3
-}
-
-
-define spir_func noundef i32 @test_flatten(i32 noundef %X) {
-entry:
-; CHECK-LABEL: ; -- Begin function test_flatten
-; OpSelectionMerge %[[#]] Flatten
-  %X.addr = alloca i32, align 4
-  %resp = alloca i32, align 4
-  store i32 %X, ptr %X.addr, align 4
-  %0 = load i32, ptr %X.addr, align 4
-  %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !1
-
-if.then:                                          ; preds = %entry
-  %1 = load i32, ptr %X.addr, align 4
-  %sub = sub nsw i32 0, %1
-  store i32 %sub, ptr %resp, align 4
-  br label %if.end
-
-if.else:                                          ; preds = %entry
-  %2 = load i32, ptr %X.addr, align 4
-  %mul = mul nsw i32 %2, 2
-  store i32 %mul, ptr %resp, align 4
-  br label %if.end
-
-if.end:                                           ; preds = %if.else, %if.then
-  %3 = load i32, ptr %resp, align 4
-  ret i32 %3
-}
-
-define spir_func noundef i32 @test_no_attr(i32 noundef %X) {
-entry:
-; CHECK-LABEL: ; -- Begin function test_no_attr
-; OpSelectionMerge %[[#]] None
-  %X.addr = alloca i32, align 4
-  %resp = alloca i32, align 4
-  store i32 %X, ptr %X.addr, align 4
-  %0 = load i32, ptr %X.addr, align 4
-  %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else
-
-if.then:                                          ; preds = %entry
-  %1 = load i32, ptr %X.addr, align 4
-  %sub = sub nsw i32 0, %1
-  store i32 %sub, ptr %resp, align 4
-  br label %if.end
-
-if.else:                                          ; preds = %entry
-  %2 = load i32, ptr %X.addr, align 4
-  %mul = mul nsw i32 %2, 2
-  store i32 %mul, ptr %resp, align 4
-  br label %if.end
-
-if.end:                                           ; preds = %if.else, %if.then
-  %3 = load i32, ptr %resp, align 4
-  ret i32 %3
-}
-
-!0 = !{!"hlsl.controlflow.hint", i32 1}
-!1 = !{!"hlsl.controlflow.hint", i32 2}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll b/llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll
index 771f5603cf526f..848eaf70f5a199 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/HLSLControlFlowHint.ll
@@ -11,7 +11,7 @@ entry:
   store i32 %X, ptr %X.addr, align 4
   %0 = load i32, ptr %X.addr, align 4
   %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !0
+  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !0
 
 if.then:                                          ; preds = %entry
   %1 = load i32, ptr %X.addr, align 4
@@ -40,7 +40,7 @@ entry:
   store i32 %X, ptr %X.addr, align 4
   %0 = load i32, ptr %X.addr, align 4
   %cmp = icmp sgt i32 %0, 0
-  br i1 %cmp, label %if.then, label %if.else, !dx.controlflow.hints !1
+  br i1 %cmp, label %if.then, label %if.else, !hlsl.controlflow.hint !1
 
 if.then:                                          ; preds = %entry
   %1 = load i32, ptr %X.addr, align 4
@@ -87,5 +87,5 @@ if.end:                                           ; preds = %if.else, %if.then
   ret i32 %3
 }
 
-!0 = !{!"dx.controlflow.hints", i32 1}
-!1 = !{!"dx.controlflow.hints", i32 2}
+!0 = !{!"hlsl.controlflow.hint", i32 1}
+!1 = !{!"hlsl.controlflow.hint", i32 2}

>From bcc72ce0039838d41031a09b8d68acb03f5e975e Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 25 Nov 2024 23:15:32 +0000
Subject: [PATCH 11/12] fixing tests

---
 .../Target/DirectX/DXILTranslateMetadata.cpp  |  6 ++--
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 32 ++++++++++++-------
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp   |  5 +--
 3 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index b3791b52b4ca19..4081e139044ed1 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -303,7 +303,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
 
 // TODO: We might need to refactor this to be more generic,
 // in case we need more metadata to be replaced.
-static void replaceMetadata(Module &M) {
+static void translateBranchMetadata(Module &M) {
   for (auto &F : M) {
     for (auto &BB : F) {
       auto *BBTerminatorInst = BB.getTerminator();
@@ -397,7 +397,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
   translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
-  replaceMetadata(M);
+  translateBranchMetadata(M);
 
   return PreservedAnalyses::all();
 }
@@ -431,7 +431,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
     translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
-    replaceMetadata(M);
+    translateBranchMetadata(M);
     return true;
   }
 };
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5576f46cf89f9a..de24d0a916730b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2706,18 +2706,26 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_selection_merge: {
 
     auto SelectionControl = SPIRV::SelectionControl::None;
-    const MDNode *MDOp =
-        I.getOperand(I.getNumExplicitOperands() - 1).getMetadata();
-    if (MDOp->getNumOperands() > 0) {
-      ConstantInt *BranchHint =
-          mdconst::extract<ConstantInt>(MDOp->getOperand(1));
-
-      if (BranchHint->equalsInt(2))
-        SelectionControl = SPIRV::SelectionControl::Flatten;
-      else if (BranchHint->equalsInt(1))
-        SelectionControl = SPIRV::SelectionControl::DontFlatten;
-      else
-        llvm_unreachable("Invalid value for SelectionControl");
+    auto LastOp = I.getOperand(I.getNumExplicitOperands() - 1);
+
+    assert((LastOp.isMBB() || LastOp.isMetadata()) &&
+           "Invalid type for last Machine Operand");
+
+    if (LastOp.isMetadata()) {
+      const MDNode *MDOp = LastOp.getMetadata();
+      if (MDOp->getNumOperands() == 2) {
+        if (ConstantInt *BranchHint =
+                mdconst::extract<ConstantInt>(MDOp->getOperand(1))) {
+          if (BranchHint->equalsInt(2))
+            SelectionControl = SPIRV::SelectionControl::Flatten;
+          else if (BranchHint->equalsInt(1))
+            SelectionControl = SPIRV::SelectionControl::DontFlatten;
+          else
+            llvm_unreachable("Invalid value for SelectionControl");
+        } else {
+          llvm_unreachable("Invalid value for SelectionControl");
+        }
+      }
     }
 
     auto MIB =
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index bd6352b2089adf..f76f724de48649 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -1207,8 +1207,9 @@ class SPIRVStructurizer : public FunctionPass {
     Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
 
     MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
-    if (MDNode && MDNode->getNumOperands() != 2)
-      llvm_unreachable("invalid metadata hlsl.controlflow.hint");
+    if (MDNode)
+      assert(MDNode->getNumOperands() == 2 &&
+             "invalid metadata hlsl.controlflow.hint");
 
     Value *MDNodeValue = MetadataAsValue::get(Builder->getContext(), MDNode);
 

>From 9db2081895196cad386680891153614a5e4db948 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 3 Dec 2024 20:12:29 +0000
Subject: [PATCH 12/12] Addressign pr comments

---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +-
 llvm/lib/IR/Verifier.cpp                      |  1 +
 .../Target/DirectX/DXILTranslateMetadata.cpp  |  5 +++-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 30 +++++--------------
 llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp   | 14 ++++++---
 5 files changed, 24 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index f29eb7ee22b2d2..9b635cac95e5db 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
-  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_any_ty, llvm_i32_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 85e60452b75c3d..9c5824d9146109 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -7237,6 +7237,7 @@ struct VerifierLegacyPass : public FunctionPass {
 
   bool runOnFunction(Function &F) override {
     if (!V->verify(F) && FatalErrors) {
+      auto x = V->verify(F);
       errs() << "in function " << F.getName() << '\n';
       report_fatal_error("Broken function found, compilation aborted!");
     }
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 4081e139044ed1..93ecc05aaf6087 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -311,9 +311,12 @@ static void translateBranchMetadata(Module &M) {
       auto *HlslControlFlowMD =
           BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
 
-      if (!HlslControlFlowMD || HlslControlFlowMD->getNumOperands() < 2)
+      if (!HlslControlFlowMD)
         continue;
 
+      assert(HlslControlFlowMD->getNumOperands() == 2 &&
+             "invalid operands for hlsl.controlflow.hint");
+
       MDBuilder MDHelper(M.getContext());
       auto *Op1 =
           mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index de24d0a916730b..529a1b5f263671 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2705,28 +2705,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   }
   case Intrinsic::spv_selection_merge: {
 
-    auto SelectionControl = SPIRV::SelectionControl::None;
-    auto LastOp = I.getOperand(I.getNumExplicitOperands() - 1);
-
-    assert((LastOp.isMBB() || LastOp.isMetadata()) &&
-           "Invalid type for last Machine Operand");
-
-    if (LastOp.isMetadata()) {
-      const MDNode *MDOp = LastOp.getMetadata();
-      if (MDOp->getNumOperands() == 2) {
-        if (ConstantInt *BranchHint =
-                mdconst::extract<ConstantInt>(MDOp->getOperand(1))) {
-          if (BranchHint->equalsInt(2))
-            SelectionControl = SPIRV::SelectionControl::Flatten;
-          else if (BranchHint->equalsInt(1))
-            SelectionControl = SPIRV::SelectionControl::DontFlatten;
-          else
-            llvm_unreachable("Invalid value for SelectionControl");
-        } else {
-          llvm_unreachable("Invalid value for SelectionControl");
-        }
-      }
-    }
+    int64_t SelectionControl = SPIRV::SelectionControl::None;
+    auto LastOp = I.getOperand(I.getNumOperands() - 1);
+
+    auto BranchHint = LastOp.getImm();
+    if (BranchHint == 2)
+      SelectionControl = SPIRV::SelectionControl::Flatten;
+    else if (BranchHint == 1)
+      SelectionControl = SPIRV::SelectionControl::DontFlatten;
 
     auto MIB =
         BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index f76f724de48649..d60f1abf676783 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -1207,15 +1207,21 @@ class SPIRVStructurizer : public FunctionPass {
     Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
 
     MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
-    if (MDNode)
+
+    ConstantInt *BranchHint = llvm::ConstantInt::get(Builder->getInt32Ty(), 0);
+
+    if (MDNode) {
       assert(MDNode->getNumOperands() == 2 &&
              "invalid metadata hlsl.controlflow.hint");
+      BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
 
-    Value *MDNodeValue = MetadataAsValue::get(Builder->getContext(), MDNode);
+      assert(BranchHint && "invalid metadata value for hlsl.controlflow.hint");
+    }
 
-    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, MDNodeValue};
+    llvm::SmallVector<llvm::Value *, 2> Args = {MergeAddress, BranchHint};
 
-    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+    Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
+                             {MergeAddress->getType()}, {Args});
   }
 };
 } // namespace llvm



More information about the cfe-commits mailing list