[clang] [clang][SPIR-V] Always add convergence intrinsics (PR #88918)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Tue May 14 02:32:41 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/88918

>From 440cdfa4132a969702348c32f2810924012c5ea6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 15 Apr 2024 17:05:40 +0200
Subject: [PATCH 1/6] [clang][SPIR-V] Always add convervence intrinsics
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

PR #80680 added bits in the codegen to lazily add convergence intrinsics
when required. This logic relied on the LoopStack. The issue is
when parsing the condition, the loopstack doesn't yet reflect the
correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available.
The solution which seemed the simplest is to greedily generate the
tokens when we generate SPIR-V.

Fixes #88144

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  88 +------------
 clang/lib/CodeGen/CGCall.cpp                  |   3 +
 clang/lib/CodeGen/CGStmt.cpp                  |  94 ++++++++++++++
 clang/lib/CodeGen/CodeGenFunction.cpp         |   9 ++
 clang/lib/CodeGen/CodeGenFunction.h           |   9 +-
 .../builtins/RWBuffer-constructor.hlsl        |   1 -
 .../CodeGenHLSL/convergence/do.while.hlsl     |  90 +++++++++++++
 clang/test/CodeGenHLSL/convergence/for.hlsl   | 121 ++++++++++++++++++
 clang/test/CodeGenHLSL/convergence/while.hlsl | 119 +++++++++++++++++
 9 files changed, 445 insertions(+), 89 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/convergence/do.while.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/for.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/while.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f9ee93049b12d..e251091c6ce3e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1141,91 +1141,8 @@ struct BitTest {
   static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
 };
 
-// Returns the first convergence entry/loop/anchor instruction found in |BB|.
-// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
-  for (auto &I : *BB) {
-    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
-    if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
-      return II;
-  }
-  return nullptr;
-}
-
 } // namespace
 
-llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken) {
-  llvm::Value *bundleArgs[] = {ParentToken};
-  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
-  auto Output = llvm::CallBase::addOperandBundle(
-      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
-  Input->replaceAllUsesWith(Output);
-  Input->eraseFromParent();
-  return Output;
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                          llvm::Value *ParentToken) {
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto CB = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_loop, {}, {});
-  Builder.restoreIP(IP);
-
-  auto I = addConvergenceControlToken(CB, ParentToken);
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
-
-  // Adding a convergence token requires the function to be marked as
-  // convergent.
-  F->setConvergent();
-
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_entry, {}, {});
-  assert(isa<llvm::IntrinsicInst>(I));
-  Builder.restoreIP(IP);
-
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
-  assert(LI != nullptr);
-
-  auto *token = getConvergenceToken(LI->getHeader());
-  if (token)
-    return token;
-
-  llvm::IntrinsicInst *PII =
-      LI->getParent()
-          ? emitConvergenceLoopToken(
-                LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
-          : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
-
-  return emitConvergenceLoopToken(LI->getHeader(), PII);
-}
-
-llvm::CallBase *
-CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
-  llvm::Value *ParentToken =
-      LoopStack.hasInfo()
-          ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
-          : getOrEmitConvergenceEntryToken(Input->getFunction());
-  return addConvergenceControlToken(Input, ParentToken);
-}
-
 BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
   switch (BuiltinID) {
     // Main portable variants.
@@ -18402,12 +18319,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
-    auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
-    if (getTarget().getTriple().isSPIRVLogical())
-      CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
-    return CI;
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 0c7eef59db53c..17248c8078088 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4830,6 +4830,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   llvm::CallInst *call = Builder.CreateCall(
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
+
+  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 479945e3b4cb5..127a72c8cbcdd 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -978,6 +978,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
+        LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+
   // Create an exit block for when the condition fails, which will
   // also become the break target.
   JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1079,6 +1083,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1098,6 +1105,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
     EmitBlockWithFallThrough(LoopBody, S.getBody());
   else
     EmitBlockWithFallThrough(LoopBody, &S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+
   {
     RunCleanupsScope BodyScope(*this);
     EmitStmt(S.getBody());
@@ -1151,6 +1163,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1170,6 +1185,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1279,6 +1298,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void
@@ -1301,6 +1323,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1369,6 +1395,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3158,3 +3187,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
 
   return F;
 }
+
+namespace {
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullptr otherwise.
+llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+  for (auto &I : *BB) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
+      return II;
+  }
+  return nullptr;
+}
+
+} // namespace
+
+llvm::CallBase *
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken) {
+  llvm::Value *bundleArgs[] = {ParentToken};
+  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+  auto Output = llvm::CallBase::addOperandBundle(
+      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+  Input->replaceAllUsesWith(Output);
+  Input->eraseFromParent();
+  return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
+                                          llvm::Value *ParentToken) {
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+
+  if (BB->empty())
+    Builder.SetInsertPoint(BB);
+  else
+    Builder.SetInsertPoint(&BB->front());
+
+  auto CB = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_loop, {}, {});
+  Builder.restoreIP(IP);
+
+  auto I = addConvergenceControlToken(CB, ParentToken);
+  return cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+  auto *BB = &F->getEntryBlock();
+  auto *token = getConvergenceToken(BB);
+  if (token)
+    return token;
+
+  // Adding a convergence token requires the function to be marked as
+  // convergent.
+  F->setConvergent();
+
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto I = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_entry, {}, {});
+  assert(isa<llvm::IntrinsicInst>(I));
+  Builder.restoreIP(IP);
+
+  return cast<llvm::IntrinsicInst>(I);
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 9f16fcb438557..2e37a1540b6f2 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -353,6 +353,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(DeferredDeactivationCleanupStack.empty() &&
          "mismatched activate/deactivate of cleanups!");
 
+  if (getTarget().getTriple().isSPIRVLogical()) {
+    ConvergenceTokenStack.pop_back();
+    assert(ConvergenceTokenStack.empty() &&
+           "mismatched push/pop in convergence stack!");
+  }
+
   bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
     && NumSimpleReturnExprs == NumReturnExprs
     && ReturnBlock.getBlock()->use_empty();
@@ -1277,6 +1283,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (CurFuncDecl)
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
 void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e1e687af6a781..c642c4c9e7497 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -315,6 +315,9 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Stack to track the Logical Operator recursion nest for MC/DC.
   SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
 
+  /// Stack to track the controlled convergence tokens.
+  SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+
   /// Number of nested loop to be consumed by the last surrounding
   /// loop-associated directive.
   int ExpectedOMPLoopDepth = 0;
@@ -5076,7 +5079,11 @@ class CodeGenFunction : public CodeGenTypeCache {
                                      const llvm::Twine &Name = "");
   // Adds a convergence_ctrl token to |Input| and emits the required parent
   // convergence instructions.
-  llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
+  template <typename CallType>
+  CallType *addControlledConvergenceToken(CallType *Input) {
+    return dyn_cast<CallType>(
+        addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
+  }
 
 private:
   // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 74b3f59bf7600..e51eac7f57c2d 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,4 +1,3 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
 
 RWBuffer<float> Buf;
diff --git a/clang/test/CodeGenHLSL/convergence/do.while.hlsl b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
new file mode 100644
index 0000000000000..ea5a45ba8fd78
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  do {
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  do {
+    foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  do {
+    if (cond())
+      foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  do {
+    if (cond()) {
+      foo();
+      break;
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  do {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: while.cond:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/for.hlsl b/clang/test/CodeGenHLSL/convergence/for.hlsl
new file mode 100644
index 0000000000000..180fae74ba751
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/for.hlsl
@@ -0,0 +1,121 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+bool cond2();
+void foo();
+
+void test1() {
+  for (;;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  for (;cond();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  for (cond();;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  for (cond();cond2();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  for (cond();cond2();foo()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test6() {
+  for (cond();cond2();foo()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:   [[C1:%[a-zA-Z0-9]+]] = call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br i1 [[C1]], label %if.then, label %if.end
+// CHECK: if.then:
+// CHECK    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %for.end
+// CHECK: if.end:
+// CHECK:   br label %for.inc
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test7() {
+  for (cond();;) {
+    for (cond();;) {
+      foo();
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test7v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.cond3:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T2]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/while.hlsl b/clang/test/CodeGenHLSL/convergence/while.hlsl
new file mode 100644
index 0000000000000..92777000190d2
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/while.hlsl
@@ -0,0 +1,119 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  while (cond()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  while (cond()) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.body:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  while (cond()) {
+    if (cond())
+      break;
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.cond
+
+void test4() {
+  while (cond()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   br label %while.cond
+
+void test5() {
+  while (cond()) {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK:   br label %while.end
+
+void test6() {
+  while (cond()) {
+    while (cond()) {
+    }
+
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }

>From 221f2abd9782cf09085840b6657a24c62673b5fd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 16 Apr 2024 18:40:09 +0200
Subject: [PATCH 2/6] feedback review

---
 clang/lib/CodeGen/CGCall.cpp        | 2 +-
 clang/lib/CodeGen/CodeGenFunction.h | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 17248c8078088..cd1d2b32ead44 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4832,7 +4832,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   call->setCallingConv(getRuntimeCC());
 
   if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
-    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
+    return addControlledConvergenceToken(call);
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index c642c4c9e7497..362f4a5fe72a6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -5081,7 +5081,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   // convergence instructions.
   template <typename CallType>
   CallType *addControlledConvergenceToken(CallType *Input) {
-    return dyn_cast<CallType>(
+    return cast<CallType>(
         addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
   }
 

>From 0efa77215fd0d21c01488a4907fc912fd64e97ea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 17 Apr 2024 14:12:19 +0200
Subject: [PATCH 3/6] review feedback target check
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGCall.cpp          |  4 ++--
 clang/lib/CodeGen/CGStmt.cpp          | 16 ++++++++--------
 clang/lib/CodeGen/CodeGenFunction.cpp |  4 ++--
 clang/lib/CodeGen/CodeGenModule.h     |  6 ++++++
 4 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index cd1d2b32ead44..1b4ca2a8b2fe8 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4831,7 +4831,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
 
-  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+  if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
     return addControlledConvergenceToken(call);
   return call;
 }
@@ -5733,7 +5733,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   if (!CI->getType()->isVoidTy())
     CI->setName("call");
 
-  if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
+  if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
     CI = addControlledConvergenceToken(CI);
 
   // Update largest vector width from the return type.
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 127a72c8cbcdd..e2f5f17ac02db 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -978,7 +978,7 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
         LoopHeader.getBlock(), ConvergenceTokenStack.back()));
 
@@ -1084,7 +1084,7 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1106,7 +1106,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   else
     EmitBlockWithFallThrough(LoopBody, &S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
 
@@ -1164,7 +1164,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1185,7 +1185,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
 
@@ -1299,7 +1299,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1323,7 +1323,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
 
@@ -1396,7 +1396,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 2e37a1540b6f2..34dc0bd5c15bc 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -353,7 +353,7 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(DeferredDeactivationCleanupStack.empty() &&
          "mismatched activate/deactivate of cleanups!");
 
-  if (getTarget().getTriple().isSPIRVLogical()) {
+  if (CGM.shouldEmitConvergenceTokens()) {
     ConvergenceTokenStack.pop_back();
     assert(ConvergenceTokenStack.empty() &&
            "mismatched push/pop in convergence stack!");
@@ -1284,7 +1284,7 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index be43a18fc6085..dcf55df4f07d8 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1586,6 +1586,12 @@ class CodeGenModule : public CodeGenTypeCache {
   void AddGlobalDtor(llvm::Function *Dtor, int Priority = 65535,
                      bool IsDtorAttrFunc = false);
 
+  // Return whether structured convergence intrinsics should be generated for
+  // this target.
+  bool shouldEmitConvergenceTokens() const {
+    return getTriple().isSPIRVLogical();
+  }
+
 private:
   llvm::Constant *GetOrCreateLLVMFunction(
       StringRef MangledName, llvm::Type *Ty, GlobalDecl D, bool ForVTable,

>From 94609c5f7c78770add820700ce1a254b21ff72d4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 7 May 2024 10:33:10 +0200
Subject: [PATCH 4/6] review feedback

---
 clang/lib/CodeGen/CGStmt.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index e2f5f17ac02db..36776846cd446 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -3218,26 +3218,25 @@ llvm::IntrinsicInst *
 CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
                                           llvm::Value *ParentToken) {
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
-
   if (BB->empty())
     Builder.SetInsertPoint(BB);
   else
-    Builder.SetInsertPoint(&BB->front());
+    Builder.SetInsertPoint(BB->getFirstInsertionPt());
 
-  auto CB = Builder.CreateIntrinsic(
+  llvm::CallBase *CB = Builder.CreateIntrinsic(
       llvm::Intrinsic::experimental_convergence_loop, {}, {});
   Builder.restoreIP(IP);
 
-  auto I = addConvergenceControlToken(CB, ParentToken);
+  llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
   return cast<llvm::IntrinsicInst>(I);
 }
 
 llvm::IntrinsicInst *
 CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
+  llvm::BasicBlock *BB = &F->getEntryBlock();
+  llvm::IntrinsicInst *Token = getConvergenceToken(BB);
+  if (Token)
+    return Token;
 
   // Adding a convergence token requires the function to be marked as
   // convergent.
@@ -3245,7 +3244,7 @@ CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
 
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
   Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
+  llvm::CallBase *I = Builder.CreateIntrinsic(
       llvm::Intrinsic::experimental_convergence_entry, {}, {});
   assert(isa<llvm::IntrinsicInst>(I));
   Builder.restoreIP(IP);

>From eca3e6090f68c450cb81faf849110f906282e0b4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 7 May 2024 15:10:52 +0200
Subject: [PATCH 5/6] fix changing register name in tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/test/CodeGenHLSL/builtins/lerp.hlsl |  82 +++++++++--------
 clang/test/CodeGenHLSL/builtins/mad.hlsl  | 104 +++++++++++++++++-----
 2 files changed, 127 insertions(+), 59 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index 87b2e3af57656..bbb419acaf3ba 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -14,88 +14,98 @@
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK
 
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call half @llvm.spv.lerp.f16(half %0, half %1, half %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call half @llvm.dx.lerp.f16(half %{{.*}}, half %{{.*}}, half %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call half @llvm.spv.lerp.f16(half %{{.*}}, half %{{.*}}, half %{{.*}})
 // NATIVE_HALF: ret half %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
-// SPIR_NO_HALF: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %0, float %1, float %2)
+// DXIL_NO_HALF: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
 // NO_HALF: ret float %hlsl.lerp
 half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.spv.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.spv.lerp.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}})
 // NATIVE_HALF: ret <2 x half> %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
-// SPIR_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// DXIL_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
 // NO_HALF: ret <2 x float> %hlsl.lerp
 half2 test_lerp_half2(half2 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.spv.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, <3 x half> %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.spv.lerp.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, <3 x half> %{{.*}})
 // NATIVE_HALF: ret <3 x half> %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
-// SPIR_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// DXIL_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
 // NO_HALF: ret <3 x float> %hlsl.lerp
 half3 test_lerp_half3(half3 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.spv.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x half> %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.spv.lerp.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x half> %{{.*}})
 // NATIVE_HALF: ret <4 x half> %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
-// SPIR_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// DXIL_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // NO_HALF: ret <4 x float> %hlsl.lerp
 half4 test_lerp_half4(half4 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
-// SPIR_CHECK: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %0, float %1, float %2)
+// DXIL_CHECK: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
 // CHECK: ret float %hlsl.lerp
 float test_lerp_float(float p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
 // CHECK: ret <2 x float> %hlsl.lerp
 float2 test_lerp_float2(float2 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
 // CHECK: ret <3 x float> %hlsl.lerp
 float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: %[[b:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %[[c:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
+// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
 // CHECK: ret <2 x float> %hlsl.lerp
 float2 test_lerp_float2_splat(float p0, float2 p1) { return lerp(p0, p1, p1); }
 
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: %[[b:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %[[c:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
+// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
 // CHECK: ret <3 x float> %hlsl.lerp
 float3 test_lerp_float3_splat(float p0, float3 p1) { return lerp(p0, p1, p1); }
 
-// DXIL_CHECK:  %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
-// SPIR_CHECK:  %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK: %[[b:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// CHECK: %[[c:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
+// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
 // CHECK:  ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4_splat(float p0, float4 p1) { return lerp(p0, p1, p1); }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[a:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %[[b:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %conv = sitofp i32 {{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %[[a]], <2 x float> %[[b]], <2 x float> %splat.splat)
+// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %[[a]], <2 x float> %[[b]], <2 x float> %splat.splat)
 // CHECK: ret <2 x float> %hlsl.lerp
 float2 test_lerp_float2_int_splat(float2 p0, int p1) {
   return lerp(p0, p0, p1);
 }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[a:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %[[b:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %conv = sitofp i32 {{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// DXIL_CHECK:  %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
-// SPIR_CHECK:  %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// DXIL_CHECK:  %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %[[a]], <3 x float> %[[b]], <3 x float> %splat.splat)
+// SPIR_CHECK:  %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %[[a]], <3 x float> %[[b]], <3 x float> %splat.splat)
 // CHECK: ret <3 x float> %hlsl.lerp
 float3 test_lerp_float3_int_splat(float3 p0, int p1) {
   return lerp(p0, p0, p1);
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index b4dc636b00b71..559e1d1dd3903 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -64,59 +64,107 @@ int16_t3 test_mad_int16_t3(int16_t3 p0, int16_t3 p1, int16_t3 p2) { return mad(p
 int16_t4 test_mad_int16_t4(int16_t4 p0, int16_t4 p1, int16_t4 p2) { return mad(p0, p1, p2); }
 #endif // __HLSL_ENABLE_16_BIT
 
-// NATIVE_HALF: %hlsl.fmad = call half @llvm.fmuladd.f16(half %0, half %1, half %2)
+// NATIVE_HALF: %[[p0:.*]] = load half, ptr %p0.addr, align 2
+// NATIVE_HALF: %[[p1:.*]] = load half, ptr %p1.addr, align 2
+// NATIVE_HALF: %[[p2:.*]] = load half, ptr %p2.addr, align 2
+// NATIVE_HALF: %hlsl.fmad = call half @llvm.fmuladd.f16(half %[[p0]], half %[[p1]], half %[[p2]])
 // NATIVE_HALF: ret half %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// NO_HALF: %[[p0:.*]] = load float, ptr %p0.addr, align 4
+// NO_HALF: %[[p1:.*]] = load float, ptr %p1.addr, align 4
+// NO_HALF: %[[p2:.*]] = load float, ptr %p2.addr, align 4
+// NO_HALF: %hlsl.fmad = call float @llvm.fmuladd.f32(float %[[p0]], float %[[p1]], float %[[p2]])
 // NO_HALF: ret float %hlsl.fmad
 half test_mad_half(half p0, half p1, half p2) { return mad(p0, p1, p2); }
 
-// NATIVE_HALF: %hlsl.fmad = call <2 x half>  @llvm.fmuladd.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// NATIVE_HALF: %[[p0:.*]] = load <2 x half>, ptr %p0.addr, align 4
+// NATIVE_HALF: %[[p1:.*]] = load <2 x half>, ptr %p1.addr, align 4
+// NATIVE_HALF: %[[p2:.*]] = load <2 x half>, ptr %p2.addr, align 4
+// NATIVE_HALF: %hlsl.fmad = call <2 x half>  @llvm.fmuladd.v2f16(<2 x half> %[[p0]], <2 x half> %[[p1]], <2 x half> %[[p2]])
 // NATIVE_HALF: ret <2 x half> %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// NO_HALF: %[[p0:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// NO_HALF: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// NO_HALF: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
+// NO_HALF: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %[[p0]], <2 x float> %[[p1]], <2 x float> %[[p2]])
 // NO_HALF: ret <2 x float> %hlsl.fmad
 half2 test_mad_half2(half2 p0, half2 p1, half2 p2) { return mad(p0, p1, p2); }
 
-// NATIVE_HALF: %hlsl.fmad = call <3 x half>  @llvm.fmuladd.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// NATIVE_HALF: %[[p0:.*]] = load <3 x half>, ptr %p0.addr, align 8
+// NATIVE_HALF: %[[p1:.*]] = load <3 x half>, ptr %p1.addr, align 8
+// NATIVE_HALF: %[[p2:.*]] = load <3 x half>, ptr %p2.addr, align 8
+// NATIVE_HALF: %hlsl.fmad = call <3 x half>  @llvm.fmuladd.v3f16(<3 x half> %[[p0]], <3 x half> %[[p1]], <3 x half> %[[p2]])
 // NATIVE_HALF: ret <3 x half> %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// NO_HALF: %[[p0:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// NO_HALF: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// NO_HALF: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
+// NO_HALF: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %[[p0]], <3 x float> %[[p1]], <3 x float> %[[p2]])
 // NO_HALF: ret <3 x float> %hlsl.fmad
 half3 test_mad_half3(half3 p0, half3 p1, half3 p2) { return mad(p0, p1, p2); }
 
-// NATIVE_HALF: %hlsl.fmad = call <4 x half>  @llvm.fmuladd.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// NATIVE_HALF: %[[p0:.*]] = load <4 x half>, ptr %p0.addr, align 8
+// NATIVE_HALF: %[[p1:.*]] = load <4 x half>, ptr %p1.addr, align 8
+// NATIVE_HALF: %[[p2:.*]] = load <4 x half>, ptr %p2.addr, align 8
+// NATIVE_HALF: %hlsl.fmad = call <4 x half>  @llvm.fmuladd.v4f16(<4 x half> %[[p0]], <4 x half> %[[p1]], <4 x half> %[[p2]])
 // NATIVE_HALF: ret <4 x half> %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// NO_HALF: %[[p0:.*]] = load <4 x float>, ptr %p0.addr, align 16
+// NO_HALF: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// NO_HALF: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
+// NO_HALF: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %[[p0]], <4 x float> %[[p1]], <4 x float> %[[p2]])
 // NO_HALF: ret <4 x float> %hlsl.fmad
 half4 test_mad_half4(half4 p0, half4 p1, half4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// CHECK: %[[p0:.*]] = load float, ptr %p0.addr, align 4
+// CHECK: %[[p1:.*]] = load float, ptr %p1.addr, align 4
+// CHECK: %[[p2:.*]] = load float, ptr %p2.addr, align 4
+// CHECK: %hlsl.fmad = call float @llvm.fmuladd.f32(float %[[p0]], float %[[p1]], float %[[p2]])
 // CHECK: ret float %hlsl.fmad
 float test_mad_float(float p0, float p1, float p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// CHECK: %[[p0:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
+// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %[[p0]], <2 x float> %[[p1]], <2 x float> %[[p2]])
 // CHECK: ret <2 x float> %hlsl.fmad
 float2 test_mad_float2(float2 p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// CHECK: %[[p0:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %[[p0]], <3 x float> %[[p1]], <3 x float> %[[p2]])
 // CHECK: ret <3 x float> %hlsl.fmad
 float3 test_mad_float3(float3 p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// CHECK: %[[p0:.*]] = load <4 x float>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %[[p0]], <4 x float> %[[p1]], <4 x float> %[[p2]])
 // CHECK: ret <4 x float> %hlsl.fmad
 float4 test_mad_float4(float4 p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call double @llvm.fmuladd.f64(double %0, double %1, double %2)
+// CHECK: %[[p0:.*]] = load double, ptr %p0.addr, align 8
+// CHECK: %[[p1:.*]] = load double, ptr %p1.addr, align 8
+// CHECK: %[[p2:.*]] = load double, ptr %p2.addr, align 8
+// CHECK: %hlsl.fmad = call double @llvm.fmuladd.f64(double %[[p0]], double %[[p1]], double %[[p2]])
 // CHECK: ret double %hlsl.fmad
 double test_mad_double(double p0, double p1, double p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <2 x double>  @llvm.fmuladd.v2f64(<2 x double> %0, <2 x double> %1, <2 x double> %2)
+// CHECK: %[[p0:.*]] = load <2 x double>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <2 x double>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <2 x double>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <2 x double>  @llvm.fmuladd.v2f64(<2 x double> %[[p0]], <2 x double> %[[p1]], <2 x double> %[[p2]])
 // CHECK: ret <2 x double> %hlsl.fmad
 double2 test_mad_double2(double2 p0, double2 p1, double2 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <3 x double>  @llvm.fmuladd.v3f64(<3 x double> %0, <3 x double> %1, <3 x double> %2)
+// CHECK: %[[p0:.*]] = load <3 x double>, ptr %p0.addr, align 32
+// CHECK: %[[p1:.*]] = load <3 x double>, ptr %p1.addr, align 32
+// CHECK: %[[p2:.*]] = load <3 x double>, ptr %p2.addr, align 32
+// CHECK: %hlsl.fmad = call <3 x double>  @llvm.fmuladd.v3f64(<3 x double> %[[p0]], <3 x double> %[[p1]], <3 x double> %[[p2]])
 // CHECK: ret <3 x double> %hlsl.fmad
 double3 test_mad_double3(double3 p0, double3 p1, double3 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <4 x double>  @llvm.fmuladd.v4f64(<4 x double> %0, <4 x double> %1, <4 x double> %2)
+// CHECK: %[[p0:.*]] = load <4 x double>, ptr %p0.addr, align 32
+// CHECK: %[[p1:.*]] = load <4 x double>, ptr %p1.addr, align 32
+// CHECK: %[[p2:.*]] = load <4 x double>, ptr %p2.addr, align 32
+// CHECK: %hlsl.fmad = call <4 x double>  @llvm.fmuladd.v4f64(<4 x double> %[[p0]], <4 x double> %[[p1]], <4 x double> %[[p2]])
 // CHECK: ret <4 x double> %hlsl.fmad
 double4 test_mad_double4(double4 p0, double4 p1, double4 p2) { return mad(p0, p1, p2); }
 
@@ -216,31 +264,41 @@ uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return
 // SPIR_CHECK: add nuw <4 x i64>  %{{.*}}, %{{.*}}
 uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
+// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %[[p1]], <2 x float> %[[p2]])
 // CHECK: ret <2 x float> %hlsl.fmad
 float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %[[p1]], <3 x float> %[[p2]])
 // CHECK: ret <3 x float> %hlsl.fmad
 float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
 
-// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
+// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %[[p1]], <4 x float> %[[p2]])
 // CHECK:  ret <4 x float> %hlsl.fmad
 float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[p0:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %conv = sitofp i32 %{{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %[[p0]], <2 x float> %[[p1]], <2 x float> %splat.splat)
 // CHECK: ret <2 x float> %hlsl.fmad
 float2 test_mad_float2_int_splat(float2 p0, float2 p1, int p2) {
   return mad(p0, p1, p2);
 }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[p0:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %conv = sitofp i32 %{{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK:  %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// CHECK:  %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %[[p0]], <3 x float> %[[p1]], <3 x float> %splat.splat)
 // CHECK: ret <3 x float> %hlsl.fmad
 float3 test_mad_float3_int_splat(float3 p0, float3 p1, int p2) {
   return mad(p0, p1, p2);

>From 166afdbd4918e70c84a6aa10442c9bd30833ee6b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 13 May 2024 20:06:40 +0200
Subject: [PATCH 6/6] add TODO comment

---
 clang/lib/CodeGen/CodeGenModule.h | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index dcf55df4f07d8..0f68418130ead 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1589,6 +1589,8 @@ class CodeGenModule : public CodeGenTypeCache {
   // Return whether structured convergence intrinsics should be generated for
   // this target.
   bool shouldEmitConvergenceTokens() const {
+    // TODO: this should probably become unconditional once the controlled
+    // convergence becomes the norm.
     return getTriple().isSPIRVLogical();
   }
 



More information about the cfe-commits mailing list