[llvm] [SPIR-V] Add WaveGetLaneIndex() intrinsic support (PR #85979)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 20 10:55:17 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Nathan Gauër (Keenuts)
<details>
<summary>Changes</summary>
Add support to generate valid SPIR-V for the WaveGetLaneIndex() HLSL builtin.
To implement this, I had to fix a few small issues in the backend, like the i8* pointer type being emitted, even if we have the type information elsewhere.
---
Full diff: https://github.com/llvm/llvm-project/pull/85979.diff
10 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+25-8)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+16-7)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+8-5)
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+1-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+4-5)
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll (+68)
- (modified) llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll (+4-5)
- (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll (-2)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 07be0b34b18271..804c264e21e5ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -368,12 +368,10 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
/// Helper function for building a load instruction for loading a builtin global
/// variable of \p BuiltinValue value.
-static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
- SPIRVType *VariableType,
- SPIRVGlobalRegistry *GR,
- SPIRV::BuiltIn::BuiltIn BuiltinValue,
- LLT LLType,
- Register Reg = Register(0)) {
+static Register buildBuiltinVariableLoad(
+ MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
+ SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
+ Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
Register NewRegister =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
MIRBuilder.getMRI()->setType(NewRegister,
@@ -385,8 +383,9 @@ static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
// Set up the global OpVariable with the necessary builtin decorations.
Register Variable = GR->buildGlobalVariable(
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
- SPIRV::StorageClass::Input, nullptr, true, true,
- SPIRV::LinkageType::Import, MIRBuilder, false);
+ SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
+ /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
+ false);
// Load the value from the global variable.
Register LoadedRegister =
@@ -1300,6 +1299,22 @@ static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
return true;
}
+static bool generateWaveInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+ SPIRV::BuiltIn::BuiltIn Value =
+ SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
+
+ // For now, we only support a single Wave intrinsic with a single return type.
+ assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
+ LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
+
+ return buildBuiltinVariableLoad(
+ MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
+ /* isConst= */ false, /* hasLinkageTy= */ false);
+}
+
static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -2187,6 +2202,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateBarrierInst(Call.get(), MIRBuilder, GR);
case SPIRV::Dot:
return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::Wave:
+ return generateWaveInst(Call.get(), MIRBuilder, GR);
case SPIRV::GetQuery:
return generateGetQueryInst(Call.get(), MIRBuilder, GR);
case SPIRV::ImageSizeQuery:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index eb26f70b1861f2..3fdfde625fbe9d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -41,6 +41,7 @@ def Variable : BuiltinGroup;
def Atomic : BuiltinGroup;
def Barrier : BuiltinGroup;
def Dot : BuiltinGroup;
+def Wave : BuiltinGroup;
def GetQuery : BuiltinGroup;
def ImageSizeQuery : BuiltinGroup;
def ImageMiscQuery : BuiltinGroup;
@@ -1119,6 +1120,7 @@ defm : DemangledGetBuiltin<"get_global_size", OpenCL_std, GetQuery, GlobalSize>;
defm : DemangledGetBuiltin<"get_group_id", OpenCL_std, GetQuery, WorkgroupId>;
defm : DemangledGetBuiltin<"get_enqueued_local_size", OpenCL_std, GetQuery, EnqueuedWorkgroupSize>;
defm : DemangledGetBuiltin<"get_num_groups", OpenCL_std, GetQuery, NumWorkgroups>;
+defm : DemangledGetBuiltin<"__hlsl_wave_get_lane_index", GLSL_std_450, Wave, SubgroupLocalInvocationId>;
//===----------------------------------------------------------------------===//
// Class defining an image query builtin record used for lowering the OpenCL
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 6f23f055b8c2ab..afdca01561b0bc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -493,9 +493,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
- // TODO: check that it's OCL builtin, then apply OpenCL_std.
- if (!DemangledName.empty() && CF && CF->isDeclaration() &&
- ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
+
+ bool isFunctionDecl = CF && CF->isDeclaration();
+ bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
+ bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
+ assert(canUseGLSL != canUseOpenCL &&
+ "Scenario where both sets are enabled is not supported.");
+
+ if (isFunctionDecl && !DemangledName.empty() &&
+ (canUseGLSL || canUseOpenCL)) {
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
@@ -504,12 +510,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
}
- if (auto Res = SPIRV::lowerBuiltin(
- DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
- ResVReg, OrigRetTy, ArgVRegs, GR))
+ auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
+ : SPIRV::InstructionSet::GLSL_std_450;
+ if (auto Res =
+ SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
+ ResVReg, OrigRetTy, ArgVRegs, GR))
return *Res;
}
- if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
+
+ if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
// Emit the type info and forward function declaration to the first MBB
// to ensure VReg definition dependencies are valid across all MBBs.
MachineIRBuilder FirstBlockBuilder;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..f865853776a1b9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -721,11 +721,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
AddrSpace = PType->getAddressSpace();
else
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
- SPIRVType *SpvElementType;
- // At the moment, all opaque pointers correspond to i8 element type.
- // TODO: change the implementation once opaque pointers are supported
- // in the SPIR-V specification.
- SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+
+ SPIRVType *SpvElementType = nullptr;
+ if (auto PType = dyn_cast<TypedPointerType>(Ty))
+ SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
+ AccQual, EmitIR);
+ else
+ SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 00d0cbd763736d..40c3e5f9c6bdab 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -658,7 +658,7 @@ void RequirementHandler::initAvailableCapabilitiesForVulkan(
// Provided by all supported Vulkan versions.
addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
- Capability::Float64});
+ Capability::Float64, Capability::GroupNonUniform});
}
} // namespace SPIRV
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index d547f91ba4a565..ea53f937d31982 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -186,7 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
}
case TargetOpcode::G_GLOBAL_VALUE: {
MIB.setInsertPt(*MI->getParent(), MI);
- Type *Ty = MI->getOperand(1).getGlobal()->getType();
+ const auto *Global = MI->getOperand(1).getGlobal();
+ auto *Ty = TypedPointerType::get(Global->getValueType(),
+ Global->getType()->getAddressSpace());
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index fc7502479fdcdd..c87c1293c622fc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -306,10 +306,12 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
+ bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
bool IsMangled = Name.starts_with("_Z");
- if (!IsNonMangledOCL && !IsNonMangledSPIRV && !IsMangled)
- return std::string();
+ // Otherwise use simple demangling to return the function name.
+ if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
+ return Name.str();
// Try to use the itanium demangler.
if (char *DemangledName = itaniumDemangle(Name.data())) {
@@ -317,9 +319,6 @@ std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
free(DemangledName);
return Result;
}
- // Otherwise use simple demangling to return the function name.
- if (IsNonMangledOCL || IsNonMangledSPIRV)
- return Name.str();
// Autocheck C++, maybe need to do explicit check of the source language.
// OpenCL C++ built-ins are declared in cl namespace.
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
new file mode 100644
index 00000000000000..ec35690ac1547c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
@@ -0,0 +1,68 @@
+; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; This file generated from the following command:
+; clang -cc1 -triple spirv-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -o - - <<EOF
+; [numthreads(1, 1, 1)]
+; void main() {
+; int idx = WaveGetLaneIndex();
+; }
+; EOF
+
+; CHECK-DAG: OpCapability Shader
+; CHECK-DAG: OpCapability GroupNonUniform
+; CHECK-DAG: OpDecorate %[[#var:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#ptri:]] = OpTypePointer Input %[[#int]]
+; CHECK-DAG: %[[#ptrf:]] = OpTypePointer Function %[[#int]]
+; CHECK-DAG: %[[#var]] = OpVariable %[[#ptri]] Input
+
+; CHECK-NOT: OpDecorate %[[#var]] LinkageAttributes
+
+
+; ModuleID = '-'
+source_filename = "-"
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-compute"
+
+; Function Attrs: convergent noinline norecurse nounwind optnone
+define internal spir_func void @main() #0 {
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+ %idx = alloca i32, align 4
+; CHECK: %[[#idx:]] = OpVariable %[[#ptrf]] Function
+
+ %1 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %0) ]
+; CHECK: %[[#tmp:]] = OpLoad %[[#int]] %[[#var]]
+
+ store i32 %1, ptr %idx, align 4
+; CHECK: OpStore %[[#idx]] %[[#tmp]]
+
+ ret void
+}
+
+; Function Attrs: norecurse
+define void @main.1() #1 {
+entry:
+ call void @main()
+ ret void
+}
+
+; Function Attrs: convergent
+declare i32 @__hlsl_wave_get_lane_index() #2
+
+; Function Attrs: convergent nocallback nofree nosync nounwind willreturn memory(none)
+declare token @llvm.experimental.convergence.entry() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent }
+attributes #3 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (/usr/local/google/home/nathangauer/projects/llvm-project/clang bc6fd04b73a195981ee77823cf1382d04ab96c44)"}
+
diff --git a/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll b/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
index 329399bab3e5b9..2ea5c767730e19 100644
--- a/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
+++ b/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
@@ -1,5 +1,6 @@
; RUN: llc -mtriple=spirv-unknown-unknown -O0 %s -o - | FileCheck %s
+; CHECK-DAG: OpDecorate %[[#SubgroupLocalInvocationId:]] BuiltIn SubgroupLocalInvocationId
; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
@@ -37,10 +38,10 @@ l1_continue:
; CHECK-NEXT: OpBranch %[[#l1_header]]
l1_end:
- %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+ %call = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %tl1) ]
br label %end
; CHECK-DAG: %[[#l1_end]] = OpLabel
-; CHECK-DAG: %[[#]] = OpFunctionCall
+; CHECK-DAG: %[[#]] = OpLoad %[[#]] %[[#SubgroupLocalInvocationId]]
; CHECK-NEXT: OpBranch %[[#end:]]
l2:
@@ -76,6 +77,4 @@ declare token @llvm.experimental.convergence.entry()
declare token @llvm.experimental.convergence.control()
declare token @llvm.experimental.convergence.loop()
-; This intrinsic is not convergent. This is only because the backend doesn't
-; support convergent operations yet.
-declare spir_func i32 @_Z3absi(i32) convergent
+declare i32 @__hlsl_wave_get_lane_index() convergent
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
index 3551030843d062..e0172ec3c1bdb7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
@@ -1,6 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
;
-; CHECK-SPIRV-DAG: %[[#i8:]] = OpTypeInt 8 0
; CHECK-SPIRV-DAG: %[[#i32:]] = OpTypeInt 32 0
; CHECK-SPIRV-DAG: %[[#one:]] = OpConstant %[[#i32]] 1
; CHECK-SPIRV-DAG: %[[#two:]] = OpConstant %[[#i32]] 2
@@ -13,7 +12,6 @@
; CHECK-SPIRV: %[[#test_arr2:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
; CHECK-SPIRV: %[[#test_arr:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
-; CHECK-SPIRV-DAG: %[[#const_i8_ptr:]] = OpTypePointer UniformConstant %[[#i8]]
; CHECK-SPIRV-DAG: %[[#i32x3_ptr:]] = OpTypePointer Function %[[#i32x3]]
; CHECK-SPIRV: %[[#arr:]] = OpVariable %[[#i32x3_ptr]] Function
``````````
</details>
https://github.com/llvm/llvm-project/pull/85979
More information about the llvm-commits
mailing list