[llvm] [HLSL][SPIR-V] Add SV_DispatchThreadID semantic support (PR #82536)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 13:22:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Natalie Chouinard (sudonatalie)

<details>
<summary>Changes</summary>

Add SPIR-V backend support for the HLSL SV_DispatchThreadID semantic attribute, which is lowered to a @<!-- -->llvm.dx.thread.id intrinsic in LLVM IR. In the SPIR-V backend, this is now correctly translated to a `GlobalInvocationId` builtin variable.

Fixes #<!-- -->82534

---
Full diff: https://github.com/llvm/llvm-project/pull/82536.diff


3 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+3-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+69) 
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll (+76) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 47fec745c3f18a..91562364383ab3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -525,7 +525,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
 
   // Output decorations for the GV.
   // TODO: maybe move to GenerateDecorations pass.
-  if (IsConst)
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  if (IsConst && ST.isOpenCLEnv())
     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
 
   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 52eeb8a523e6f6..751ecf9e9840cf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
 
@@ -182,6 +183,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectDXThreadId(Register ResVReg, const SPIRVType *ResType,
+                        MachineInstr &I) const;
+
   Register buildI32Constant(uint32_t Val, MachineInstr &I,
                             const SPIRVType *ResType = nullptr) const;
 
@@ -284,6 +288,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_IMPLICIT_DEF:
     return selectOpUndef(ResVReg, ResType, I);
 
+  case TargetOpcode::G_INTRINSIC:
   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
   case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
     return selectIntrinsic(ResVReg, ResType, I);
@@ -1427,6 +1432,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
           .addUse(I.getOperand(2).getReg())
           .addUse(I.getOperand(3).getReg());
     break;
+  case Intrinsic::dx_thread_id:
+    return selectDXThreadId(ResVReg, ResType, I);
   default:
     llvm_unreachable("Intrinsic selection not implemented");
   }
@@ -1660,6 +1667,68 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectDXThreadId(Register ResVReg,
+                                                const SPIRVType *ResType,
+                                                MachineInstr &I) const {
+  // DX intrinsic: @llvm.dx.thread.id(i32)
+  // ID  Name      Description
+  // 93  ThreadId  reads the thread ID
+
+  MachineIRBuilder MIRBuilder(I);
+  const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
+  const SPIRVType *Vec3Ty =
+      GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
+  const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
+      Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
+
+  // Create new register for GlobalInvocationID builtin variable.
+  Register NewRegister =
+      MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 32));
+  GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
+
+  // Build GlobalInvocationID global variable with the necessary decorations.
+  Register Variable = GR.buildGlobalVariable(
+      NewRegister, PtrType,
+      getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
+      SPIRV::StorageClass::Input, nullptr, true, true,
+      SPIRV::LinkageType::Import, MIRBuilder, false);
+
+  // Create new register for loading value.
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  Register LoadedRegister = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setType(LoadedRegister, LLT::pointer(0, 32));
+  GR.assignSPIRVTypeToVReg(Vec3Ty, LoadedRegister, MIRBuilder.getMF());
+
+  // Load v3uint value from the global variable.
+  BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
+      .addDef(LoadedRegister)
+      .addUse(GR.getSPIRVTypeID(Vec3Ty))
+      .addUse(Variable);
+
+  // Get Thread ID index. Expecting operand is a constant immediate value,
+  // wrapped in a type assignment.
+  assert(I.getOperand(2).isReg());
+  Register ThreadIdReg = I.getOperand(2).getReg();
+  SPIRVType *ConstTy = this->MRI->getVRegDef(ThreadIdReg);
+  assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
+         ConstTy->getOperand(1).isReg());
+  Register ConstReg = ConstTy->getOperand(1).getReg();
+  const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
+  assert(Const && Const->getOpcode() == TargetOpcode::G_CONSTANT);
+  const llvm::APInt &Val = Const->getOperand(1).getCImm()->getValue();
+  const uint32_t ThreadId = Val.getZExtValue();
+
+  // Extract the thread ID from the loaded vector value.
+  MachineBasicBlock &BB = *I.getParent();
+  auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType))
+                 .addUse(LoadedRegister)
+                 .addImm(ThreadId);
+  return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
 namespace llvm {
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll
new file mode 100644
index 00000000000000..4915c0d3277075
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll
@@ -0,0 +1,76 @@
+; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; This file generated from the following HLSL:
+; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header -o - DispatchThreadID.hlsl
+;
+; [shader("compute")]
+; [numthreads(1,1,1)]
+; void main(uint3 ID : SV_DispatchThreadID) {}
+
+; CHECK-DAG:        %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG:        %[[#v3int:]] = OpTypeVector %[[#int]] 3
+; CHECK-DAG:        %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]]
+; CHECK-DAG:        %[[#tempvar:]] = OpUndef %[[#v3int]]
+; CHECK-DAG:        %[[#GlobalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input
+
+; CHECK-DAG:        OpEntryPoint GLCompute {{.*}} %[[#GlobalInvocationId]]
+; CHECK-DAG:        OpName %[[#GlobalInvocationId]] "__spirv_BuiltInGlobalInvocationId"
+; CHECK-DAG:        OpDecorate %[[#GlobalInvocationId]] LinkageAttributes "__spirv_BuiltInGlobalInvocationId" Import
+; CHECK-DAG:        OpDecorate %[[#GlobalInvocationId]] BuiltIn GlobalInvocationId
+
+; ModuleID = 'DispatchThreadID.hlsl'
+source_filename = "DispatchThreadID.hlsl"
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-library"
+
+; Function Attrs: noinline norecurse nounwind optnone
+define internal spir_func void @main(<3 x i32> noundef %ID) #0 {
+entry:
+  %ID.addr = alloca <3 x i32>, align 16
+  store <3 x i32> %ID, ptr %ID.addr, align 16
+  ret void
+}
+
+; Function Attrs: norecurse
+define void @main.1() #1 {
+entry:
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
+; CHECK:        %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0
+  %0 = call i32 @llvm.dx.thread.id(i32 0)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0
+  %1 = insertelement <3 x i32> poison, i32 %0, i64 0
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
+; CHECK:        %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1
+  %2 = call i32 @llvm.dx.thread.id(i32 1)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1
+  %3 = insertelement <3 x i32> %1, i32 %2, i64 1
+
+; CHECK:        %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
+; CHECK:        %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2
+  %4 = call i32 @llvm.dx.thread.id(i32 2)
+
+; CHECK:        %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2
+  %5 = insertelement <3 x i32> %3, i32 %4, i64 2
+
+  call void @main(<3 x i32> %5)
+  ret void
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare i32 @llvm.dx.thread.id(i32) #2
+
+attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
+!2 = !{!"clang version 19.0.0git (git at github.com:llvm/llvm-project.git c9afeaa6434a61b3b3a57c8eda6d2cfb25ab675b)"}

``````````

</details>


https://github.com/llvm/llvm-project/pull/82536


More information about the llvm-commits mailing list