[llvm] [SPIR-V] Fix early definition of SPIR-V types during call lowering (PR #115192)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 10:50:57 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR is to fix early definition of SPIR-V types during call lowering. Namely, the goal of the test case is to ensure that correct types are applied to virtual registers which were used as arguments in call lowering and so caused early definition of SPIR-V types.

Reproducers are attached as a new test cases.

---
Full diff: https://github.com/llvm/llvm-project/pull/115192.diff


4 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+27-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+2) 
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering-unwrapped.ll (+50) 
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering.ll (+49) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 98cf598a1f031a..8d1b82465d3df2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -546,12 +546,35 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
       ArgVRegs.push_back(ArgReg);
       SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
       if (!SpvType) {
-        SpvType = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
-        GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
+        Type *ArgTy = nullptr;
+        if (auto *PtrArgTy = dyn_cast<PointerType>(Arg.Ty)) {
+          // If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
+          // don't have access to original value in LLVM IR or info about
+          // deduced pointee type, then we should wait with setting the type for
+          // the virtual register until pre-legalizer step when we access
+          // @llvm.spv.assign.ptr.type.p...(...)'s info.
+          if (Arg.OrigValue)
+            if (Type *ElemTy = GR->findDeducedElementType(Arg.OrigValue))
+              ArgTy = TypedPointerType::get(ElemTy, PtrArgTy->getAddressSpace());
+        } else {
+          ArgTy = Arg.Ty;
+        }
+        if (ArgTy) {
+          SpvType = GR->getOrCreateSPIRVType(ArgTy, MIRBuilder);
+          GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
+        }
       }
       if (!MRI->getRegClassOrNull(ArgReg)) {
-        MRI->setRegClass(ArgReg, GR->getRegClass(SpvType));
-        MRI->setType(ArgReg, GR->getRegType(SpvType));
+        // Either we have SpvType created, or Arg.Ty is an untyped pointer and
+        // we know its virtual register's class and type even if we don't know
+        // pointee type.
+        MRI->setRegClass(ArgReg, SpvType ? GR->getRegClass(SpvType)
+                                         : &SPIRV::pIDRegClass);
+        MRI->setType(
+            ArgReg,
+            SpvType ? GR->getRegType(SpvType)
+                    : LLT::pointer(cast<PointerType>(Arg.Ty)->getAddressSpace(),
+                                   GR->getPointerSize()));
       }
     }
     auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8b7e9c48de6c75..191533a52cccd9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1219,6 +1219,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
   SmallVector<Value *, 2> Args = {Pointer, VMD, B.getInt32(AddressSpace)};
   auto *PtrCastI = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
   I->setOperand(OperandToReplace, PtrCastI);
+  // We need to set up a pointee type for the newly created spv_ptrcast.
+  buildAssignPtr(B, ExpectedElementType, PtrCastI);
 }
 
 void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering-unwrapped.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering-unwrapped.ll
new file mode 100644
index 00000000000000..09c0a92d596ce3
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering-unwrapped.ll
@@ -0,0 +1,50 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers
+; which were used as arguments in call lowering and so caused early definition of SPIR-V types.
+
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+%t_id = type { %t_arr }
+%t_arr = type { [1 x i64] }
+%t_bf16 = type { i16 }
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) align 4 %_arg_ERR, ptr byval(%t_id) align 8 %_arg_ERR3) {
+entry:
+  %FloatArray.i = alloca [4 x float], align 4
+  %BF16Array.i = alloca [4 x %t_bf16], align 2
+  %0 = load i64, ptr %_arg_ERR3, align 8
+  %add.ptr.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_ERR, i64 %0
+  %FloatArray.ascast.i = addrspacecast ptr %FloatArray.i to ptr addrspace(4)
+  %BF16Array.ascast.i = addrspacecast ptr %BF16Array.i to ptr addrspace(4)
+  call spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4) %FloatArray.ascast.i, ptr addrspace(4) %BF16Array.ascast.i)
+  br label %for.cond.i
+
+for.cond.i:                                       ; preds = %for.inc.i, %entry
+  %lsr.iv1 = phi ptr [ %scevgep2, %for.inc.i ], [ %FloatArray.i, %entry ]
+  %lsr.iv = phi ptr addrspace(4) [ %scevgep, %for.inc.i ], [ %BF16Array.ascast.i, %entry ]
+  %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.inc.i ]
+  %cmp.i = icmp ult i32 %i.0.i, 4
+  br i1 %cmp.i, label %for.body.i, label %exit
+
+for.body.i:                                       ; preds = %for.cond.i
+  %1 = load float, ptr %lsr.iv1, align 4
+  %call.i.i = call spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2) %lsr.iv)
+  %cmp5.i = fcmp une float %1, %call.i.i
+  br i1 %cmp5.i, label %if.then.i, label %for.inc.i
+
+if.then.i:                                        ; preds = %for.body.i
+  store i32 1, ptr addrspace(1) %add.ptr.i, align 4
+  br label %for.inc.i
+
+for.inc.i:                                        ; preds = %if.then.i, %for.body.i
+  %inc.i = add nuw nsw i32 %i.0.i, 1
+  %scevgep = getelementptr i8, ptr addrspace(4) %lsr.iv, i64 2
+  %scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
+  br label %for.cond.i
+
+exit:                                             ; preds = %for.cond.i
+  ret void
+}
+
+declare void @llvm.memcpy.p0.p1.i64(ptr noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg)
+declare dso_local spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4), ptr addrspace(4))
+declare dso_local spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2))
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering.ll
new file mode 100644
index 00000000000000..a31638e0b87043
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-vs-calllowering.ll
@@ -0,0 +1,49 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers
+; which were used as arguments in call lowering and so caused early definition of SPIR-V types.
+
+; RUN: %if spirv-tools %{ llc -O2 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+%t_id = type { %t_arr }
+%t_arr = type { [1 x i64] }
+%t_bf16 = type { i16 }
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) align 4 %_arg_ERR, ptr byval(%t_id) align 8 %_arg_ERR3) {
+entry:
+  %FloatArray.i = alloca [4 x float], align 4
+  %BF16Array.i = alloca [4 x %t_bf16], align 2
+  %0 = load i64, ptr %_arg_ERR3, align 8
+  %add.ptr.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_ERR, i64 %0
+  %FloatArray.ascast.i = addrspacecast ptr %FloatArray.i to ptr addrspace(4)
+  %BF16Array.ascast.i = addrspacecast ptr %BF16Array.i to ptr addrspace(4)
+  call spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4) %FloatArray.ascast.i, ptr addrspace(4) %BF16Array.ascast.i)
+  br label %for.cond.i
+
+for.cond.i:                                       ; preds = %for.inc.i, %entry
+  %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.inc.i ]
+  %cmp.i = icmp ult i32 %i.0.i, 4
+  br i1 %cmp.i, label %for.body.i, label %exit
+
+for.body.i:                                       ; preds = %for.cond.i
+  %idxprom.i = zext nneg i32 %i.0.i to i64
+  %arrayidx.i = getelementptr inbounds [4 x float], ptr %FloatArray.i, i64 0, i64 %idxprom.i
+  %1 = load float, ptr %arrayidx.i, align 4
+  %arrayidx4.i = getelementptr inbounds [4 x %t_bf16], ptr addrspace(4) %BF16Array.ascast.i, i64 0, i64 %idxprom.i
+  %call.i.i = call spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2) %arrayidx4.i)
+  %cmp5.i = fcmp une float %1, %call.i.i
+  br i1 %cmp5.i, label %if.then.i, label %for.inc.i
+
+if.then.i:                                        ; preds = %for.body.i
+  store i32 1, ptr addrspace(1) %add.ptr.i, align 4
+  br label %for.inc.i
+
+for.inc.i:                                        ; preds = %if.then.i, %for.body.i
+  %inc.i = add nuw nsw i32 %i.0.i, 1
+  br label %for.cond.i
+
+exit: ; preds = %for.cond.i
+  ret void
+}
+
+declare void @llvm.memcpy.p0.p1.i64(ptr noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg)
+declare dso_local spir_func void @__devicelib_ConvertFToBF16INTELVec4(ptr addrspace(4), ptr addrspace(4))
+declare dso_local spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) align 2 dereferenceable(2))

``````````

</details>


https://github.com/llvm/llvm-project/pull/115192


More information about the llvm-commits mailing list