[llvm] [SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (PR #94219)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 09:11:22 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/94219

>From b81b1f8b691a0a1c4a472c7b6056d03370fe0985 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 3 Jun 2024 06:35:24 -0700
Subject: [PATCH 1/3] emit valid SPIR-V code for integer sizes other than
 8,16,32,64

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 67 +++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  6 +-
 .../CodeGen/SPIRV/no-i8-type-duplication.ll   | 17 +++++
 .../SPIRV/transcoding/OpBitReverse-subbyte.ll | 24 +++++++
 4 files changed, 99 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..2f2d5efc5e3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,19 +90,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
       .addDef(createTypeVReg(MIRBuilder));
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
-                                             MachineIRBuilder &MIRBuilder,
-                                             bool IsSigned) {
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
   assert(Width <= 64 && "Unsupported integer width!");
-  const SPIRVSubtarget &ST =
-      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
-    MIRBuilder.buildInstr(SPIRV::OpExtension)
-        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
-    MIRBuilder.buildInstr(SPIRV::OpCapability)
-        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
-  } else if (Width <= 8)
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+    return Width;
+  if (Width <= 8)
     Width = 8;
   else if (Width <= 16)
     Width = 16;
@@ -110,7 +104,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
     Width = 32;
   else if (Width <= 64)
     Width = 64;
+  return Width;
+}
 
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
+                                             MachineIRBuilder &MIRBuilder,
+                                             bool IsSigned) {
+  Width = adjustOpTypeIntWidth(Width);
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  if (ST.canUseExtension(
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+    MIRBuilder.buildInstr(SPIRV::OpExtension)
+        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+    MIRBuilder.buildInstr(SPIRV::OpCapability)
+        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+  }
   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
                  .addDef(createTypeVReg(MIRBuilder))
                  .addImm(Width)
@@ -800,6 +809,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+  Ty = adjustIntTypeByWidth(Ty);
   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
@@ -815,6 +825,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
   return SpirvType->defs().begin()->getReg();
 }
 
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+  if (auto IType = dyn_cast<IntegerType>(Ty)) {
+    unsigned SrcBitWidth = IType->getBitWidth();
+    if (SrcBitWidth > 1) {
+      unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+      // Maybe change source LLVM type to keep DuplicateTracker consistent.
+      if (SrcBitWidth != BitWidth)
+        Ty = IntegerType::get(Ty->getContext(), BitWidth);
+    }
+  }
+  return Ty;
+}
+
 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +973,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
   Register Reg;
-  if (!isPointerTy(Ty))
+  if (!isPointerTy(Ty)) {
+    Ty = adjustIntTypeByWidth(Ty);
     Reg = DT.find(Ty, &MIRBuilder.getMF());
-  else if (isTypedPointerTy(Ty))
+  } else if (isTypedPointerTy(Ty)) {
     Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
                   getPointerAddressSpace(Ty), &MIRBuilder.getMF());
-  else
+  } else {
     Reg =
         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
                 getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+  }
 
   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
     return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1291,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+  // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+  // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+  // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+  // with number of bits less than 8, causing duplicate type definitions.
+  BitWidth = adjustOpTypeIntWidth(BitWidth);
   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
 }
+
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
   LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
 #include "llvm/IR/TypedPointerType.h"
 
 namespace llvm {
+class SPIRVSubtarget;
 using SPIRVType = const MachineInstr;
 
 class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
-  SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+  const Type *adjustIntTypeByWidth(const Type *Ty) const;
+  unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+  SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..92045cc6d7619
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,24 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+  %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+  %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+  ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)

>From 098728d45f99e604e4c688f152b309fc75d2077a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 3 Jun 2024 06:48:40 -0700
Subject: [PATCH 2/3] harden the test case

---
 llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1ffb762aafaa6 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,10 @@
 ; CHECK-SPIRV: OpCapability BitInstructions
 ; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
 ; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
 
 define spir_func signext i2 @foo(i2 noundef signext %a) {
 entry:

>From 27bdca9f3ca052613311075ffe438bc29ccb0904 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 3 Jun 2024 09:11:06 -0700
Subject: [PATCH 3/3] apply suggestions after code review

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp               | 5 +++--
 llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll | 3 +++
 llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll      | 3 +++
 3 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 2f2d5efc5e3ba..f96c3a2b0a770 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -91,7 +91,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
-  assert(Width <= 64 && "Unsupported integer width!");
+  if (Width > 64)
+    report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
@@ -102,7 +103,7 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
     Width = 16;
   else if (Width <= 32)
     Width = 32;
-  else if (Width <= 64)
+  else
     Width = 64;
   return Width;
 }
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
index 92045cc6d7619..fe71ce862dfc3 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -13,6 +13,9 @@
 ; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
 ; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
 
+; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
+;       between the LLVM-IR and SPIR-V for i2 and i4
+
 define spir_func void @foo(i2 %a, i4 %b) {
 entry:
   %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index 1ffb762aafaa6..1840ad5411f47 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -15,6 +15,9 @@
 ; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
 ; CHECK-SPIRV: OpReturnValue %[[#Res]]
 
+; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
+;       between the LLVM-IR and SPIR-V for i2
+
 define spir_func signext i2 @foo(i2 noundef signext %a) {
 entry:
   %b = tail call i2 @llvm.bitreverse.i2(i2 %a)



More information about the llvm-commits mailing list