[llvm] [SPIR-V] Validate and fix bit width of scalar registers (PR #95147)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 11 12:25:10 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/95147
>From d8e17b64ecf0dad9ba7dcd6174e0eae9ada3c1eb Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 11 Jun 2024 09:58:13 -0700
Subject: [PATCH 1/2] validate bit width of scalar registers
---
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 25 +++++++--
.../CodeGen/SPIRV/trunc-nonstd-bitwidth.ll | 56 +++++++++++++++++++
2 files changed, 76 insertions(+), 5 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index adc5b36af6f18..aaba6e873e2c1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -271,6 +271,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
return SpirvTy;
}
+// To support current approach and limitations wrt. bit width here we widen a
+// scalar register with a bit width greater than 1 to valid sizes and cap it to
+// 64 width.
+static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
+ LLT RegType = MRI.getType(Reg);
+ if (!RegType.isScalar())
+ return;
+ unsigned Sz = RegType.getScalarSizeInBits();
+ if (Sz == 1)
+ return;
+ unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
+ if (NewSz != Sz)
+ MRI.setType(Reg, LLT::scalar(NewSz));
+}
+
static std::pair<Register, unsigned>
createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
const SPIRVGlobalRegistry &GR) {
@@ -406,6 +421,11 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineInstr &MI = *MII;
unsigned MIOp = MI.getOpcode();
+ // validate bit width of scalar registers
+ for (const auto& MOP : MI.operands())
+ if (MOP.isReg())
+ widenScalarLLTNextPow2(MOP.getReg(), MRI);
+
if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
Register Reg = MI.getOperand(1).getReg();
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
@@ -475,11 +495,6 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
} else if (MIOp == TargetOpcode::G_GLOBAL_VALUE) {
propagateSPIRVType(&MI, GR, MRI, MIB);
- } else if (MIOp == TargetOpcode::G_BITREVERSE) {
- Register Reg = MI.getOperand(0).getReg();
- LLT RegType = MRI.getType(Reg);
- if (RegType.getSizeInBits() < 32)
- MRI.setType(Reg, LLT::scalar(32));
}
if (MII == Begin)
diff --git a/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll b/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
new file mode 100644
index 0000000000000..437e161864eca
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/trunc-nonstd-bitwidth.ll
@@ -0,0 +1,56 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXT
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOEXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXT
+
+; CHECK-DAG: OpName %[[#Struct:]] "struct"
+; CHECK-DAG: OpName %[[#Arg:]] "arg"
+; CHECK-DAG: OpName %[[#QArg:]] "qarg"
+; CHECK-DAG: OpName %[[#R:]] "r"
+; CHECK-DAG: OpName %[[#Q:]] "q"
+; CHECK-DAG: OpName %[[#Tr:]] "tr"
+; CHECK-DAG: OpName %[[#Tq:]] "tq"
+; CHECK-DAG: %[[#Struct]] = OpTypeStruct %[[#]] %[[#]] %[[#]]
+; CHECK-DAG: %[[#PtrStruct:]] = OpTypePointer CrossWorkgroup %[[#Struct]]
+; CHECK-EXT-DAG: %[[#Int40:]] = OpTypeInt 40 0
+; CHECK-EXT-DAG: %[[#Int50:]] = OpTypeInt 50 0
+; CHECK-NOEXT-DAG: %[[#Int40:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#PtrInt40:]] = OpTypePointer CrossWorkgroup %[[#Int40]]
+
+; CHECK: OpFunction
+
+; CHECK-EXT: %[[#Tr]] = OpUConvert %[[#Int40]] %[[#R]]
+; CHECK-EXT: %[[#Store:]] = OpInBoundsPtrAccessChain %[[#PtrStruct]] %[[#Arg]] %[[#]]
+; CHECK-EXT: %[[#StoreAsInt40:]] = OpBitcast %[[#PtrInt40]] %[[#Store]]
+; CHECK-EXT: OpStore %[[#StoreAsInt40]] %[[#Tr]]
+
+; CHECK-NOEXT: %[[#Store:]] = OpInBoundsPtrAccessChain %[[#PtrStruct]] %[[#Arg]] %[[#]]
+; CHECK-NOEXT: %[[#StoreAsInt40:]] = OpBitcast %[[#PtrInt40]] %[[#Store]]
+; CHECK-NOEXT: OpStore %[[#StoreAsInt40]] %[[#R]]
+
+; CHECK: OpFunction
+
+; CHECK-EXT: %[[#Tq]] = OpUConvert %[[#Int40]] %[[#Q]]
+; CHECK-EXT: OpStore %[[#QArg]] %[[#Tq]]
+
+; CHECK-NOEXT: OpStore %[[#QArg]] %[[#Q]]
+
+%struct = type <{ i32, i8, [3 x i8] }>
+
+define spir_kernel void @foo(ptr addrspace(1) %arg, i64 %r) {
+ %tr = trunc i64 %r to i40
+ %addr = getelementptr inbounds %struct, ptr addrspace(1) %arg, i64 0
+ store i40 %tr, ptr addrspace(1) %addr
+ ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %qarg, i50 %q) {
+ %tq = trunc i50 %q to i40
+ store i40 %tq, ptr addrspace(1) %qarg
+ ret void
+}
>From e401c96b087b74aef9a8370e842335e98faa987f Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 11 Jun 2024 12:24:58 -0700
Subject: [PATCH 2/2] clang-format
---
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index aaba6e873e2c1..53e0432192ca9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -422,7 +422,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
unsigned MIOp = MI.getOpcode();
// validate bit width of scalar registers
- for (const auto& MOP : MI.operands())
+ for (const auto &MOP : MI.operands())
if (MOP.isReg())
widenScalarLLTNextPow2(MOP.getReg(), MRI);
More information about the llvm-commits
mailing list