[llvm] [SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (PR #94219)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 3 09:11:22 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/94219
>From b81b1f8b691a0a1c4a472c7b6056d03370fe0985 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 3 Jun 2024 06:35:24 -0700
Subject: [PATCH 1/3] emit valid SPIR-V code for integer sizes other than
8,16,32,64
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 67 +++++++++++++++----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 6 +-
.../CodeGen/SPIRV/no-i8-type-duplication.ll | 17 +++++
.../SPIRV/transcoding/OpBitReverse-subbyte.ll | 24 +++++++
4 files changed, 99 insertions(+), 15 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..2f2d5efc5e3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,19 +90,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
.addDef(createTypeVReg(MIRBuilder));
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
- MachineIRBuilder &MIRBuilder,
- bool IsSigned) {
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
assert(Width <= 64 && "Unsupported integer width!");
- const SPIRVSubtarget &ST =
- cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
- SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
- MIRBuilder.buildInstr(SPIRV::OpExtension)
- .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
- MIRBuilder.buildInstr(SPIRV::OpCapability)
- .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
- } else if (Width <= 8)
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+ return Width;
+ if (Width <= 8)
Width = 8;
else if (Width <= 16)
Width = 16;
@@ -110,7 +104,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
Width = 32;
else if (Width <= 64)
Width = 64;
+ return Width;
+}
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
+ MachineIRBuilder &MIRBuilder,
+ bool IsSigned) {
+ Width = adjustOpTypeIntWidth(Width);
+ const SPIRVSubtarget &ST =
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ if (ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+ MIRBuilder.buildInstr(SPIRV::OpExtension)
+ .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+ MIRBuilder.buildInstr(SPIRV::OpCapability)
+ .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+ }
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
@@ -800,6 +809,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ Ty = adjustIntTypeByWidth(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
@@ -815,6 +825,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
return SpirvType->defs().begin()->getReg();
}
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+ if (auto IType = dyn_cast<IntegerType>(Ty)) {
+ unsigned SrcBitWidth = IType->getBitWidth();
+ if (SrcBitWidth > 1) {
+ unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+ // Maybe change source LLVM type to keep DuplicateTracker consistent.
+ if (SrcBitWidth != BitWidth)
+ Ty = IntegerType::get(Ty->getContext(), BitWidth);
+ }
+ }
+ return Ty;
+}
+
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +973,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
- if (!isPointerTy(Ty))
+ if (!isPointerTy(Ty)) {
+ Ty = adjustIntTypeByWidth(Ty);
Reg = DT.find(Ty, &MIRBuilder.getMF());
- else if (isTypedPointerTy(Ty))
+ } else if (isTypedPointerTy(Ty)) {
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
- else
+ } else {
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+ }
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1291,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+ // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+ // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+ // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+ // with number of bits less than 8, causing duplicate type definitions.
+ BitWidth = adjustOpTypeIntWidth(BitWidth);
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
#include "llvm/IR/TypedPointerType.h"
namespace llvm {
+class SPIRVSubtarget;
using SPIRVType = const MachineInstr;
class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
- SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ const Type *adjustIntTypeByWidth(const Type *Ty) const;
+ unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+ SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..92045cc6d7619
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,24 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+ %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+ ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)
>From 098728d45f99e604e4c688f152b309fc75d2077a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 3 Jun 2024 06:48:40 -0700
Subject: [PATCH 2/3] harden the test case
---
llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1ffb762aafaa6 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,10 @@
; CHECK-SPIRV: OpCapability BitInstructions
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
define spir_func signext i2 @foo(i2 noundef signext %a) {
entry:
>From 27bdca9f3ca052613311075ffe438bc29ccb0904 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 3 Jun 2024 09:11:06 -0700
Subject: [PATCH 3/3] apply suggestions after code review
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 5 +++--
llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll | 3 +++
llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll | 3 +++
3 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 2f2d5efc5e3ba..f96c3a2b0a770 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -91,7 +91,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
}
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
- assert(Width <= 64 && "Unsupported integer width!");
+ if (Width > 64)
+ report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
@@ -102,7 +103,7 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
Width = 16;
else if (Width <= 32)
Width = 32;
- else if (Width <= 64)
+ else
Width = 64;
return Width;
}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
index 92045cc6d7619..fe71ce862dfc3 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -13,6 +13,9 @@
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
+; between the LLVM-IR and SPIR-V for i2 and i4
+
define spir_func void @foo(i2 %a, i4 %b) {
entry:
%res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index 1ffb762aafaa6..1840ad5411f47 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -15,6 +15,9 @@
; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
; CHECK-SPIRV: OpReturnValue %[[#Res]]
+; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
+; between the LLVM-IR and SPIR-V for i2
+
define spir_func signext i2 @foo(i2 noundef signext %a) {
entry:
%b = tail call i2 @llvm.bitreverse.i2(i2 %a)
More information about the llvm-commits
mailing list