[llvm] [SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (PR #94219)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 06:58:11 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification `SPIRVGlobalRegistry::getOpTypeInt()` rounds integer sizes other than 8,16,32,64 up, to one of defined by the specification sizes. For the `DuplicateTracker` class this means that several original LLVM types (e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaks `DuplicateTracker`'s logic and leads to generation of invalid SPIR-V code eventually.

For example,

```
define spir_func void @<!-- -->foo(i2 %a, i4 %b) {
entry:
  %res2 = tail call i2 @<!-- -->llvm.bitreverse.i2(i2 %a)
  %res4 = tail call i4 @<!-- -->llvm.bitreverse.i4(i4 %b)
  ret void
}

declare i2 @<!-- -->llvm.bitreverse.i2(i2)
declare i4 @<!-- -->llvm.bitreverse.i4(i4)
```

after translation to SPIR-V would fail during validation (`spirv-val`) due to two `OpTypeInt 8 0` instructions.

This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code.

---
Full diff: https://github.com/llvm/llvm-project/pull/94219.diff


5 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+53-14) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+5-1) 
- (added) llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll (+17) 
- (added) llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll (+24) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll (+4-1) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..2f2d5efc5e3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,19 +90,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
       .addDef(createTypeVReg(MIRBuilder));
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
-                                             MachineIRBuilder &MIRBuilder,
-                                             bool IsSigned) {
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
   assert(Width <= 64 && "Unsupported integer width!");
-  const SPIRVSubtarget &ST =
-      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
-    MIRBuilder.buildInstr(SPIRV::OpExtension)
-        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
-    MIRBuilder.buildInstr(SPIRV::OpCapability)
-        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
-  } else if (Width <= 8)
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+    return Width;
+  if (Width <= 8)
     Width = 8;
   else if (Width <= 16)
     Width = 16;
@@ -110,7 +104,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
     Width = 32;
   else if (Width <= 64)
     Width = 64;
+  return Width;
+}
 
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
+                                             MachineIRBuilder &MIRBuilder,
+                                             bool IsSigned) {
+  Width = adjustOpTypeIntWidth(Width);
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  if (ST.canUseExtension(
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+    MIRBuilder.buildInstr(SPIRV::OpExtension)
+        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+    MIRBuilder.buildInstr(SPIRV::OpCapability)
+        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+  }
   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
                  .addDef(createTypeVReg(MIRBuilder))
                  .addImm(Width)
@@ -800,6 +809,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+  Ty = adjustIntTypeByWidth(Ty);
   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
@@ -815,6 +825,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
   return SpirvType->defs().begin()->getReg();
 }
 
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+  if (auto IType = dyn_cast<IntegerType>(Ty)) {
+    unsigned SrcBitWidth = IType->getBitWidth();
+    if (SrcBitWidth > 1) {
+      unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+      // Maybe change source LLVM type to keep DuplicateTracker consistent.
+      if (SrcBitWidth != BitWidth)
+        Ty = IntegerType::get(Ty->getContext(), BitWidth);
+    }
+  }
+  return Ty;
+}
+
 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +973,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
   Register Reg;
-  if (!isPointerTy(Ty))
+  if (!isPointerTy(Ty)) {
+    Ty = adjustIntTypeByWidth(Ty);
     Reg = DT.find(Ty, &MIRBuilder.getMF());
-  else if (isTypedPointerTy(Ty))
+  } else if (isTypedPointerTy(Ty)) {
     Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
                   getPointerAddressSpace(Ty), &MIRBuilder.getMF());
-  else
+  } else {
     Reg =
         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
                 getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+  }
 
   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
     return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1291,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+  // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+  // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+  // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+  // with number of bits less than 8, causing duplicate type definitions.
+  BitWidth = adjustOpTypeIntWidth(BitWidth);
   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
 }
+
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
   LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
 #include "llvm/IR/TypedPointerType.h"
 
 namespace llvm {
+class SPIRVSubtarget;
 using SPIRVType = const MachineInstr;
 
 class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
-  SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+  const Type *adjustIntTypeByWidth(const Type *Ty) const;
+  unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+  SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..92045cc6d7619
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,24 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+  %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+  %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+  ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1ffb762aafaa6 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,10 @@
 ; CHECK-SPIRV: OpCapability BitInstructions
 ; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
 ; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
 
 define spir_func signext i2 @foo(i2 noundef signext %a) {
 entry:

``````````

</details>


https://github.com/llvm/llvm-project/pull/94219


More information about the llvm-commits mailing list