[llvm] 6d4fb3d - [SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (#94219)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 5 00:56:56 PDT 2024
Author: Vyacheslav Levytskyy
Date: 2024-06-05T09:56:52+02:00
New Revision: 6d4fb3d3bbecbdfa1c98da3f7e09322abaec5f97
URL: https://github.com/llvm/llvm-project/commit/6d4fb3d3bbecbdfa1c98da3f7e09322abaec5f97
DIFF: https://github.com/llvm/llvm-project/commit/6d4fb3d3bbecbdfa1c98da3f7e09322abaec5f97.diff
LOG: [SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (#94219)
Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates
arbitrary sized integer types (<= 64 bits). Without such extension and
according to the SPIR-V specification
`SPIRVGlobalRegistry::getOpTypeInt()` rounds integer sizes other than
8,16,32,64 up, to one of defined by the specification sizes. For the
`DuplicateTracker` class this means that several original LLVM types
(e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaks
`DuplicateTracker`'s logic and leads to generation of invalid SPIR-V
code eventually.
For example,
```
define spir_func void @foo(i2 %a, i4 %b) {
entry:
%res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
%res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
ret void
}
declare i2 @llvm.bitreverse.i2(i2)
declare i4 @llvm.bitreverse.i4(i4)
```
after translation to SPIR-V would fail during validation (`spirv-val`)
due to two `OpTypeInt 8 0` instructions.
This PR fixes the issue by changing source LLVM type according to the
SPIR-V type that will be used in the emitted code.
Added:
llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..f96c3a2b0a770 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,10 +90,28 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
.addDef(createTypeVReg(MIRBuilder));
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
+ if (Width > 64)
+ report_fatal_error("Unsupported integer width!");
+ const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
+ if (ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+ return Width;
+ if (Width <= 8)
+ Width = 8;
+ else if (Width <= 16)
+ Width = 16;
+ else if (Width <= 32)
+ Width = 32;
+ else
+ Width = 64;
+ return Width;
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
MachineIRBuilder &MIRBuilder,
bool IsSigned) {
- assert(Width <= 64 && "Unsupported integer width!");
+ Width = adjustOpTypeIntWidth(Width);
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (ST.canUseExtension(
@@ -102,15 +120,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
- } else if (Width <= 8)
- Width = 8;
- else if (Width <= 16)
- Width = 16;
- else if (Width <= 32)
- Width = 32;
- else if (Width <= 64)
- Width = 64;
-
+ }
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
@@ -800,6 +810,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ Ty = adjustIntTypeByWidth(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
@@ -815,6 +826,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
return SpirvType->defs().begin()->getReg();
}
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+ if (auto IType = dyn_cast<IntegerType>(Ty)) {
+ unsigned SrcBitWidth = IType->getBitWidth();
+ if (SrcBitWidth > 1) {
+ unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+ // Maybe change source LLVM type to keep DuplicateTracker consistent.
+ if (SrcBitWidth != BitWidth)
+ Ty = IntegerType::get(Ty->getContext(), BitWidth);
+ }
+ }
+ return Ty;
+}
+
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +974,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
- if (!isPointerTy(Ty))
+ if (!isPointerTy(Ty)) {
+ Ty = adjustIntTypeByWidth(Ty);
Reg = DT.find(Ty, &MIRBuilder.getMF());
- else if (isTypedPointerTy(Ty))
+ } else if (isTypedPointerTy(Ty)) {
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
- else
+ } else {
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+ }
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1292,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+ // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+ // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+ // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+ // with number of bits less than 8, causing duplicate type definitions.
+ BitWidth = adjustOpTypeIntWidth(BitWidth);
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
#include "llvm/IR/TypedPointerType.h"
namespace llvm {
+class SPIRVSubtarget;
using SPIRVType = const MachineInstr;
class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
- SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ const Type *adjustIntTypeByWidth(const Type *Ty) const;
+ unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+ SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..fe71ce862dfc3
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,27 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
+; between the LLVM-IR and SPIR-V for i2 and i4
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+ %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+ ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1840ad5411f47 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,13 @@
; CHECK-SPIRV: OpCapability BitInstructions
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
+
+; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
+; between the LLVM-IR and SPIR-V for i2
define spir_func signext i2 @foo(i2 noundef signext %a) {
entry:
More information about the llvm-commits
mailing list