[llvm] [SPIR-V] Fix inconsistency between previously deduced element type of a pointer and function's return type (PR #109660)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 23 06:19:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR improves type inference and fixes inconsistency between previously deduced element type of a pointer and function's return type. It fixes https://github.com/llvm/llvm-project/issues/109401 by ensuring that OpPhi is consistent with respect to operand types.
---
Full diff: https://github.com/llvm/llvm-project/pull/109660.diff
4 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+38-2)
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+13-4)
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll (+55)
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll (+53)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 795ddf47c40dab..7057cc1fd30242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -144,6 +144,8 @@ class SPIRVEmitIntrinsics
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);
+ void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
+ CallInst *AssignCI);
public:
static char ID;
@@ -475,10 +477,11 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (DemangledName.length() > 0)
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
auto AsArgIt = ResTypeByArg.find(DemangledName);
- if (AsArgIt != ResTypeByArg.end()) {
+ if (AsArgIt != ResTypeByArg.end())
Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
Visited, UnknownElemTypeI8);
- }
+ else if (Type *KnownRetTy = GR->findDeducedElementType(CalledF))
+ Ty = KnownRetTy;
}
}
@@ -808,6 +811,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
CallInst *PtrCastI =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
I->setOperand(OpIt.second, PtrCastI);
+ buildAssignPtr(B, KnownElemTy, PtrCastI);
}
}
}
@@ -1706,6 +1710,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
return true;
}
+void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
+ Type *KnownElemTy,
+ CallInst *AssignCI) {
+ updateAssignType(AssignCI, CI, PoisonValue::get(NewElemTy));
+ IRBuilder<> B(CI->getContext());
+ B.SetInsertPoint(*CI->getInsertionPointAfterDef());
+ B.SetCurrentDebugLocation(CI->getDebugLoc());
+ Type *OpTy = CI->getType();
+ SmallVector<Type *, 2> Types = {OpTy, OpTy};
+ SmallVector<Value *, 2> Args = {CI, buildMD(PoisonValue::get(KnownElemTy)),
+ B.getInt32(getPointerAddressSpace(OpTy))};
+ CallInst *PtrCasted =
+ B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+ SmallVector<User *> Users(CI->users());
+ for (auto *U : Users)
+ if (U != AssignCI && U != PtrCasted)
+ U->replaceUsesOfWith(CI, PtrCasted);
+ buildAssignPtr(B, KnownElemTy, PtrCasted);
+}
+
// Try to deduce a better type for pointers to untyped ptr.
bool SPIRVEmitIntrinsics::postprocessTypes() {
bool Changed = false;
@@ -1717,6 +1741,18 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
Type *KnownTy = GR->findDeducedElementType(*IB);
if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
continue;
+ // Try to improve the type deduced after all Functions are processed.
+ if (auto *CI = dyn_cast<CallInst>(*IB)) {
+ if (Function *CalledF = CI->getCalledFunction()) {
+ Type *RetElemTy = GR->findDeducedElementType(CalledF);
+ // Fix inconsistency between known type and function's return type.
+ if (RetElemTy && RetElemTy != KnownTy) {
+ replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
+ Changed = true;
+ continue;
+ }
+ }
+ }
Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
for (User *U : I->users()) {
Instruction *Inst = dyn_cast<Instruction>(U);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index f1b10e264781f2..83f4b92147a231 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -341,6 +341,17 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
return {Reg, GetIdOp};
}
+static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
+ MachineBasicBlock &MBB = *Def->getParent();
+ MachineBasicBlock::iterator DefIt =
+ Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
+ // Skip all the PHI and debug instructions.
+ while (DefIt != MBB.end() &&
+ (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
+ DefIt = std::next(DefIt);
+ MIB.setInsertPt(MBB, DefIt);
+}
+
// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
// a dst of the definition, assign SPIRVType to both registers. If SpvType is
// provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -350,11 +361,9 @@ namespace llvm {
Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) {
- MachineInstr *Def = MRI.getVRegDef(Reg);
assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
- MIB.setInsertPt(*Def->getParent(),
- (Def->getNextNode() ? Def->getNextNode()->getIterator()
- : Def->getParent()->end()));
+ MachineInstr *Def = MRI.getVRegDef(Reg);
+ setInsertPtAfterDef(MIB, Def);
SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
if (auto *RC = MRI.getRegClassOrNull(Reg)) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
new file mode 100644
index 00000000000000..6fa3f4e53cc598
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
+; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
+; CHECK: %[[#Int:]] = OpTypeInt 32 0
+; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
+; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
+; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
+; CHECK-DAG: %[[#Casted1:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK-DAG: %[[#Casted2:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK: OpBranchConditional
+; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted1]] %[[#]]
+; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted2]] %[[#]]
+
+define void @f0(ptr %arg) {
+entry:
+ ret void
+}
+
+define ptr @f1() {
+entry:
+ %p = alloca i8
+ store i8 8, ptr %p
+ ret ptr %p
+}
+
+define ptr @f2() {
+entry:
+ %p = alloca i32
+ store i32 32, ptr %p
+ ret ptr %p
+}
+
+define ptr @foo(i1 %arg) {
+entry:
+ %r1 = tail call ptr @f1()
+ %r2 = tail call ptr @f2()
+ br i1 %arg, label %l1, label %l2
+
+l1:
+ br label %exit
+
+l2:
+ br label %exit
+
+exit:
+ %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+ %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+ tail call void @f0(ptr %ret)
+ ret ptr %ret2
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
new file mode 100644
index 00000000000000..4fbaae25567300
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
@@ -0,0 +1,53 @@
+; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
+; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
+; CHECK: %[[#Int:]] = OpTypeInt 32 0
+; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
+; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
+; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
+; CHECK: %[[#Casted:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
+; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
+
+define ptr @foo(i1 %arg) {
+entry:
+ %r1 = tail call ptr @f1()
+ %r2 = tail call ptr @f2()
+ br i1 %arg, label %l1, label %l2
+
+l1:
+ br label %exit
+
+l2:
+ br label %exit
+
+exit:
+ %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+ %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+ tail call void @f0(ptr %ret)
+ ret ptr %ret2
+}
+
+define void @f0(ptr %arg) {
+entry:
+ ret void
+}
+
+define ptr @f1() {
+entry:
+ %p = alloca i8
+ store i8 8, ptr %p
+ ret ptr %p
+}
+
+define ptr @f2() {
+entry:
+ %p = alloca i32
+ store i32 32, ptr %p
+ ret ptr %p
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/109660
More information about the llvm-commits
mailing list