[llvm] [SPIR-V] Validate and fix bit width of scalar registers (PR #95147)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 11 10:06:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR improves legalization process of SPIR-V instructions. Namely, it introduces validation and fixing of bit width of scalar registers as a part of pre-legalizer. A test case is added that demonstrates ability to legalize instructions with non 8/16/32/64 bit width both with and without vendor-specific SPIR-V extension (SPV_INTEL_arbitrary_precision_integers). In the case of absence of the extension, a generated SPIR-V code will fallback to 8/16/32/64 bit width in OpTypeInt, but SPIR-V Backend still is able to legalize operations with original integer sizes.

---
Full diff: https://github.com/llvm/llvm-project/pull/95147.diff


2 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+20-5) 
- (added) llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll (+56) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index adc5b36af6f18..aaba6e873e2c1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -271,6 +271,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
   return SpirvTy;
 }
 
+// To support current approach and limitations wrt. bit width here we widen a
+// scalar register with a bit width greater than 1 to valid sizes and cap it to
+// 64 width.
+static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
+  LLT RegType = MRI.getType(Reg);
+  if (!RegType.isScalar())
+    return;
+  unsigned Sz = RegType.getScalarSizeInBits();
+  if (Sz == 1)
+    return;
+  unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
+  if (NewSz != Sz)
+    MRI.setType(Reg, LLT::scalar(NewSz));
+}
+
 static std::pair<Register, unsigned>
 createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
                const SPIRVGlobalRegistry &GR) {
@@ -406,6 +421,11 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       MachineInstr &MI = *MII;
       unsigned MIOp = MI.getOpcode();
 
+      // validate bit width of scalar registers
+      for (const auto& MOP : MI.operands())
+        if (MOP.isReg())
+          widenScalarLLTNextPow2(MOP.getReg(), MRI);
+
       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
         Register Reg = MI.getOperand(1).getReg();
         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
@@ -475,11 +495,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
       } else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
         propagateSPIRVType(&MI, GR, MRI, MIB);
-      } else if (MIOp == TargetOpcode::G_BITREVERSE) {
-        Register Reg = MI.getOperand(0).getReg();
-        LLT RegType = MRI.getType(Reg);
-        if (RegType.getSizeInBits() < 32)
-          MRI.setType(Reg, LLT::scalar(32));
       }
 
       if (MII == Begin)
diff --git a/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll b/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
new file mode 100644
index 0000000000000..437e161864eca
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
@@ -0,0 +1,56 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXT
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXT
+
+; CHECK-DAG: OpName %[[#Struct:]] "struct"
+; CHECK-DAG: OpName %[[#Arg:]] "arg"
+; CHECK-DAG: OpName %[[#QArg:]] "qarg"
+; CHECK-DAG: OpName %[[#R:]] "r"
+; CHECK-DAG: OpName %[[#Q:]] "q"
+; CHECK-DAG: OpName %[[#Tr:]] "tr"
+; CHECK-DAG: OpName %[[#Tq:]] "tq"
+; CHECK-DAG: %[[#Struct]] = OpTypeStruct %[[#]] %[[#]] %[[#]]
+; CHECK-DAG: %[[#PtrStruct:]] = OpTypePointer CrossWorkgroup %[[#Struct]]
+; CHECK-EXT-DAG: %[[#Int40:]] = OpTypeInt 40 0
+; CHECK-EXT-DAG: %[[#Int50:]] = OpTypeInt 50 0
+; CHECK-NOEXT-DAG: %[[#Int40:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#PtrInt40:]] = OpTypePointer CrossWorkgroup %[[#Int40]]
+
+; CHECK: OpFunction
+
+; CHECK-EXT: %[[#Tr]] = OpUConvert %[[#Int40]] %[[#R]]
+; CHECK-EXT: %[[#Store:]] = OpInBoundsPtrAccessChain %[[#PtrStruct]] %[[#Arg]] %[[#]]
+; CHECK-EXT: %[[#StoreAsInt40:]] = OpBitcast %[[#PtrInt40]] %[[#Store]]
+; CHECK-EXT: OpStore %[[#StoreAsInt40]] %[[#Tr]]
+
+; CHECK-NOEXT: %[[#Store:]] = OpInBoundsPtrAccessChain %[[#PtrStruct]] %[[#Arg]] %[[#]]
+; CHECK-NOEXT: %[[#StoreAsInt40:]] = OpBitcast %[[#PtrInt40]] %[[#Store]]
+; CHECK-NOEXT: OpStore %[[#StoreAsInt40]] %[[#R]]
+
+; CHECK: OpFunction
+
+; CHECK-EXT: %[[#Tq]] = OpUConvert %[[#Int40]] %[[#Q]]
+; CHECK-EXT: OpStore %[[#QArg]] %[[#Tq]]
+
+; CHECK-NOEXT: OpStore %[[#QArg]] %[[#Q]]
+
+%struct = type <{ i32, i8, [3 x i8] }>
+
+define spir_kernel void @foo(ptr addrspace(1) %arg, i64 %r) {
+  %tr = trunc i64 %r to i40
+  %addr = getelementptr inbounds %struct, ptr addrspace(1) %arg, i64 0
+  store i40 %tr, ptr addrspace(1) %addr
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %qarg, i50 %q) {
+  %tq = trunc i50 %q to i40
+  store i40 %tq, ptr addrspace(1) %qarg
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/95147


More information about the llvm-commits mailing list