[llvm] [SPIR-V] Add WaveGetLaneIndex() intrinsic support (PR #85979)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 10:55:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

<details>
<summary>Changes</summary>

Add support to generate valid SPIR-V for the WaveGetLaneIndex() HLSL builtin.

To implement this, I had to fix a few small issues in the backend, like the i8* pointer type being emitted, even if we have the type information elsewhere.

---
Full diff: https://github.com/llvm/llvm-project/pull/85979.diff


10 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+25-8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+16-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+8-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+4-5) 
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll (+68) 
- (modified) llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll (+4-5) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll (-2) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 07be0b34b18271..804c264e21e5ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -368,12 +368,10 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
 
 /// Helper function for building a load instruction for loading a builtin global
 /// variable of \p BuiltinValue value.
-static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
-                                         SPIRVType *VariableType,
-                                         SPIRVGlobalRegistry *GR,
-                                         SPIRV::BuiltIn::BuiltIn BuiltinValue,
-                                         LLT LLType,
-                                         Register Reg = Register(0)) {
+static Register buildBuiltinVariableLoad(
+    MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
+    SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
+    Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
   Register NewRegister =
       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
   MIRBuilder.getMRI()->setType(NewRegister,
@@ -385,8 +383,9 @@ static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
   // Set up the global OpVariable with the necessary builtin decorations.
   Register Variable = GR->buildGlobalVariable(
       NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
-      SPIRV::StorageClass::Input, nullptr, true, true,
-      SPIRV::LinkageType::Import, MIRBuilder, false);
+      SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
+      /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
+      false);
 
   // Load the value from the global variable.
   Register LoadedRegister =
@@ -1300,6 +1299,22 @@ static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
+static bool generateWaveInst(const SPIRV::IncomingCall *Call,
+                             MachineIRBuilder &MIRBuilder,
+                             SPIRVGlobalRegistry *GR) {
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  SPIRV::BuiltIn::BuiltIn Value =
+      SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
+
+  // For now, we only support a single Wave intrinsic with a single return type.
+  assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
+  LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
+
+  return buildBuiltinVariableLoad(
+      MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
+      /* isConst= */ false, /* hasLinkageTy= */ false);
+}
+
 static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
@@ -2187,6 +2202,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateBarrierInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Dot:
     return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::Wave:
+    return generateWaveInst(Call.get(), MIRBuilder, GR);
   case SPIRV::GetQuery:
     return generateGetQueryInst(Call.get(), MIRBuilder, GR);
   case SPIRV::ImageSizeQuery:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index eb26f70b1861f2..3fdfde625fbe9d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -41,6 +41,7 @@ def Variable : BuiltinGroup;
 def Atomic : BuiltinGroup;
 def Barrier : BuiltinGroup;
 def Dot : BuiltinGroup;
+def Wave : BuiltinGroup;
 def GetQuery : BuiltinGroup;
 def ImageSizeQuery : BuiltinGroup;
 def ImageMiscQuery : BuiltinGroup;
@@ -1119,6 +1120,7 @@ defm : DemangledGetBuiltin<"get_global_size", OpenCL_std, GetQuery, GlobalSize>;
 defm : DemangledGetBuiltin<"get_group_id", OpenCL_std, GetQuery, WorkgroupId>;
 defm : DemangledGetBuiltin<"get_enqueued_local_size", OpenCL_std, GetQuery, EnqueuedWorkgroupSize>;
 defm : DemangledGetBuiltin<"get_num_groups", OpenCL_std, GetQuery, NumWorkgroups>;
+defm : DemangledGetBuiltin<"__hlsl_wave_get_lane_index", GLSL_std_450, Wave, SubgroupLocalInvocationId>;
 
 //===----------------------------------------------------------------------===//
 // Class defining an image query builtin record used for lowering the OpenCL
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 6f23f055b8c2ab..afdca01561b0bc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -493,9 +493,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   Register ResVReg =
       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
-  // TODO: check that it's OCL builtin, then apply OpenCL_std.
-  if (!DemangledName.empty() && CF && CF->isDeclaration() &&
-      ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
+
+  bool isFunctionDecl = CF && CF->isDeclaration();
+  bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
+  bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
+  assert(canUseGLSL != canUseOpenCL &&
+         "Scenario where both sets are enabled is not supported.");
+
+  if (isFunctionDecl && !DemangledName.empty() &&
+      (canUseGLSL || canUseOpenCL)) {
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
@@ -504,12 +510,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
         GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
     }
-    if (auto Res = SPIRV::lowerBuiltin(
-            DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
-            ResVReg, OrigRetTy, ArgVRegs, GR))
+    auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
+                                       : SPIRV::InstructionSet::GLSL_std_450;
+    if (auto Res =
+            SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
+                                ResVReg, OrigRetTy, ArgVRegs, GR))
       return *Res;
   }
-  if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
+
+  if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
     // Emit the type info and forward function declaration to the first MBB
     // to ensure VReg definition dependencies are valid across all MBBs.
     MachineIRBuilder FirstBlockBuilder;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..f865853776a1b9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -721,11 +721,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     AddrSpace = PType->getAddressSpace();
   else
     report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
-  SPIRVType *SpvElementType;
-  // At the moment, all opaque pointers correspond to i8 element type.
-  // TODO: change the implementation once opaque pointers are supported
-  // in the SPIR-V specification.
-  SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+
+  SPIRVType *SpvElementType = nullptr;
+  if (auto PType = dyn_cast<TypedPointerType>(Ty))
+    SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
+                                          AccQual, EmitIR);
+  else
+    SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+
   // Get access to information about available extensions
   const SPIRVSubtarget *ST =
       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 00d0cbd763736d..40c3e5f9c6bdab 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -658,7 +658,7 @@ void RequirementHandler::initAvailableCapabilitiesForVulkan(
 
   // Provided by all supported Vulkan versions.
   addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
-                    Capability::Float64});
+                    Capability::Float64, Capability::GroupNonUniform});
 }
 
 } // namespace SPIRV
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index d547f91ba4a565..ea53f937d31982 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -186,7 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
       }
       case TargetOpcode::G_GLOBAL_VALUE: {
         MIB.setInsertPt(*MI->getParent(), MI);
-        Type *Ty = MI->getOperand(1).getGlobal()->getType();
+        const auto *Global = MI->getOperand(1).getGlobal();
+        auto *Ty = TypedPointerType::get(Global->getValueType(),
+                                         Global->getType()->getAddressSpace());
         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
         break;
       }
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index fc7502479fdcdd..c87c1293c622fc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -306,10 +306,12 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
+  bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
   bool IsMangled = Name.starts_with("_Z");
 
-  if (!IsNonMangledOCL && !IsNonMangledSPIRV && !IsMangled)
-    return std::string();
+  // Otherwise use simple demangling to return the function name.
+  if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
+    return Name.str();
 
   // Try to use the itanium demangler.
   if (char *DemangledName = itaniumDemangle(Name.data())) {
@@ -317,9 +319,6 @@ std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
     free(DemangledName);
     return Result;
   }
-  // Otherwise use simple demangling to return the function name.
-  if (IsNonMangledOCL || IsNonMangledSPIRV)
-    return Name.str();
 
   // Autocheck C++, maybe need to do explicit check of the source language.
   // OpenCL C++ built-ins are declared in cl namespace.
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
new file mode 100644
index 00000000000000..ec35690ac1547c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll
@@ -0,0 +1,68 @@
+; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; This file generated from the following command:
+; clang -cc1 -triple spirv-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -o - - <<EOF
+; [numthreads(1, 1, 1)]
+; void main() {
+;   int idx = WaveGetLaneIndex();
+; }
+; EOF
+
+; CHECK-DAG:                         OpCapability Shader
+; CHECK-DAG:                         OpCapability GroupNonUniform
+; CHECK-DAG:                         OpDecorate %[[#var:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:            %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG:           %[[#ptri:]] = OpTypePointer Input %[[#int]]
+; CHECK-DAG:           %[[#ptrf:]] = OpTypePointer Function %[[#int]]
+; CHECK-DAG:             %[[#var]] = OpVariable %[[#ptri]] Input
+
+; CHECK-NOT:                         OpDecorate %[[#var]] LinkageAttributes
+
+
+; ModuleID = '-'
+source_filename = "-"
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-compute"
+
+; Function Attrs: convergent noinline norecurse nounwind optnone
+define internal spir_func void @main() #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+; CHECK:     %[[#idx:]] = OpVariable %[[#ptrf]] Function
+
+  %1 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %0) ]
+; CHECK:    %[[#tmp:]] = OpLoad %[[#int]] %[[#var]]
+
+  store i32 %1, ptr %idx, align 4
+; CHECK:                 OpStore %[[#idx]] %[[#tmp]]
+
+  ret void
+}
+
+; Function Attrs: norecurse
+define void @main.1() #1 {
+entry:
+  call void @main()
+  ret void
+}
+
+; Function Attrs: convergent
+declare i32 @__hlsl_wave_get_lane_index() #2
+
+; Function Attrs: convergent nocallback nofree nosync nounwind willreturn memory(none)
+declare token @llvm.experimental.convergence.entry() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent }
+attributes #3 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (/usr/local/google/home/nathangauer/projects/llvm-project/clang bc6fd04b73a195981ee77823cf1382d04ab96c44)"}
+
diff --git a/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll b/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
index 329399bab3e5b9..2ea5c767730e19 100644
--- a/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
+++ b/llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll
@@ -1,5 +1,6 @@
 ; RUN: llc -mtriple=spirv-unknown-unknown -O0 %s -o - | FileCheck %s
 
+; CHECK-DAG:                  OpDecorate %[[#SubgroupLocalInvocationId:]] BuiltIn SubgroupLocalInvocationId
 ; CHECK-DAG:    %[[#bool:]] = OpTypeBool
 ; CHECK-DAG:    %[[#uint:]] = OpTypeInt 32 0
 ; CHECK-DAG:  %[[#uint_0:]] = OpConstant %[[#uint]] 0
@@ -37,10 +38,10 @@ l1_continue:
 ; CHECK-NEXT:                      OpBranch %[[#l1_header]]
 
 l1_end:
-  %call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
+  %call = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %tl1) ]
   br label %end
 ; CHECK-DAG:   %[[#l1_end]] = OpLabel
-; CHECK-DAG:         %[[#]] = OpFunctionCall
+; CHECK-DAG:         %[[#]] = OpLoad %[[#]] %[[#SubgroupLocalInvocationId]]
 ; CHECK-NEXT:                 OpBranch %[[#end:]]
 
 l2:
@@ -76,6 +77,4 @@ declare token @llvm.experimental.convergence.entry()
 declare token @llvm.experimental.convergence.control()
 declare token @llvm.experimental.convergence.loop()
 
-; This intrinsic is not convergent. This is only because the backend doesn't
-; support convergent operations yet.
-declare spir_func i32 @_Z3absi(i32) convergent
+declare i32 @__hlsl_wave_get_lane_index() convergent
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
index 3551030843d062..e0172ec3c1bdb7 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll
@@ -1,6 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
 ;
-; CHECK-SPIRV-DAG: %[[#i8:]] = OpTypeInt 8 0
 ; CHECK-SPIRV-DAG: %[[#i32:]] = OpTypeInt 32 0
 ; CHECK-SPIRV-DAG: %[[#one:]] = OpConstant %[[#i32]] 1
 ; CHECK-SPIRV-DAG: %[[#two:]] = OpConstant %[[#i32]] 2
@@ -13,7 +12,6 @@
 ; CHECK-SPIRV:     %[[#test_arr2:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
 ; CHECK-SPIRV:     %[[#test_arr:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
 
-; CHECK-SPIRV-DAG: %[[#const_i8_ptr:]] = OpTypePointer UniformConstant %[[#i8]]
 ; CHECK-SPIRV-DAG: %[[#i32x3_ptr:]] = OpTypePointer Function %[[#i32x3]]
 
 ; CHECK-SPIRV:     %[[#arr:]] = OpVariable %[[#i32x3_ptr]] Function

``````````

</details>


https://github.com/llvm/llvm-project/pull/85979


More information about the llvm-commits mailing list