[clang] [clang][SPIR-V] Always add convervence intrinsics (PR #88918)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Tue May 7 06:36:08 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/88918

>From a8bf6fe83a1c145ef81ee30471dc51de1b5354ef Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 15 Apr 2024 17:05:40 +0200
Subject: [PATCH 1/5] [clang][SPIR-V] Always add convervence intrinsics
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

PR #80680 added bits in the codegen to lazily add convergence intrinsics
when required. This logic relied on the LoopStack. The issue is
when parsing the condition, the loopstack doesn't yet reflect the
correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available.
The solution which seemed the simplest is to greedily generate the
tokens when we generate SPIR-V.

Fixes #88144

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  88 +------------
 clang/lib/CodeGen/CGCall.cpp                  |   3 +
 clang/lib/CodeGen/CGStmt.cpp                  |  94 ++++++++++++++
 clang/lib/CodeGen/CodeGenFunction.cpp         |   9 ++
 clang/lib/CodeGen/CodeGenFunction.h           |   9 +-
 .../builtins/RWBuffer-constructor.hlsl        |   1 -
 .../CodeGenHLSL/convergence/do.while.hlsl     |  90 +++++++++++++
 clang/test/CodeGenHLSL/convergence/for.hlsl   | 121 ++++++++++++++++++
 clang/test/CodeGenHLSL/convergence/while.hlsl | 119 +++++++++++++++++
 9 files changed, 445 insertions(+), 89 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/convergence/do.while.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/for.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/while.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 8e31652f4dabef..fb5904558bbae6 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1141,91 +1141,8 @@ struct BitTest {
   static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
 };
 
-// Returns the first convergence entry/loop/anchor instruction found in |BB|.
-// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
-  for (auto &I : *BB) {
-    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
-    if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
-      return II;
-  }
-  return nullptr;
-}
-
 } // namespace
 
-llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken) {
-  llvm::Value *bundleArgs[] = {ParentToken};
-  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
-  auto Output = llvm::CallBase::addOperandBundle(
-      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
-  Input->replaceAllUsesWith(Output);
-  Input->eraseFromParent();
-  return Output;
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                          llvm::Value *ParentToken) {
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto CB = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_loop, {}, {});
-  Builder.restoreIP(IP);
-
-  auto I = addConvergenceControlToken(CB, ParentToken);
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
-
-  // Adding a convergence token requires the function to be marked as
-  // convergent.
-  F->setConvergent();
-
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_entry, {}, {});
-  assert(isa<llvm::IntrinsicInst>(I));
-  Builder.restoreIP(IP);
-
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
-  assert(LI != nullptr);
-
-  auto *token = getConvergenceToken(LI->getHeader());
-  if (token)
-    return token;
-
-  llvm::IntrinsicInst *PII =
-      LI->getParent()
-          ? emitConvergenceLoopToken(
-                LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
-          : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
-
-  return emitConvergenceLoopToken(LI->getHeader(), PII);
-}
-
-llvm::CallBase *
-CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
-  llvm::Value *ParentToken =
-      LoopStack.hasInfo()
-          ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
-          : getOrEmitConvergenceEntryToken(Input->getFunction());
-  return addConvergenceControlToken(Input, ParentToken);
-}
-
 BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
   switch (BuiltinID) {
     // Main portable variants.
@@ -18400,12 +18317,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
-    auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
-    if (getTarget().getTriple().isSPIRVLogical())
-      CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
-    return CI;
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 0c7eef59db53c9..17248c80780884 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4830,6 +4830,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   llvm::CallInst *call = Builder.CreateCall(
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
+
+  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 479945e3b4cb56..127a72c8cbcdd9 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -978,6 +978,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
+        LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+
   // Create an exit block for when the condition fails, which will
   // also become the break target.
   JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1079,6 +1083,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1098,6 +1105,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
     EmitBlockWithFallThrough(LoopBody, S.getBody());
   else
     EmitBlockWithFallThrough(LoopBody, &S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+
   {
     RunCleanupsScope BodyScope(*this);
     EmitStmt(S.getBody());
@@ -1151,6 +1163,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1170,6 +1185,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1279,6 +1298,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void
@@ -1301,6 +1323,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1369,6 +1395,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3158,3 +3187,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
 
   return F;
 }
+
+namespace {
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullptr otherwise.
+llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+  for (auto &I : *BB) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
+      return II;
+  }
+  return nullptr;
+}
+
+} // namespace
+
+llvm::CallBase *
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken) {
+  llvm::Value *bundleArgs[] = {ParentToken};
+  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+  auto Output = llvm::CallBase::addOperandBundle(
+      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+  Input->replaceAllUsesWith(Output);
+  Input->eraseFromParent();
+  return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
+                                          llvm::Value *ParentToken) {
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+
+  if (BB->empty())
+    Builder.SetInsertPoint(BB);
+  else
+    Builder.SetInsertPoint(&BB->front());
+
+  auto CB = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_loop, {}, {});
+  Builder.restoreIP(IP);
+
+  auto I = addConvergenceControlToken(CB, ParentToken);
+  return cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+  auto *BB = &F->getEntryBlock();
+  auto *token = getConvergenceToken(BB);
+  if (token)
+    return token;
+
+  // Adding a convergence token requires the function to be marked as
+  // convergent.
+  F->setConvergent();
+
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto I = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_entry, {}, {});
+  assert(isa<llvm::IntrinsicInst>(I));
+  Builder.restoreIP(IP);
+
+  return cast<llvm::IntrinsicInst>(I);
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 477814140a9e2f..7228a028237770 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -353,6 +353,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(DeferredDeactivationCleanupStack.empty() &&
          "mismatched activate/deactivate of cleanups!");
 
+  if (getTarget().getTriple().isSPIRVLogical()) {
+    ConvergenceTokenStack.pop_back();
+    assert(ConvergenceTokenStack.empty() &&
+           "mismatched push/pop in convergence stack!");
+  }
+
   bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
     && NumSimpleReturnExprs == NumReturnExprs
     && ReturnBlock.getBlock()->use_empty();
@@ -1277,6 +1283,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (CurFuncDecl)
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
 void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e1e687af6a781b..c642c4c9e74978 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -315,6 +315,9 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Stack to track the Logical Operator recursion nest for MC/DC.
   SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
 
+  /// Stack to track the controlled convergence tokens.
+  SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+
   /// Number of nested loop to be consumed by the last surrounding
   /// loop-associated directive.
   int ExpectedOMPLoopDepth = 0;
@@ -5076,7 +5079,11 @@ class CodeGenFunction : public CodeGenTypeCache {
                                      const llvm::Twine &Name = "");
   // Adds a convergence_ctrl token to |Input| and emits the required parent
   // convergence instructions.
-  llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
+  template <typename CallType>
+  CallType *addControlledConvergenceToken(CallType *Input) {
+    return dyn_cast<CallType>(
+        addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
+  }
 
 private:
   // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 74b3f59bf7600f..e51eac7f57c2d3 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,4 +1,3 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
 
 RWBuffer<float> Buf;
diff --git a/clang/test/CodeGenHLSL/convergence/do.while.hlsl b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
new file mode 100644
index 00000000000000..ea5a45ba8fd780
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  do {
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  do {
+    foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  do {
+    if (cond())
+      foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  do {
+    if (cond()) {
+      foo();
+      break;
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  do {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: while.cond:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/for.hlsl b/clang/test/CodeGenHLSL/convergence/for.hlsl
new file mode 100644
index 00000000000000..180fae74ba7514
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/for.hlsl
@@ -0,0 +1,121 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+bool cond2();
+void foo();
+
+void test1() {
+  for (;;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  for (;cond();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  for (cond();;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  for (cond();cond2();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  for (cond();cond2();foo()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test6() {
+  for (cond();cond2();foo()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:   [[C1:%[a-zA-Z0-9]+]] = call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br i1 [[C1]], label %if.then, label %if.end
+// CHECK: if.then:
+// CHECK    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %for.end
+// CHECK: if.end:
+// CHECK:   br label %for.inc
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test7() {
+  for (cond();;) {
+    for (cond();;) {
+      foo();
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test7v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.cond3:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T2]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/while.hlsl b/clang/test/CodeGenHLSL/convergence/while.hlsl
new file mode 100644
index 00000000000000..92777000190d22
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/while.hlsl
@@ -0,0 +1,119 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  while (cond()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  while (cond()) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.body:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  while (cond()) {
+    if (cond())
+      break;
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.cond
+
+void test4() {
+  while (cond()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   br label %while.cond
+
+void test5() {
+  while (cond()) {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK:   br label %while.end
+
+void test6() {
+  while (cond()) {
+    while (cond()) {
+    }
+
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }

>From a733a2ca616445a1236f668932b6e29506c4549d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 16 Apr 2024 18:40:09 +0200
Subject: [PATCH 2/5] feedback review

---
 clang/lib/CodeGen/CGCall.cpp        | 2 +-
 clang/lib/CodeGen/CodeGenFunction.h | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 17248c80780884..cd1d2b32ead44d 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4832,7 +4832,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   call->setCallingConv(getRuntimeCC());
 
   if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
-    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
+    return addControlledConvergenceToken(call);
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index c642c4c9e74978..362f4a5fe72a63 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -5081,7 +5081,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   // convergence instructions.
   template <typename CallType>
   CallType *addControlledConvergenceToken(CallType *Input) {
-    return dyn_cast<CallType>(
+    return cast<CallType>(
         addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
   }
 

>From b70b014e081e3d5474abd693a015fbfa7514870d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 17 Apr 2024 14:12:19 +0200
Subject: [PATCH 3/5] review feedback target check
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGCall.cpp          |  4 ++--
 clang/lib/CodeGen/CGStmt.cpp          | 16 ++++++++--------
 clang/lib/CodeGen/CodeGenFunction.cpp |  4 ++--
 clang/lib/CodeGen/CodeGenModule.h     |  6 ++++++
 4 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index cd1d2b32ead44d..1b4ca2a8b2fe84 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4831,7 +4831,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
 
-  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+  if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
     return addControlledConvergenceToken(call);
   return call;
 }
@@ -5733,7 +5733,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   if (!CI->getType()->isVoidTy())
     CI->setName("call");
 
-  if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
+  if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
     CI = addControlledConvergenceToken(CI);
 
   // Update largest vector width from the return type.
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 127a72c8cbcdd9..e2f5f17ac02db8 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -978,7 +978,7 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
         LoopHeader.getBlock(), ConvergenceTokenStack.back()));
 
@@ -1084,7 +1084,7 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1106,7 +1106,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   else
     EmitBlockWithFallThrough(LoopBody, &S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
 
@@ -1164,7 +1164,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1185,7 +1185,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
 
@@ -1299,7 +1299,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1323,7 +1323,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
 
@@ -1396,7 +1396,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 7228a028237770..051e6764b26597 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -353,7 +353,7 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(DeferredDeactivationCleanupStack.empty() &&
          "mismatched activate/deactivate of cleanups!");
 
-  if (getTarget().getTriple().isSPIRVLogical()) {
+  if (CGM.shouldEmitConvergenceTokens()) {
     ConvergenceTokenStack.pop_back();
     assert(ConvergenceTokenStack.empty() &&
            "mismatched push/pop in convergence stack!");
@@ -1284,7 +1284,7 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index be43a18fc60856..dcf55df4f07d80 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1586,6 +1586,12 @@ class CodeGenModule : public CodeGenTypeCache {
   void AddGlobalDtor(llvm::Function *Dtor, int Priority = 65535,
                      bool IsDtorAttrFunc = false);
 
+  // Return whether structured convergence intrinsics should be generated for
+  // this target.
+  bool shouldEmitConvergenceTokens() const {
+    return getTriple().isSPIRVLogical();
+  }
+
 private:
   llvm::Constant *GetOrCreateLLVMFunction(
       StringRef MangledName, llvm::Type *Ty, GlobalDecl D, bool ForVTable,

>From 2c01eeaee89e337edac002f513d068e540756ae6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 7 May 2024 10:33:10 +0200
Subject: [PATCH 4/5] review feedback

---
 clang/lib/CodeGen/CGStmt.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index e2f5f17ac02db8..36776846cd4464 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -3218,26 +3218,25 @@ llvm::IntrinsicInst *
 CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
                                           llvm::Value *ParentToken) {
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
-
   if (BB->empty())
     Builder.SetInsertPoint(BB);
   else
-    Builder.SetInsertPoint(&BB->front());
+    Builder.SetInsertPoint(BB->getFirstInsertionPt());
 
-  auto CB = Builder.CreateIntrinsic(
+  llvm::CallBase *CB = Builder.CreateIntrinsic(
       llvm::Intrinsic::experimental_convergence_loop, {}, {});
   Builder.restoreIP(IP);
 
-  auto I = addConvergenceControlToken(CB, ParentToken);
+  llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
   return cast<llvm::IntrinsicInst>(I);
 }
 
 llvm::IntrinsicInst *
 CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
+  llvm::BasicBlock *BB = &F->getEntryBlock();
+  llvm::IntrinsicInst *Token = getConvergenceToken(BB);
+  if (Token)
+    return Token;
 
   // Adding a convergence token requires the function to be marked as
   // convergent.
@@ -3245,7 +3244,7 @@ CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
 
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
   Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
+  llvm::CallBase *I = Builder.CreateIntrinsic(
       llvm::Intrinsic::experimental_convergence_entry, {}, {});
   assert(isa<llvm::IntrinsicInst>(I));
   Builder.restoreIP(IP);

>From 131ea26719e41cf1f98a58fe26e6b7f7b1e63fac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 7 May 2024 15:10:52 +0200
Subject: [PATCH 5/5] fix changing register name in tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/test/CodeGenHLSL/builtins/lerp.hlsl |  82 +++++++++--------
 clang/test/CodeGenHLSL/builtins/mad.hlsl  | 104 +++++++++++++++++-----
 2 files changed, 127 insertions(+), 59 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index 87b2e3af576565..bbb419acaf3ba2 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -14,88 +14,98 @@
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK
 
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call half @llvm.spv.lerp.f16(half %0, half %1, half %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call half @llvm.dx.lerp.f16(half %{{.*}}, half %{{.*}}, half %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call half @llvm.spv.lerp.f16(half %{{.*}}, half %{{.*}}, half %{{.*}})
 // NATIVE_HALF: ret half %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
-// SPIR_NO_HALF: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %0, float %1, float %2)
+// DXIL_NO_HALF: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
 // NO_HALF: ret float %hlsl.lerp
 half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.spv.lerp.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.dx.lerp.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call <2 x half> @llvm.spv.lerp.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}})
 // NATIVE_HALF: ret <2 x half> %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
-// SPIR_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// DXIL_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
 // NO_HALF: ret <2 x float> %hlsl.lerp
 half2 test_lerp_half2(half2 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.spv.lerp.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.dx.lerp.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, <3 x half> %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call <3 x half> @llvm.spv.lerp.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, <3 x half> %{{.*}})
 // NATIVE_HALF: ret <3 x half> %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
-// SPIR_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// DXIL_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
 // NO_HALF: ret <3 x float> %hlsl.lerp
 half3 test_lerp_half3(half3 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
-// SPIR_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.spv.lerp.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// DXIL_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.dx.lerp.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x half> %{{.*}})
+// SPIR_NATIVE_HALF: %hlsl.lerp = call <4 x half> @llvm.spv.lerp.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x half> %{{.*}})
 // NATIVE_HALF: ret <4 x half> %hlsl.lerp
-// DXIL_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
-// SPIR_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// DXIL_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
+// SPIR_NO_HALF: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // NO_HALF: ret <4 x float> %hlsl.lerp
 half4 test_lerp_half4(half4 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
-// SPIR_CHECK: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %0, float %1, float %2)
+// DXIL_CHECK: %hlsl.lerp = call float @llvm.dx.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call float @llvm.spv.lerp.f32(float %{{.*}}, float %{{.*}}, float %{{.*}})
 // CHECK: ret float %hlsl.lerp
 float test_lerp_float(float p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}})
 // CHECK: ret <2 x float> %hlsl.lerp
 float2 test_lerp_float2(float2 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, <3 x float> %{{.*}})
 // CHECK: ret <3 x float> %hlsl.lerp
 float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
+// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
 
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: %[[b:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %[[c:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
+// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
 // CHECK: ret <2 x float> %hlsl.lerp
 float2 test_lerp_float2_splat(float p0, float2 p1) { return lerp(p0, p1, p1); }
 
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: %[[b:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %[[c:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
+// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
 // CHECK: ret <3 x float> %hlsl.lerp
 float3 test_lerp_float3_splat(float p0, float3 p1) { return lerp(p0, p1, p1); }
 
-// DXIL_CHECK:  %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
-// SPIR_CHECK:  %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK: %[[b:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// CHECK: %[[c:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
+// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
 // CHECK:  ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4_splat(float p0, float4 p1) { return lerp(p0, p1, p1); }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[a:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %[[b:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %conv = sitofp i32 {{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %[[a]], <2 x float> %[[b]], <2 x float> %splat.splat)
+// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %[[a]], <2 x float> %[[b]], <2 x float> %splat.splat)
 // CHECK: ret <2 x float> %hlsl.lerp
 float2 test_lerp_float2_int_splat(float2 p0, int p1) {
   return lerp(p0, p0, p1);
 }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[a:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %[[b:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %conv = sitofp i32 {{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// DXIL_CHECK:  %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
-// SPIR_CHECK:  %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// DXIL_CHECK:  %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %[[a]], <3 x float> %[[b]], <3 x float> %splat.splat)
+// SPIR_CHECK:  %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %[[a]], <3 x float> %[[b]], <3 x float> %splat.splat)
 // CHECK: ret <3 x float> %hlsl.lerp
 float3 test_lerp_float3_int_splat(float3 p0, int p1) {
   return lerp(p0, p0, p1);
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index b4dc636b00b71b..559e1d1dd3903a 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -64,59 +64,107 @@ int16_t3 test_mad_int16_t3(int16_t3 p0, int16_t3 p1, int16_t3 p2) { return mad(p
 int16_t4 test_mad_int16_t4(int16_t4 p0, int16_t4 p1, int16_t4 p2) { return mad(p0, p1, p2); }
 #endif // __HLSL_ENABLE_16_BIT
 
-// NATIVE_HALF: %hlsl.fmad = call half @llvm.fmuladd.f16(half %0, half %1, half %2)
+// NATIVE_HALF: %[[p0:.*]] = load half, ptr %p0.addr, align 2
+// NATIVE_HALF: %[[p1:.*]] = load half, ptr %p1.addr, align 2
+// NATIVE_HALF: %[[p2:.*]] = load half, ptr %p2.addr, align 2
+// NATIVE_HALF: %hlsl.fmad = call half @llvm.fmuladd.f16(half %[[p0]], half %[[p1]], half %[[p2]])
 // NATIVE_HALF: ret half %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// NO_HALF: %[[p0:.*]] = load float, ptr %p0.addr, align 4
+// NO_HALF: %[[p1:.*]] = load float, ptr %p1.addr, align 4
+// NO_HALF: %[[p2:.*]] = load float, ptr %p2.addr, align 4
+// NO_HALF: %hlsl.fmad = call float @llvm.fmuladd.f32(float %[[p0]], float %[[p1]], float %[[p2]])
 // NO_HALF: ret float %hlsl.fmad
 half test_mad_half(half p0, half p1, half p2) { return mad(p0, p1, p2); }
 
-// NATIVE_HALF: %hlsl.fmad = call <2 x half>  @llvm.fmuladd.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+// NATIVE_HALF: %[[p0:.*]] = load <2 x half>, ptr %p0.addr, align 4
+// NATIVE_HALF: %[[p1:.*]] = load <2 x half>, ptr %p1.addr, align 4
+// NATIVE_HALF: %[[p2:.*]] = load <2 x half>, ptr %p2.addr, align 4
+// NATIVE_HALF: %hlsl.fmad = call <2 x half>  @llvm.fmuladd.v2f16(<2 x half> %[[p0]], <2 x half> %[[p1]], <2 x half> %[[p2]])
 // NATIVE_HALF: ret <2 x half> %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// NO_HALF: %[[p0:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// NO_HALF: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// NO_HALF: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
+// NO_HALF: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %[[p0]], <2 x float> %[[p1]], <2 x float> %[[p2]])
 // NO_HALF: ret <2 x float> %hlsl.fmad
 half2 test_mad_half2(half2 p0, half2 p1, half2 p2) { return mad(p0, p1, p2); }
 
-// NATIVE_HALF: %hlsl.fmad = call <3 x half>  @llvm.fmuladd.v3f16(<3 x half> %0, <3 x half> %1, <3 x half> %2)
+// NATIVE_HALF: %[[p0:.*]] = load <3 x half>, ptr %p0.addr, align 8
+// NATIVE_HALF: %[[p1:.*]] = load <3 x half>, ptr %p1.addr, align 8
+// NATIVE_HALF: %[[p2:.*]] = load <3 x half>, ptr %p2.addr, align 8
+// NATIVE_HALF: %hlsl.fmad = call <3 x half>  @llvm.fmuladd.v3f16(<3 x half> %[[p0]], <3 x half> %[[p1]], <3 x half> %[[p2]])
 // NATIVE_HALF: ret <3 x half> %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// NO_HALF: %[[p0:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// NO_HALF: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// NO_HALF: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
+// NO_HALF: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %[[p0]], <3 x float> %[[p1]], <3 x float> %[[p2]])
 // NO_HALF: ret <3 x float> %hlsl.fmad
 half3 test_mad_half3(half3 p0, half3 p1, half3 p2) { return mad(p0, p1, p2); }
 
-// NATIVE_HALF: %hlsl.fmad = call <4 x half>  @llvm.fmuladd.v4f16(<4 x half> %0, <4 x half> %1, <4 x half> %2)
+// NATIVE_HALF: %[[p0:.*]] = load <4 x half>, ptr %p0.addr, align 8
+// NATIVE_HALF: %[[p1:.*]] = load <4 x half>, ptr %p1.addr, align 8
+// NATIVE_HALF: %[[p2:.*]] = load <4 x half>, ptr %p2.addr, align 8
+// NATIVE_HALF: %hlsl.fmad = call <4 x half>  @llvm.fmuladd.v4f16(<4 x half> %[[p0]], <4 x half> %[[p1]], <4 x half> %[[p2]])
 // NATIVE_HALF: ret <4 x half> %hlsl.fmad
-// NO_HALF: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// NO_HALF: %[[p0:.*]] = load <4 x float>, ptr %p0.addr, align 16
+// NO_HALF: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// NO_HALF: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
+// NO_HALF: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %[[p0]], <4 x float> %[[p1]], <4 x float> %[[p2]])
 // NO_HALF: ret <4 x float> %hlsl.fmad
 half4 test_mad_half4(half4 p0, half4 p1, half4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call float @llvm.fmuladd.f32(float %0, float %1, float %2)
+// CHECK: %[[p0:.*]] = load float, ptr %p0.addr, align 4
+// CHECK: %[[p1:.*]] = load float, ptr %p1.addr, align 4
+// CHECK: %[[p2:.*]] = load float, ptr %p2.addr, align 4
+// CHECK: %hlsl.fmad = call float @llvm.fmuladd.f32(float %[[p0]], float %[[p1]], float %[[p2]])
 // CHECK: ret float %hlsl.fmad
 float test_mad_float(float p0, float p1, float p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %2)
+// CHECK: %[[p0:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
+// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %[[p0]], <2 x float> %[[p1]], <2 x float> %[[p2]])
 // CHECK: ret <2 x float> %hlsl.fmad
 float2 test_mad_float2(float2 p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %2)
+// CHECK: %[[p0:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %[[p0]], <3 x float> %[[p1]], <3 x float> %[[p2]])
 // CHECK: ret <3 x float> %hlsl.fmad
 float3 test_mad_float3(float3 p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+// CHECK: %[[p0:.*]] = load <4 x float>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %[[p0]], <4 x float> %[[p1]], <4 x float> %[[p2]])
 // CHECK: ret <4 x float> %hlsl.fmad
 float4 test_mad_float4(float4 p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call double @llvm.fmuladd.f64(double %0, double %1, double %2)
+// CHECK: %[[p0:.*]] = load double, ptr %p0.addr, align 8
+// CHECK: %[[p1:.*]] = load double, ptr %p1.addr, align 8
+// CHECK: %[[p2:.*]] = load double, ptr %p2.addr, align 8
+// CHECK: %hlsl.fmad = call double @llvm.fmuladd.f64(double %[[p0]], double %[[p1]], double %[[p2]])
 // CHECK: ret double %hlsl.fmad
 double test_mad_double(double p0, double p1, double p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <2 x double>  @llvm.fmuladd.v2f64(<2 x double> %0, <2 x double> %1, <2 x double> %2)
+// CHECK: %[[p0:.*]] = load <2 x double>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <2 x double>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <2 x double>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <2 x double>  @llvm.fmuladd.v2f64(<2 x double> %[[p0]], <2 x double> %[[p1]], <2 x double> %[[p2]])
 // CHECK: ret <2 x double> %hlsl.fmad
 double2 test_mad_double2(double2 p0, double2 p1, double2 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <3 x double>  @llvm.fmuladd.v3f64(<3 x double> %0, <3 x double> %1, <3 x double> %2)
+// CHECK: %[[p0:.*]] = load <3 x double>, ptr %p0.addr, align 32
+// CHECK: %[[p1:.*]] = load <3 x double>, ptr %p1.addr, align 32
+// CHECK: %[[p2:.*]] = load <3 x double>, ptr %p2.addr, align 32
+// CHECK: %hlsl.fmad = call <3 x double>  @llvm.fmuladd.v3f64(<3 x double> %[[p0]], <3 x double> %[[p1]], <3 x double> %[[p2]])
 // CHECK: ret <3 x double> %hlsl.fmad
 double3 test_mad_double3(double3 p0, double3 p1, double3 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <4 x double>  @llvm.fmuladd.v4f64(<4 x double> %0, <4 x double> %1, <4 x double> %2)
+// CHECK: %[[p0:.*]] = load <4 x double>, ptr %p0.addr, align 32
+// CHECK: %[[p1:.*]] = load <4 x double>, ptr %p1.addr, align 32
+// CHECK: %[[p2:.*]] = load <4 x double>, ptr %p2.addr, align 32
+// CHECK: %hlsl.fmad = call <4 x double>  @llvm.fmuladd.v4f64(<4 x double> %[[p0]], <4 x double> %[[p1]], <4 x double> %[[p2]])
 // CHECK: ret <4 x double> %hlsl.fmad
 double4 test_mad_double4(double4 p0, double4 p1, double4 p2) { return mad(p0, p1, p2); }
 
@@ -216,31 +264,41 @@ uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return
 // SPIR_CHECK: add nuw <4 x i64>  %{{.*}}, %{{.*}}
 uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %1, <2 x float> %2)
+// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
+// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %[[p1]], <2 x float> %[[p2]])
 // CHECK: ret <2 x float> %hlsl.fmad
 float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %1, <3 x float> %2)
+// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
+// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %[[p1]], <3 x float> %[[p2]])
 // CHECK: ret <3 x float> %hlsl.fmad
 float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
 
-// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %1, <4 x float> %2)
+// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
+// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
+// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %[[p1]], <4 x float> %[[p2]])
 // CHECK:  ret <4 x float> %hlsl.fmad
 float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[p0:.*]] = load <2 x float>, ptr %p0.addr, align 8
+// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
+// CHECK: %conv = sitofp i32 %{{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %0, <2 x float> %1, <2 x float> %splat.splat)
+// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %[[p0]], <2 x float> %[[p1]], <2 x float> %splat.splat)
 // CHECK: ret <2 x float> %hlsl.fmad
 float2 test_mad_float2_int_splat(float2 p0, float2 p1, int p2) {
   return mad(p0, p1, p2);
 }
 
-// CHECK: %conv = sitofp i32 %2 to float
+// CHECK: %[[p0:.*]] = load <3 x float>, ptr %p0.addr, align 16
+// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
+// CHECK: %conv = sitofp i32 %{{.*}} to float
 // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK:  %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %0, <3 x float> %1, <3 x float> %splat.splat)
+// CHECK:  %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %[[p0]], <3 x float> %[[p1]], <3 x float> %splat.splat)
 // CHECK: ret <3 x float> %hlsl.fmad
 float3 test_mad_float3_int_splat(float3 p0, float3 p1, int p2) {
   return mad(p0, p1, p2);



More information about the cfe-commits mailing list