[llvm] f6aa508 - [SPIR-V]: Fix creation of constants of array types in SPIRV Backend (#96514)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 01:56:59 PDT 2024


Author: Vyacheslav Levytskyy
Date: 2024-06-25T10:56:56+02:00
New Revision: f6aa50873463ebd9a459b7ccd4989460175a6e7f

URL: https://github.com/llvm/llvm-project/commit/f6aa50873463ebd9a459b7ccd4989460175a6e7f
DIFF: https://github.com/llvm/llvm-project/commit/f6aa50873463ebd9a459b7ccd4989460175a6e7f.diff

LOG: [SPIR-V]: Fix creation of constants of array types in SPIRV Backend (#96514)

This PR fixes https://github.com/llvm/llvm-project/issues/96513.

The way of creation of array type constant was incorrect: instead of
creating [1, 1, 1] or [1, 1, 1, 1, 1, ....] constants, the same [1]
constant was always created, substituting original composite constants.
This in its turn led to a situation when only one of constants might
exist in the code without emitting invalid code, the second constant
would be eventually rewritten to the first constant, because a key to
address both was an array of a single element (like [1]).

This PR fixes the issue and purges from the code unneeded copy/pasted
clone of the function that creates an array constant.

Added: 
    llvm/test/CodeGen/SPIRV/var-uniform-const.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
    llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
    llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f5f36075d4a31..71168d2d7dacd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1972,7 +1972,10 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
           .addDef(GlobalWorkSize)
           .addUse(GR->getSPIRVTypeID(SpvFieldTy))
           .addUse(GWSPtr);
-      Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
+      const SPIRVSubtarget &ST =
+          cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+      Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
+                                           SpvFieldTy, *ST.getInstrInfo());
     } else {
       Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
     }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index b8710d24bff94..5558c7a5a4a5f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -394,7 +394,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
     unsigned ElemCnt, bool ZeroAsNull) {
-  // Find a constant vector in DT or build a new one.
+  // Find a constant vector or array in DT or build a new one.
   Register Res = DT.find(CA, CurMF);
   // If no values are attached, the composite is null constant.
   bool IsNull = Val->isNullValue() && ZeroAsNull;
@@ -474,20 +474,28 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
                                     ZeroAsNull);
 }
 
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
-                                             SPIRVType *SpvType,
-                                             const SPIRVInstrInfo &TII) {
+Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
+    uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
+    const SPIRVInstrInfo &TII) {
   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
   assert(LLVMTy->isArrayTy());
   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
   Type *LLVMBaseTy = LLVMArrTy->getElementType();
-  auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
-  auto *ConstArr =
-      ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
+  Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
-  return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
+  // The following is reasonably unique key that is better that [Val]. The naive
+  // alternative would be something along the lines of:
+  //   SmallVector<Constant *> NumCI(Num, CI);
+  //   Constant *UniqueKey =
+  //     ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
+  // that would be a truly unique but dangerous key, because it could lead to
+  // the creation of constants of arbitrary length (that is, the parameter of
+  // memset) which were missing in the original module.
+  Constant *UniqueKey = ConstantStruct::getAnon(
+      {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
+       ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
+  return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
                                     LLVMArrTy->getNumElements());
 }
 
@@ -545,24 +553,6 @@ SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
                                        SpvType->getOperand(2).getImm());
 }
 
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
-                                             MachineIRBuilder &MIRBuilder,
-                                             SPIRVType *SpvType, bool EmitIR) {
-  const Type *LLVMTy = getTypeForSPIRVType(SpvType);
-  assert(LLVMTy->isArrayTy());
-  const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
-  Type *LLVMBaseTy = LLVMArrTy->getElementType();
-  const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
-  auto ConstArr =
-      ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
-  SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
-  unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
-  return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
-                                       ConstArr, BW,
-                                       LLVMArrTy->getNumElements());
-}
-
 Register
 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                              SPIRVType *SpvType) {

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 990d3328f6a30..a45e1ccd0717f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -457,13 +457,11 @@ class SPIRVGlobalRegistry {
   Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
                                   SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                   bool ZeroAsNull = true);
-  Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
-                                   SPIRVType *SpvType,
-                                   const SPIRVInstrInfo &TII);
+  Register getOrCreateConstIntArray(uint64_t Val, size_t Num, MachineInstr &I,
+                                    SPIRVType *SpvType,
+                                    const SPIRVInstrInfo &TII);
   Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
                                     SPIRVType *SpvType, bool EmitIR = true);
-  Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder,
-                                   SPIRVType *SpvType, bool EmitIR = true);
   Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                    SPIRVType *SpvType);
   Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 41a0d2c5e2f35..f5b6bcd64f480 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -846,7 +846,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
     unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
     SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
     SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
-    Register Const = GR.getOrCreateConsIntArray(Val, I, ArrTy, TII);
+    Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
     SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
         ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
     // TODO: check if we have such GV, add init, use buildGlobalVariable.

diff  --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c1b90b0e9d884..927683ad7e32b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -253,7 +253,11 @@ SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
 
 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
                                        const MachineRegisterInfo *MRI) {
-  MachineInstr *ConstInstr = MRI->getVRegDef(ConstReg);
+  MachineInstr *MI = MRI->getVRegDef(ConstReg);
+  MachineInstr *ConstInstr =
+      MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
+          ? MRI->getVRegDef(MI->getOperand(1).getReg())
+          : MI;
   if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
     if (GI->is(Intrinsic::spv_track_constant)) {
       ConstReg = ConstInstr->getOperand(2).getReg();

diff  --git a/llvm/test/CodeGen/SPIRV/var-uniform-const.ll b/llvm/test/CodeGen/SPIRV/var-uniform-const.ll
new file mode 100644
index 0000000000000..6f7c91eb09e90
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/var-uniform-const.ll
@@ -0,0 +1,87 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[#Long:]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#Size3:]] = OpConstant %[[#Int]] 3
+; CHECK-SPIRV-DAG: %[[#Arr3:]] = OpTypeArray %[[#Char]] %[[#Size3]]
+; CHECK-SPIRV-DAG: %[[#Size16:]] = OpConstant %[[#Int]] 16
+; CHECK-SPIRV-DAG: %[[#Arr16:]] = OpTypeArray %[[#Char]] %[[#Size16]]
+; CHECK-SPIRV-DAG: %[[#Const3:]] = OpConstant %[[#Long]] 3
+; CHECK-SPIRV-DAG: %[[#One:]] = OpConstant %[[#Char]] 1
+; CHECK-SPIRV-DAG: %[[#One3:]] = OpConstantComposite %[[#Arr3]] %[[#One]] %[[#One]] %[[#One]]
+; CHECK-SPIRV-DAG: %[[#Zero3:]] = OpConstantNull %[[#Arr3]]
+; CHECK-SPIRV-DAG: %[[#Const16:]] = OpConstant %[[#Long]] 16
+; CHECK-SPIRV-DAG: %[[#One16:]] = OpConstantComposite %[[#Arr16]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]]
+; CHECK-SPIRV-DAG: %[[#Zero16:]] = OpConstantNull %[[#Arr16]]
+
+; The first set of functions.
+; CHECK-SPIRV-DAG: %[[#PtrArr3:]] = OpTypePointer UniformConstant %[[#Arr3]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
+; CHECK-SPIRV-DAG: %[[#PtrArr16:]] = OpTypePointer UniformConstant %[[#Arr16]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]
+
+; The second set of functions.
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]
+
+%Vec3 = type { <3 x i8> }
+%Vec16 = type { <16 x i8> }
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
+; CHECK-SPIRV: OpFunctionEnd
+define spir_kernel void @foo(ptr addrspace(1) noundef align 16 %arg) {
+  %a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
+  %a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
+  ret void
+}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
+; CHECK-SPIRV: OpFunctionEnd
+define spir_kernel void @bar(ptr addrspace(1) noundef align 16 %arg) {
+  %a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
+  %a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
+  ret void
+}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
+; CHECK-SPIRV: OpFunctionEnd
+define spir_kernel void @foo_2(ptr addrspace(1) noundef align 16 %arg) {
+  %a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
+  %a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
+  ret void
+}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
+; CHECK-SPIRV: OpFunctionEnd
+define spir_kernel void @bar_2(ptr addrspace(1) noundef align 16 %arg) {
+  %a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
+  %a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
+  ret void
+}
+
+declare void @llvm.memset.p1.i64(ptr addrspace(1) nocapture writeonly, i8, i64, i1 immarg)


        


More information about the llvm-commits mailing list