[clang] df67e37 - [clang][NFC] clean up the handling of convergence control tokens (#121738)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Jan 6 08:04:15 PST 2025
Author: Sameer Sahasrabuddhe
Date: 2025-01-06T21:34:11+05:30
New Revision: df67e37e37a7862e1e67f52e01f0c9a019477930
URL: https://github.com/llvm/llvm-project/commit/df67e37e37a7862e1e67f52e01f0c9a019477930
DIFF: https://github.com/llvm/llvm-project/commit/df67e37e37a7862e1e67f52e01f0c9a019477930.diff
LOG: [clang][NFC] clean up the handling of convergence control tokens (#121738)
Added:
Modified:
clang/lib/CodeGen/CGCall.cpp
clang/lib/CodeGen/CGStmt.cpp
clang/lib/CodeGen/CodeGenFunction.h
Removed:
################################################################################
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f139c30f3dfd44..89e2eace9120bf 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4871,7 +4871,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
call->setCallingConv(getRuntimeCC());
if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
- return addControlledConvergenceToken(call);
+ return cast<llvm::CallInst>(addConvergenceControlToken(call));
return call;
}
@@ -5787,7 +5787,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
CI->setName("call");
if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
- CI = addControlledConvergenceToken(CI);
+ CI = addConvergenceControlToken(CI);
// Update largest vector width from the return type.
LargestVectorWidth =
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 3974739d2abb47..7904e17dbebb81 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -1024,8 +1024,8 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
EmitBlock(LoopHeader.getBlock());
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
- LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(
+ emitConvergenceLoopToken(LoopHeader.getBlock()));
// Create an exit block for when the condition fails, which will
// also become the break target.
@@ -1152,8 +1152,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
EmitBlockWithFallThrough(LoopBody, &S);
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(
- emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(emitConvergenceLoopToken(LoopBody));
{
RunCleanupsScope BodyScope(*this);
@@ -1231,8 +1230,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
EmitBlock(CondBlock);
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(
- emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(emitConvergenceLoopToken(CondBlock));
const SourceRange &R = S.getSourceRange();
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
@@ -1369,8 +1367,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
EmitBlock(CondBlock);
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(
- emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(emitConvergenceLoopToken(CondBlock));
const SourceRange &R = S.getSourceRange();
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
@@ -3245,35 +3242,32 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
return F;
}
-namespace {
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+static llvm::ConvergenceControlInst *getConvergenceToken(llvm::BasicBlock *BB) {
for (auto &I : *BB) {
- auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
- if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
- return II;
+ if (auto *CI = dyn_cast<llvm::ConvergenceControlInst>(&I))
+ return CI;
}
return nullptr;
}
-} // namespace
-
llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
- llvm::Value *ParentToken) {
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input) {
+ llvm::ConvergenceControlInst *ParentToken = ConvergenceTokenStack.back();
+ assert(ParentToken);
+
llvm::Value *bundleArgs[] = {ParentToken};
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
- auto Output = llvm::CallBase::addOperandBundle(
+ auto *Output = llvm::CallBase::addOperandBundle(
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input->getIterator());
Input->replaceAllUsesWith(Output);
Input->eraseFromParent();
return Output;
}
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
- llvm::Value *ParentToken) {
+llvm::ConvergenceControlInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB) {
CGBuilderTy::InsertPoint IP = Builder.saveIP();
if (BB->empty())
Builder.SetInsertPoint(BB);
@@ -3284,14 +3278,14 @@ CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
llvm::Intrinsic::experimental_convergence_loop, {}, {});
Builder.restoreIP(IP);
- llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
- return cast<llvm::IntrinsicInst>(I);
+ CB = addConvergenceControlToken(CB);
+ return cast<llvm::ConvergenceControlInst>(CB);
}
-llvm::IntrinsicInst *
+llvm::ConvergenceControlInst *
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
llvm::BasicBlock *BB = &F->getEntryBlock();
- llvm::IntrinsicInst *Token = getConvergenceToken(BB);
+ llvm::ConvergenceControlInst *Token = getConvergenceToken(BB);
if (Token)
return Token;
@@ -3306,5 +3300,5 @@ CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
assert(isa<llvm::IntrinsicInst>(I));
Builder.restoreIP(IP);
- return cast<llvm::IntrinsicInst>(I);
+ return cast<llvm::ConvergenceControlInst>(I);
}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1a5c42f8f974d0..46f267983edad1 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -315,7 +315,7 @@ class CodeGenFunction : public CodeGenTypeCache {
SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
/// Stack to track the controlled convergence tokens.
- SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+ SmallVector<llvm::ConvergenceControlInst *, 4> ConvergenceTokenStack;
/// Number of nested loop to be consumed by the last surrounding
/// loop-associated directive.
@@ -5234,29 +5234,20 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *emitBoolVecConversion(llvm::Value *SrcVec,
unsigned NumElementsDst,
const llvm::Twine &Name = "");
- // Adds a convergence_ctrl token to |Input| and emits the required parent
- // convergence instructions.
- template <typename CallType>
- CallType *addControlledConvergenceToken(CallType *Input) {
- return cast<CallType>(
- addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
- }
private:
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
// as it's parent convergence instr.
- llvm::IntrinsicInst *emitConvergenceLoopToken(llvm::BasicBlock *BB,
- llvm::Value *ParentToken);
+ llvm::ConvergenceControlInst *emitConvergenceLoopToken(llvm::BasicBlock *BB);
+
// Adds a convergence_ctrl token with |ParentToken| as parent convergence
// instr to the call |Input|.
- llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input,
- llvm::Value *ParentToken);
+ llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input);
+
// Find the convergence_entry instruction |F|, or emits ones if none exists.
// Returns the convergence instruction.
- llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
- // Find the convergence_loop instruction for the loop defined by |LI|, or
- // emits one if none exists. Returns the convergence instruction.
- llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
+ llvm::ConvergenceControlInst *
+ getOrEmitConvergenceEntryToken(llvm::Function *F);
private:
llvm::MDNode *getRangeForLoadFromType(QualType Ty);
More information about the cfe-commits
mailing list