[clang] [clang][SPIR-V] Always add convervence intrinsics (PR #88918)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Tue Apr 16 09:40:47 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/88918

>From 94d76dcdfac88d1d50fe705406c0280c33766e15 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 15 Apr 2024 17:05:40 +0200
Subject: [PATCH 1/2] [clang][SPIR-V] Always add convervence intrinsics
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

PR #80680 added bits in the codegen to lazily add convergence intrinsics
when required. This logic relied on the LoopStack. The issue is
when parsing the condition, the loopstack doesn't yet reflect the
correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available.
The solution which seemed the simplest is to greedily generate the
tokens when we generate SPIR-V.

Fixes #88144

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  88 +------------
 clang/lib/CodeGen/CGCall.cpp                  |   3 +
 clang/lib/CodeGen/CGStmt.cpp                  |  94 ++++++++++++++
 clang/lib/CodeGen/CodeGenFunction.cpp         |   9 ++
 clang/lib/CodeGen/CodeGenFunction.h           |   9 +-
 .../builtins/RWBuffer-constructor.hlsl        |   1 -
 .../CodeGenHLSL/convergence/do.while.hlsl     |  90 +++++++++++++
 clang/test/CodeGenHLSL/convergence/for.hlsl   | 121 ++++++++++++++++++
 clang/test/CodeGenHLSL/convergence/while.hlsl | 119 +++++++++++++++++
 9 files changed, 445 insertions(+), 89 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/convergence/do.while.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/for.hlsl
 create mode 100644 clang/test/CodeGenHLSL/convergence/while.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index df7502b8def531..f5d40a1555fcb5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1133,91 +1133,8 @@ struct BitTest {
   static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
 };
 
-// Returns the first convergence entry/loop/anchor instruction found in |BB|.
-// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
-  for (auto &I : *BB) {
-    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
-    if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
-      return II;
-  }
-  return nullptr;
-}
-
 } // namespace
 
-llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken) {
-  llvm::Value *bundleArgs[] = {ParentToken};
-  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
-  auto Output = llvm::CallBase::addOperandBundle(
-      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
-  Input->replaceAllUsesWith(Output);
-  Input->eraseFromParent();
-  return Output;
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                          llvm::Value *ParentToken) {
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto CB = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_loop, {}, {});
-  Builder.restoreIP(IP);
-
-  auto I = addConvergenceControlToken(CB, ParentToken);
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
-
-  // Adding a convergence token requires the function to be marked as
-  // convergent.
-  F->setConvergent();
-
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_entry, {}, {});
-  assert(isa<llvm::IntrinsicInst>(I));
-  Builder.restoreIP(IP);
-
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
-  assert(LI != nullptr);
-
-  auto *token = getConvergenceToken(LI->getHeader());
-  if (token)
-    return token;
-
-  llvm::IntrinsicInst *PII =
-      LI->getParent()
-          ? emitConvergenceLoopToken(
-                LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
-          : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
-
-  return emitConvergenceLoopToken(LI->getHeader(), PII);
-}
-
-llvm::CallBase *
-CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
-  llvm::Value *ParentToken =
-      LoopStack.hasInfo()
-          ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
-          : getOrEmitConvergenceEntryToken(Input->getFunction());
-  return addConvergenceControlToken(Input, ParentToken);
-}
-
 BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
   switch (BuiltinID) {
     // Main portable variants.
@@ -18306,12 +18223,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
-    auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
-    if (getTarget().getTriple().isSPIRVLogical())
-      CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
-    return CI;
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f12765b826935b..06d4bceacfd34b 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4824,6 +4824,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   llvm::CallInst *call = Builder.CreateCall(
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
+
+  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 576fe2f7a2d46f..f8287e100f4bd5 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -915,6 +915,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
+        LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+
   // Create an exit block for when the condition fails, which will
   // also become the break target.
   JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1017,6 +1021,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1036,6 +1043,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
     EmitBlockWithFallThrough(LoopBody, S.getBody());
   else
     EmitBlockWithFallThrough(LoopBody, &S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+
   {
     RunCleanupsScope BodyScope(*this);
     EmitStmt(S.getBody());
@@ -1090,6 +1102,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1109,6 +1124,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   Expr::EvalResult Result;
   bool CondIsConstInt =
       !S.getCond() || S.getCond()->EvaluateAsInt(Result, getContext());
@@ -1222,6 +1241,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void
@@ -1244,6 +1266,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1312,6 +1338,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3101,3 +3130,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
 
   return F;
 }
+
+namespace {
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullptr otherwise.
+llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+  for (auto &I : *BB) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
+      return II;
+  }
+  return nullptr;
+}
+
+} // namespace
+
+llvm::CallBase *
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken) {
+  llvm::Value *bundleArgs[] = {ParentToken};
+  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+  auto Output = llvm::CallBase::addOperandBundle(
+      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+  Input->replaceAllUsesWith(Output);
+  Input->eraseFromParent();
+  return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
+                                          llvm::Value *ParentToken) {
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+
+  if (BB->empty())
+    Builder.SetInsertPoint(BB);
+  else
+    Builder.SetInsertPoint(&BB->front());
+
+  auto CB = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_loop, {}, {});
+  Builder.restoreIP(IP);
+
+  auto I = addConvergenceControlToken(CB, ParentToken);
+  return cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+  auto *BB = &F->getEntryBlock();
+  auto *token = getConvergenceToken(BB);
+  if (token)
+    return token;
+
+  // Adding a convergence token requires the function to be marked as
+  // convergent.
+  F->setConvergent();
+
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto I = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_entry, {}, {});
+  assert(isa<llvm::IntrinsicInst>(I));
+  Builder.restoreIP(IP);
+
+  return cast<llvm::IntrinsicInst>(I);
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 6474d6c8c1d1e4..8f3327bf12a4b3 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -347,6 +347,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(BreakContinueStack.empty() &&
          "mismatched push/pop in break/continue stack!");
 
+  if (getTarget().getTriple().isSPIRVLogical()) {
+    ConvergenceTokenStack.pop_back();
+    assert(ConvergenceTokenStack.empty() &&
+           "mismatched push/pop in convergence stack!");
+  }
+
   bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
     && NumSimpleReturnExprs == NumReturnExprs
     && ReturnBlock.getBlock()->use_empty();
@@ -1271,6 +1277,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (CurFuncDecl)
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
 void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e2a7e28c8211ea..12c5e71bf6af60 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -314,6 +314,9 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Stack to track the Logical Operator recursion nest for MC/DC.
   SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
 
+  /// Stack to track the controlled convergence tokens.
+  SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+
   /// Number of nested loop to be consumed by the last surrounding
   /// loop-associated directive.
   int ExpectedOMPLoopDepth = 0;
@@ -4987,7 +4990,11 @@ class CodeGenFunction : public CodeGenTypeCache {
                                      const llvm::Twine &Name = "");
   // Adds a convergence_ctrl token to |Input| and emits the required parent
   // convergence instructions.
-  llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
+  template <typename CallType>
+  CallType *addControlledConvergenceToken(CallType *Input) {
+    return dyn_cast<CallType>(
+        addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
+  }
 
 private:
   // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 74b3f59bf7600f..e51eac7f57c2d3 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,4 +1,3 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
 
 RWBuffer<float> Buf;
diff --git a/clang/test/CodeGenHLSL/convergence/do.while.hlsl b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
new file mode 100644
index 00000000000000..ea5a45ba8fd780
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  do {
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  do {
+    foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  do {
+    if (cond())
+      foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  do {
+    if (cond()) {
+      foo();
+      break;
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  do {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: while.cond:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/for.hlsl b/clang/test/CodeGenHLSL/convergence/for.hlsl
new file mode 100644
index 00000000000000..180fae74ba7514
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/for.hlsl
@@ -0,0 +1,121 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+bool cond2();
+void foo();
+
+void test1() {
+  for (;;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  for (;cond();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  for (cond();;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  for (cond();cond2();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  for (cond();cond2();foo()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test6() {
+  for (cond();cond2();foo()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:   [[C1:%[a-zA-Z0-9]+]] = call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br i1 [[C1]], label %if.then, label %if.end
+// CHECK: if.then:
+// CHECK    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %for.end
+// CHECK: if.end:
+// CHECK:   br label %for.inc
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test7() {
+  for (cond();;) {
+    for (cond();;) {
+      foo();
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test7v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.cond3:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T2]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/while.hlsl b/clang/test/CodeGenHLSL/convergence/while.hlsl
new file mode 100644
index 00000000000000..92777000190d22
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/while.hlsl
@@ -0,0 +1,119 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  while (cond()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  while (cond()) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.body:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  while (cond()) {
+    if (cond())
+      break;
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.cond
+
+void test4() {
+  while (cond()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+// CHECK: if.end:
+// CHECK:   br label %while.cond
+
+void test5() {
+  while (cond()) {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK:   br label %while.end
+
+void test6() {
+  while (cond()) {
+    while (cond()) {
+    }
+
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: define spir_func void @_Z5test6v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: while.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: while.cond2:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: if.then:
+// CHECK:   call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK:   br label %while.end
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }

>From 6f3b0ec5353d819380d8e454c528232e82e1f22a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 16 Apr 2024 18:40:09 +0200
Subject: [PATCH 2/2] feedback review

---
 clang/lib/CodeGen/CGCall.cpp        | 2 +-
 clang/lib/CodeGen/CodeGenFunction.h | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 06d4bceacfd34b..95cf25f4a0b3c0 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4826,7 +4826,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   call->setCallingConv(getRuntimeCC());
 
   if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
-    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
+    return addControlledConvergenceToken(call);
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 12c5e71bf6af60..f348a1e866b6f6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4992,7 +4992,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   // convergence instructions.
   template <typename CallType>
   CallType *addControlledConvergenceToken(CallType *Input) {
-    return dyn_cast<CallType>(
+    return cast<CallType>(
         addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
   }
 



More information about the cfe-commits mailing list