[llvm] [clang] [clang][HLSL][SPRI-V] Add convergence intrinsics (PR #80680)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 5 05:37:43 PST 2024
https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/80680
HLSL has wave operations and other kind of function which required the control flow to either be converged, or respect certain constraints as where and how to re-converge.
At the HLSL level, the convergence are mostly obvious: the control flow is expected to re-converge at the end of a scope.
Once translated to IR, HLSL scopes disapear. This means we need a way to communicate convergence restrictions down to the backend.
For this, the SPIR-V backend uses convergence intrinsics. So this commit adds some code to generate convergence intrinsics when required.
This commit is not to be submitted as-is (lacks testing), but should serve as a basis for an upcoming RFC.
>From 8d653d1af6f624f341e88997682fc271195d8a45 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Fri, 2 Feb 2024 16:38:46 +0100
Subject: [PATCH] [clang][HLSL][SPRI-V] Add convergence intrinsics
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
HLSL has wave operations and other kind of function which required the
control flow to either be converged, or respect certain constraints as
where and how to re-converge.
At the HLSL level, the convergence are mostly obvious: the control flow
is expected to re-converge at the end of a scope.
Once translated to IR, HLSL scopes disapear. This means we need a way to
communicate convergence restrictions down to the backend.
For this, the SPIR-V backend uses convergence intrinsics. So this commit
adds some code to generate convergence intrinsics when required.
This commit is not to be submitted as-is (lacks testing), but
should serve as a basis for an upcoming RFC.
Signed-off-by: Nathan Gauër <brioche at google.com>
---
clang/lib/CodeGen/CGBuiltin.cpp | 102 +++++++++++++++++++++++++++
clang/lib/CodeGen/CGCall.cpp | 4 ++
clang/lib/CodeGen/CGLoopInfo.h | 8 ++-
clang/lib/CodeGen/CodeGenFunction.h | 19 +++++
llvm/include/llvm/IR/IntrinsicInst.h | 13 ++++
5 files changed, 145 insertions(+), 1 deletion(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f17e4a83305bf..0de350dc65485 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1129,8 +1129,97 @@ struct BitTest {
static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
};
+
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullopt otherwise.
+std::optional<llvm::IntrinsicInst *> getConvergenceToken(llvm::BasicBlock *BB) {
+ for (auto &I : *BB) {
+ auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+ if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
+ return II;
+ }
+ return std::nullopt;
+}
+
} // namespace
+llvm::CallBase *
+CodeGenFunction::AddConvergenceControlAttr(llvm::CallBase *Input,
+ llvm::Value *ParentToken) {
+ llvm::Value *bundleArgs[] = {ParentToken};
+ llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+ auto Output = llvm::CallBase::addOperandBundle(
+ Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+ Input->replaceAllUsesWith(Output);
+ Input->eraseFromParent();
+ return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::EmitConvergenceLoop(llvm::BasicBlock *BB,
+ llvm::Value *ParentToken) {
+ CGBuilderTy::InsertPoint IP = Builder.saveIP();
+ Builder.SetInsertPoint(&BB->front());
+ auto CB = Builder.CreateIntrinsic(
+ llvm::Intrinsic::experimental_convergence_loop, {}, {});
+ Builder.restoreIP(IP);
+
+ auto I = AddConvergenceControlAttr(CB, ParentToken);
+ // Controlled convergence is incompatible with uncontrolled convergence.
+ // Removing any old attributes.
+ I->setNotConvergent();
+
+ assert(isa<llvm::IntrinsicInst>(I));
+ return dyn_cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+ auto *BB = &F->getEntryBlock();
+ auto token = getConvergenceToken(BB);
+ if (token.has_value())
+ return token.value();
+
+ // Adding a convergence token requires the function to be marked as
+ // convergent.
+ F->setConvergent();
+
+ CGBuilderTy::InsertPoint IP = Builder.saveIP();
+ Builder.SetInsertPoint(&BB->front());
+ auto I = Builder.CreateIntrinsic(
+ llvm::Intrinsic::experimental_convergence_entry, {}, {});
+ assert(isa<llvm::IntrinsicInst>(I));
+ Builder.restoreIP(IP);
+
+ return dyn_cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
+ assert(LI != nullptr);
+
+ auto token = getConvergenceToken(LI->getHeader());
+ if (token.has_value())
+ return *token;
+
+ llvm::IntrinsicInst *PII =
+ LI->getParent()
+ ? EmitConvergenceLoop(LI->getHeader(),
+ getOrEmitConvergenceLoopToken(LI->getParent()))
+ : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
+
+ return EmitConvergenceLoop(LI->getHeader(), PII);
+}
+
+llvm::CallBase *
+CodeGenFunction::AddControlledConvergenceAttr(llvm::CallBase *Input) {
+ llvm::Value *ParentToken =
+ LoopStack.hasInfo()
+ ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
+ : getOrEmitConvergenceEntryToken(Input->getFunction());
+ return AddConvergenceControlAttr(Input, ParentToken);
+}
+
BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
switch (BuiltinID) {
// Main portable variants.
@@ -5692,6 +5781,19 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
{NDRange, Kernel, Block}));
}
+ case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
+ llvm::Type *BoolTy = llvm::IntegerType::get(getLLVMContext(), 1);
+ llvm::Value *Src0 = EmitScalarExpr(E->getArg(0));
+ auto *CI =
+ EmitRuntimeCall(CGM.CreateRuntimeFunction(
+ llvm::FunctionType::get(IntTy, {BoolTy}, false),
+ "__hlsl_wave_active_count_bits", {}),
+ {Src0});
+ if (getTarget().getTriple().isSPIRVLogical())
+ CI = dyn_cast<CallInst>(AddControlledConvergenceAttr(CI));
+ return RValue::get(CI);
+ }
+
case Builtin::BI__builtin_store_half:
case Builtin::BI__builtin_store_halff: {
Value *Val = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index bb61cf08bbfb4..27616db3d9ba0 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -5684,6 +5684,10 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
if (!CI->getType()->isVoidTy())
CI->setName("call");
+ if (getTarget().getTriple().isSPIRVLogical() &&
+ CI->getCalledFunction()->isConvergent())
+ CI = AddControlledConvergenceAttr(CI);
+
// Update largest vector width from the return type.
LargestVectorWidth =
std::max(LargestVectorWidth, getMaxVectorWidth(CI->getType()));
diff --git a/clang/lib/CodeGen/CGLoopInfo.h b/clang/lib/CodeGen/CGLoopInfo.h
index a1c8c7e5307fd..7c2f7443bd3c9 100644
--- a/clang/lib/CodeGen/CGLoopInfo.h
+++ b/clang/lib/CodeGen/CGLoopInfo.h
@@ -110,6 +110,10 @@ class LoopInfo {
/// been processed.
void finish();
+ /// Returns the first outer loop containing this loop if any, nullptr
+ /// otherwise.
+ const LoopInfo *getParent() const { return Parent; }
+
private:
/// Loop ID metadata.
llvm::TempMDTuple TempLoopID;
@@ -291,12 +295,14 @@ class LoopInfoStack {
/// Set no progress for the next loop pushed.
void setMustProgress(bool P) { StagedAttrs.MustProgress = P; }
-private:
/// Returns true if there is LoopInfo on the stack.
bool hasInfo() const { return !Active.empty(); }
+
/// Return the LoopInfo for the current loop. HasInfo should be called
/// first to ensure LoopInfo is present.
const LoopInfo &getInfo() const { return *Active.back(); }
+
+private:
/// The set of attributes that will be applied to the next pushed loop.
LoopAttributes StagedAttrs;
/// Stack of active loops.
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 143ad64e8816b..5299090ceada7 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4141,6 +4141,25 @@ class CodeGenFunction : public CodeGenTypeCache {
void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl);
void checkTargetFeatures(SourceLocation Loc, const FunctionDecl *TargetDecl);
+ // Adds a convergence_ctrl attribute to |Input| and emits the required parent
+ // convergence instructions.
+ llvm::CallBase *AddControlledConvergenceAttr(llvm::CallBase *Input);
+
+ // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
+ // as it's parent convergence instr.
+ llvm::IntrinsicInst *EmitConvergenceLoop(llvm::BasicBlock *BB,
+ llvm::Value *ParentToken);
+ // Adds a convergence_ctrl attribute with |ParentToken| as parent convergence
+ // instr to the call |Input|.
+ llvm::CallBase *AddConvergenceControlAttr(llvm::CallBase *Input,
+ llvm::Value *ParentToken);
+ // Find the convergence_entry instruction |F|, or emits ones if none exists.
+ // Returns the convergence instruction.
+ llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
+ // Find the convergence_loop instruction for the loop defined by |LI|, or
+ // emits one if none exists. Returns the convergence instruction.
+ llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
+
llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
const Twine &name = "");
llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index b8d578d0fee08..c0c19203b160a 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1746,6 +1746,19 @@ class ConvergenceControlInst : public IntrinsicInst {
static bool classof(const Value *V) {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
+
+ // Returns the convergence intrinsic referenced by |I|'s convergencectrl
+ // attribute if any.
+ static IntrinsicInst *getParentConvergenceToken(Instruction *I) {
+ auto *CI = dyn_cast<llvm::CallInst>(I);
+ if (!CI)
+ return nullptr;
+
+ auto Bundle = CI->getOperandBundle(llvm::LLVMContext::OB_convergencectrl);
+ assert(Bundle->Inputs.size() == 1 &&
+ Bundle->Inputs[0]->getType()->isTokenTy());
+ return dyn_cast<llvm::IntrinsicInst>(Bundle->Inputs[0].get());
+ }
};
} // end namespace llvm
More information about the llvm-commits
mailing list