[clang] [clang][SPIR-V] Always add convervence intrinsics (PR #88918)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Tue May 7 01:48:13 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/88918

>From 94d76dcdfac88d1d50fe705406c0280c33766e15 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 15 Apr 2024 17:05:40 +0200
Subject: [PATCH 1/4] [clang][SPIR-V] Always add convervence intrinsics
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

PR #80680 added bits in the codegen to lazily add convergence intrinsics
when required. This logic relied on the LoopStack. The issue is
when parsing the condition, the loopstack doesn't yet reflect the
correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available.
The solution which seemed the simplest is to greedily generate the
tokens when we generate SPIR-V.

Fixes #88144

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  88 +------------
 clang/lib/CodeGen/CGCall.cpp                  |   3 +
 clang/lib/CodeGen/CGStmt.cpp                  |  94 ++++++++++++++
 clang/lib/CodeGen/CodeGenFunction.cpp         |   9 ++
 clang/lib/CodeGen/CodeGenFunction.h           |   9 +-
 .../builtins/RWBuffer-constructor.hlsl        |   1 -
 .../CodeGenHLSL/convergence/do.while.hlsl     |  90 +++++++++++++
 clang/test/CodeGenHLSL/convergence/for.hlsl   | 121 ++++++++++++++++++
 clang/test/CodeGenHLSL/convergence/while.hlsl | 119 +++++++++++++++++
 9 files changed, 445 insertions(+), 89 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/convergence/do.while.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/for.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/while.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index df7502b8def5314..f5d40a1555fcb57 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1133,91 +1133,8 @@ struct BitTest {
   static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
 };
 
-// Returns the first convergence entry/loop/anchor instruction found in |BB|.
-// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
-  for (auto &I : *BB) {
-    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
-    if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
-      return II;
-  }
-  return nullptr;
-}
-
 } // namespace
 
-llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken) {
-  llvm::Value *bundleArgs[] = {ParentToken};
-  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
-  auto Output = llvm::CallBase::addOperandBundle(
-      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
-  Input->replaceAllUsesWith(Output);
-  Input->eraseFromParent();
-  return Output;
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                          llvm::Value *ParentToken) {
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto CB = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_loop, {}, {});
-  Builder.restoreIP(IP);
-
-  auto I = addConvergenceControlToken(CB, ParentToken);
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
-
-  // Adding a convergence token requires the function to be marked as
-  // convergent.
-  F->setConvergent();
-
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_entry, {}, {});
-  assert(isa<llvm::IntrinsicInst>(I));
-  Builder.restoreIP(IP);
-
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
-  assert(LI != nullptr);
-
-  auto *token = getConvergenceToken(LI->getHeader());
-  if (token)
-    return token;
-
-  llvm::IntrinsicInst *PII =
-      LI->getParent()
-          ? emitConvergenceLoopToken(
-                LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
-          : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
-
-  return emitConvergenceLoopToken(LI->getHeader(), PII);
-}
-
-llvm::CallBase *
-CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
-  llvm::Value *ParentToken =
-      LoopStack.hasInfo()
-          ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
-          : getOrEmitConvergenceEntryToken(Input->getFunction());
-  return addConvergenceControlToken(Input, ParentToken);
-}
-
 BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
   switch (BuiltinID) {
     // Main portable variants.
@@ -18306,12 +18223,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
-    auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
-    if (getTarget().getTriple().isSPIRVLogical())
-      CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
-    return CI;
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f12765b826935b2..06d4bceacfd34b9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4824,6 +4824,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   llvm::CallInst *call = Builder.CreateCall(
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
+
+  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 576fe2f7a2d46f4..f8287e100f4bd55 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -915,6 +915,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
+        LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+
   // Create an exit block for when the condition fails, which will
   // also become the break target.
   JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1017,6 +1021,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1036,6 +1043,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
     EmitBlockWithFallThrough(LoopBody, S.getBody());
   else
     EmitBlockWithFallThrough(LoopBody, &S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+
   {
     RunCleanupsScope BodyScope(*this);
     EmitStmt(S.getBody());
@@ -1090,6 +1102,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1109,6 +1124,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   Expr::EvalResult Result;
   bool CondIsConstInt =
       !S.getCond() || S.getCond()->EvaluateAsInt(Result, getContext());
@@ -1222,6 +1241,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void
@@ -1244,6 +1266,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1312,6 +1338,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3101,3 +3130,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
 
   return F;
 }
+
+namespace {
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullptr otherwise.
+llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+  for (auto &I : *BB) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
+      return II;
+  }
+  return nullptr;
+}
+
+} // namespace
+
+llvm::CallBase *
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken) {
+  llvm::Value *bundleArgs[] = {ParentToken};
+  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+  auto Output = llvm::CallBase::addOperandBundle(
+      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+  Input->replaceAllUsesWith(Output);
+  Input->eraseFromParent();
+  return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
+                                          llvm::Value *ParentToken) {
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+
+  if (BB->empty())
+    Builder.SetInsertPoint(BB);
+  else
+    Builder.SetInsertPoint(&BB->front());
+
+  auto CB = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_loop, {}, {});
+  Builder.restoreIP(IP);
+
+  auto I = addConvergenceControlToken(CB, ParentToken);
+  return cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+  auto *BB = &F->getEntryBlock();
+  auto *token = getConvergenceToken(BB);
+  if (token)
+    return token;
+
+  // Adding a convergence token requires the function to be marked as
+  // convergent.
+  F->setConvergent();
+
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto I = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_entry, {}, {});
+  assert(isa<llvm::IntrinsicInst>(I));
+  Builder.restoreIP(IP);
+
+  return cast<llvm::IntrinsicInst>(I);
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 6474d6c8c1d1e42..8f3327bf12a4b33 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -347,6 +347,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(BreakContinueStack.empty() &&
          "mismatched push/pop in break/continue stack!");
 
+  if (getTarget().getTriple().isSPIRVLogical()) {
+    ConvergenceTokenStack.pop_back();
+    assert(ConvergenceTokenStack.empty() &&
+           "mismatched push/pop in convergence stack!");
+  }
+
   bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
     && NumSimpleReturnExprs == NumReturnExprs
     && ReturnBlock.getBlock()->use_empty();
@@ -1271,6 +1277,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (CurFuncDecl)
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
 void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e2a7e28c8211ea7..12c5e71bf6af60f 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -314,6 +314,9 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Stack to track the Logical Operator recursion nest for MC/DC.
   SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
 
+  /// Stack to track the controlled convergence tokens.
+  SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+
   /// Number of nested loop to be consumed by the last surrounding
   /// loop-associated directive.
   int ExpectedOMPLoopDepth = 0;
@@ -4987,7 +4990,11 @@ class CodeGenFunction : public CodeGenTypeCache {
                                      const llvm::Twine &Name = "");
   // Adds a convergence_ctrl token to |Input| and emits the required parent
   // convergence instructions.
-  llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
+  template <typename CallType>
+  CallType *addControlledConvergenceToken(CallType *Input) {
+    return dyn_cast<CallType>(
+        addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
+  }
 
 private:
   // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 74b3f59bf7600fd..e51eac7f57c2d31 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,4 +1,3 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
 
 RWBuffer<float> Buf;
diff --git a/clang/test/CodeGenHLSL/convergence/do.while.hlsl b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
new file mode 100644
index 000000000000000..ea5a45ba8fd780a
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  do {
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  do {
+    foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  do {
+    if (cond())
+      foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  do {
+    if (cond()) {
+      foo();
+      break;
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  do {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: while.cond:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/for.hlsl b/clang/test/CodeGenHLSL/convergence/for.hlsl
new file mode 100644
index 000000000000000..180fae74ba7514e
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/for.hlsl
@@ -0,0 +1,121 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+bool cond2();
+void foo();
+
+void test1() {
+  for (;;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  for (;cond();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  for (cond();;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  for (cond();cond2();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  for (cond();cond2();foo()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test6() {
+  for (cond();cond2();foo()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:   [[C1:%[a-zA-Z0-9]+]] = call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br i1 [[C1]], label %if.then, label %if.end
+// CHECK: if.then:
+// CHECK    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %for.end
+// CHECK: if.end:
+// CHECK:   br label %for.inc
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test7() {
+  for (cond();;) {
+    for (cond();;) {
+      foo();
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test7v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.cond3:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T2]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/while.hlsl b/clang/test/CodeGenHLSL/convergence/while.hlsl
new file mode 100644
index 000000000000000..92777000190d22a
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/while.hlsl
@@ -0,0 +1,119 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  while (cond()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  while (cond()) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.body:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  while (cond()) {
+    if (cond())
+      break;
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.cond
+
+void test4() {
+  while (cond()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   br label %while.cond
+
+void test5() {
+  while (cond()) {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK:   br label %while.end
+
+void test6() {
+  while (cond()) {
+    while (cond()) {
+    }
+
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }

>From 6f3b0ec5353d819380d8e454c528232e82e1f22a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 16 Apr 2024 18:40:09 +0200
Subject: [PATCH 2/4] feedback review

---
 clang/lib/CodeGen/CGCall.cpp        | 2 +-
 clang/lib/CodeGen/CodeGenFunction.h | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 06d4bceacfd34b9..95cf25f4a0b3c01 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4826,7 +4826,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   call->setCallingConv(getRuntimeCC());
 
   if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
-    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
+    return addControlledConvergenceToken(call);
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 12c5e71bf6af60f..f348a1e866b6f62 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4992,7 +4992,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   // convergence instructions.
   template <typename CallType>
   CallType *addControlledConvergenceToken(CallType *Input) {
-    return dyn_cast<CallType>(
+    return cast<CallType>(
         addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
   }
 

>From 501e5665dba695eefc6e2c4751c158e5ecf70865 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 17 Apr 2024 14:12:19 +0200
Subject: [PATCH 3/4] review feedback target check
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGCall.cpp          |  4 ++--
 clang/lib/CodeGen/CGStmt.cpp          | 16 ++++++++--------
 clang/lib/CodeGen/CodeGenFunction.cpp |  4 ++--
 clang/lib/CodeGen/CodeGenModule.h     |  6 ++++++
 4 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 95cf25f4a0b3c01..0bac8dc918369e7 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4825,7 +4825,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
 
-  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+  if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
     return addControlledConvergenceToken(call);
   return call;
 }
@@ -5720,7 +5720,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   if (!CI->getType()->isVoidTy())
     CI->setName("call");
 
-  if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
+  if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
     CI = addControlledConvergenceToken(CI);
 
   // Update largest vector width from the return type.
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index f8287e100f4bd55..c2f1ee505912d03 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -915,7 +915,7 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
         LoopHeader.getBlock(), ConvergenceTokenStack.back()));
 
@@ -1022,7 +1022,7 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1044,7 +1044,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   else
     EmitBlockWithFallThrough(LoopBody, &S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
 
@@ -1103,7 +1103,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1124,7 +1124,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
 
@@ -1242,7 +1242,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
@@ -1266,7 +1266,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(
         emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
 
@@ -1339,7 +1339,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.pop_back();
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 8f3327bf12a4b33..d4142b02d973f8b 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -347,7 +347,7 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(BreakContinueStack.empty() &&
          "mismatched push/pop in break/continue stack!");
 
-  if (getTarget().getTriple().isSPIRVLogical()) {
+  if (CGM.shouldEmitConvergenceTokens()) {
     ConvergenceTokenStack.pop_back();
     assert(ConvergenceTokenStack.empty() &&
            "mismatched push/pop in convergence stack!");
@@ -1278,7 +1278,7 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
 
-  if (getTarget().getTriple().isSPIRVLogical())
+  if (CGM.shouldEmitConvergenceTokens())
     ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index 1cc447765e2c977..ecf19d20223752f 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1583,6 +1583,12 @@ class CodeGenModule : public CodeGenTypeCache {
   void AddGlobalDtor(llvm::Function *Dtor, int Priority = 65535,
                      bool IsDtorAttrFunc = false);
 
+  // Return whether structured convergence intrinsics should be generated for
+  // this target.
+  bool shouldEmitConvergenceTokens() const {
+    return getTriple().isSPIRVLogical();
+  }
+
 private:
   llvm::Constant *GetOrCreateLLVMFunction(
       StringRef MangledName, llvm::Type *Ty, GlobalDecl D, bool ForVTable,

>From c9aa6a623b1d0584dda78c3798213d6b974ddaca Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 7 May 2024 10:33:10 +0200
Subject: [PATCH 4/4] review feedback

---
 clang/lib/CodeGen/CGStmt.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index c2f1ee505912d03..c199626e55fcd2f 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -3161,26 +3161,25 @@ llvm::IntrinsicInst *
 CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
                                           llvm::Value *ParentToken) {
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
-
   if (BB->empty())
     Builder.SetInsertPoint(BB);
   else
-    Builder.SetInsertPoint(&BB->front());
+    Builder.SetInsertPoint(BB->getFirstInsertionPt());
 
-  auto CB = Builder.CreateIntrinsic(
+  llvm::CallBase *CB = Builder.CreateIntrinsic(
       llvm::Intrinsic::experimental_convergence_loop, {}, {});
   Builder.restoreIP(IP);
 
-  auto I = addConvergenceControlToken(CB, ParentToken);
+  llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
   return cast<llvm::IntrinsicInst>(I);
 }
 
 llvm::IntrinsicInst *
 CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
+  llvm::BasicBlock *BB = &F->getEntryBlock();
+  llvm::IntrinsicInst *Token = getConvergenceToken(BB);
+  if (Token)
+    return Token;
 
   // Adding a convergence token requires the function to be marked as
   // convergent.
@@ -3188,7 +3187,7 @@ CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
 
   CGBuilderTy::InsertPoint IP = Builder.saveIP();
   Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
+  llvm::CallBase *I = Builder.CreateIntrinsic(
       llvm::Intrinsic::experimental_convergence_entry, {}, {});
   assert(isa<llvm::IntrinsicInst>(I));
   Builder.restoreIP(IP);



More information about the cfe-commits mailing list