[llvm] [SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (PR #94219)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 3 06:58:11 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification `SPIRVGlobalRegistry::getOpTypeInt()` rounds integer sizes other than 8,16,32,64 up, to one of defined by the specification sizes. For the `DuplicateTracker` class this means that several original LLVM types (e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaks `DuplicateTracker`'s logic and leads to generation of invalid SPIR-V code eventually.
For example,
```
define spir_func void @<!-- -->foo(i2 %a, i4 %b) {
entry:
%res2 = tail call i2 @<!-- -->llvm.bitreverse.i2(i2 %a)
%res4 = tail call i4 @<!-- -->llvm.bitreverse.i4(i4 %b)
ret void
}
declare i2 @<!-- -->llvm.bitreverse.i2(i2)
declare i4 @<!-- -->llvm.bitreverse.i4(i4)
```
after translation to SPIR-V would fail during validation (`spirv-val`) due to two `OpTypeInt 8 0` instructions.
This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code.
---
Full diff: https://github.com/llvm/llvm-project/pull/94219.diff
5 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+53-14)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+5-1)
- (added) llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll (+17)
- (added) llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll (+24)
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll (+4-1)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..2f2d5efc5e3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,19 +90,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
.addDef(createTypeVReg(MIRBuilder));
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
- MachineIRBuilder &MIRBuilder,
- bool IsSigned) {
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
assert(Width <= 64 && "Unsupported integer width!");
- const SPIRVSubtarget &ST =
- cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
- SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
- MIRBuilder.buildInstr(SPIRV::OpExtension)
- .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
- MIRBuilder.buildInstr(SPIRV::OpCapability)
- .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
- } else if (Width <= 8)
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+ return Width;
+ if (Width <= 8)
Width = 8;
else if (Width <= 16)
Width = 16;
@@ -110,7 +104,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
Width = 32;
else if (Width <= 64)
Width = 64;
+ return Width;
+}
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
+ MachineIRBuilder &MIRBuilder,
+ bool IsSigned) {
+ Width = adjustOpTypeIntWidth(Width);
+ const SPIRVSubtarget &ST =
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ if (ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+ MIRBuilder.buildInstr(SPIRV::OpExtension)
+ .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+ MIRBuilder.buildInstr(SPIRV::OpCapability)
+ .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+ }
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
@@ -800,6 +809,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ Ty = adjustIntTypeByWidth(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
@@ -815,6 +825,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
return SpirvType->defs().begin()->getReg();
}
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+ if (auto IType = dyn_cast<IntegerType>(Ty)) {
+ unsigned SrcBitWidth = IType->getBitWidth();
+ if (SrcBitWidth > 1) {
+ unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+ // Maybe change source LLVM type to keep DuplicateTracker consistent.
+ if (SrcBitWidth != BitWidth)
+ Ty = IntegerType::get(Ty->getContext(), BitWidth);
+ }
+ }
+ return Ty;
+}
+
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +973,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
- if (!isPointerTy(Ty))
+ if (!isPointerTy(Ty)) {
+ Ty = adjustIntTypeByWidth(Ty);
Reg = DT.find(Ty, &MIRBuilder.getMF());
- else if (isTypedPointerTy(Ty))
+ } else if (isTypedPointerTy(Ty)) {
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
- else
+ } else {
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+ }
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1291,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+ // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+ // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+ // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+ // with number of bits less than 8, causing duplicate type definitions.
+ BitWidth = adjustOpTypeIntWidth(BitWidth);
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
#include "llvm/IR/TypedPointerType.h"
namespace llvm {
+class SPIRVSubtarget;
using SPIRVType = const MachineInstr;
class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
- SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ const Type *adjustIntTypeByWidth(const Type *Ty) const;
+ unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+ SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..92045cc6d7619
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,24 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+ %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+ ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1ffb762aafaa6 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,10 @@
; CHECK-SPIRV: OpCapability BitInstructions
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
define spir_func signext i2 @foo(i2 noundef signext %a) {
entry:
``````````
</details>
https://github.com/llvm/llvm-project/pull/94219
More information about the llvm-commits
mailing list