[llvm] [SPIR-V] Fix inconsistency between previously deduced element type of a pointer and function's return type (PR #109660)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 23 06:19:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR improves type inference and fixes inconsistency between previously deduced element type of a pointer and function's return type. It fixes https://github.com/llvm/llvm-project/issues/109401 by ensuring that OpPhi is consistent with respect to operand types.


---
Full diff: https://github.com/llvm/llvm-project/pull/109660.diff


4 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+38-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+13-4) 
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll (+55) 
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll (+53) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 795ddf47c40dab..7057cc1fd30242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -144,6 +144,8 @@ class SPIRVEmitIntrinsics
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
+  void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
+                            CallInst *AssignCI);
 
 public:
   static char ID;
@@ -475,10 +477,11 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       if (DemangledName.length() > 0)
         DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
       auto AsArgIt = ResTypeByArg.find(DemangledName);
-      if (AsArgIt != ResTypeByArg.end()) {
+      if (AsArgIt != ResTypeByArg.end())
         Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
                                      Visited, UnknownElemTypeI8);
-      }
+      else if (Type *KnownRetTy = GR->findDeducedElementType(CalledF))
+        Ty = KnownRetTy;
     }
   }
 
@@ -808,6 +811,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
       CallInst *PtrCastI =
           B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
       I->setOperand(OpIt.second, PtrCastI);
+      buildAssignPtr(B, KnownElemTy, PtrCastI);
     }
   }
 }
@@ -1706,6 +1710,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   return true;
 }
 
+void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
+                                               Type *KnownElemTy,
+                                               CallInst *AssignCI) {
+  updateAssignType(AssignCI, CI, PoisonValue::get(NewElemTy));
+  IRBuilder<> B(CI->getContext());
+  B.SetInsertPoint(*CI->getInsertionPointAfterDef());
+  B.SetCurrentDebugLocation(CI->getDebugLoc());
+  Type *OpTy = CI->getType();
+  SmallVector<Type *, 2> Types = {OpTy, OpTy};
+  SmallVector<Value *, 2> Args = {CI, buildMD(PoisonValue::get(KnownElemTy)),
+                                  B.getInt32(getPointerAddressSpace(OpTy))};
+  CallInst *PtrCasted =
+      B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+  SmallVector<User *> Users(CI->users());
+  for (auto *U : Users)
+    if (U != AssignCI && U != PtrCasted)
+      U->replaceUsesOfWith(CI, PtrCasted);
+  buildAssignPtr(B, KnownElemTy, PtrCasted);
+}
+
 // Try to deduce a better type for pointers to untyped ptr.
 bool SPIRVEmitIntrinsics::postprocessTypes() {
   bool Changed = false;
@@ -1717,6 +1741,18 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
     Type *KnownTy = GR->findDeducedElementType(*IB);
     if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
       continue;
+    // Try to improve the type deduced after all Functions are processed.
+    if (auto *CI = dyn_cast<CallInst>(*IB)) {
+      if (Function *CalledF = CI->getCalledFunction()) {
+        Type *RetElemTy = GR->findDeducedElementType(CalledF);
+        // Fix inconsistency between known type and function's return type.
+        if (RetElemTy && RetElemTy != KnownTy) {
+          replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
+          Changed = true;
+          continue;
+        }
+      }
+    }
     Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
     for (User *U : I->users()) {
       Instruction *Inst = dyn_cast<Instruction>(U);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index f1b10e264781f2..83f4b92147a231 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -341,6 +341,17 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
   return {Reg, GetIdOp};
 }
 
+static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
+  MachineBasicBlock &MBB = *Def->getParent();
+  MachineBasicBlock::iterator DefIt =
+      Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
+  // Skip all the PHI and debug instructions.
+  while (DefIt != MBB.end() &&
+         (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
+    DefIt = std::next(DefIt);
+  MIB.setInsertPt(MBB, DefIt);
+}
+
 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -350,11 +361,9 @@ namespace llvm {
 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
                            MachineRegisterInfo &MRI) {
-  MachineInstr *Def = MRI.getVRegDef(Reg);
   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
-  MIB.setInsertPt(*Def->getParent(),
-                  (Def->getNextNode() ? Def->getNextNode()->getIterator()
-                                      : Def->getParent()->end()));
+  MachineInstr *Def = MRI.getVRegDef(Reg);
+  setInsertPtAfterDef(MIB, Def);
   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
new file mode 100644
index 00000000000000..6fa3f4e53cc598
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
+; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
+; CHECK: %[[#Int:]] = OpTypeInt 32 0
+; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
+; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
+; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
+; CHECK-DAG: %[[#Casted1:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK-DAG: %[[#Casted2:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK: OpBranchConditional
+; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted1]] %[[#]]
+; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted2]] %[[#]]
+
+define void @f0(ptr %arg) {
+entry:
+  ret void
+}
+
+define ptr @f1() {
+entry:
+  %p = alloca i8
+  store i8 8, ptr %p
+  ret ptr %p
+}
+
+define ptr @f2() {
+entry:
+  %p = alloca i32
+  store i32 32, ptr %p
+  ret ptr %p
+}
+
+define ptr @foo(i1 %arg) {
+entry:
+  %r1 = tail call ptr @f1()
+  %r2 = tail call ptr @f2()
+  br i1 %arg, label %l1, label %l2
+
+l1:
+  br label %exit
+
+l2:
+  br label %exit
+
+exit:
+  %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  tail call void @f0(ptr %ret)
+  ret ptr %ret2
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
new file mode 100644
index 00000000000000..4fbaae25567300
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
@@ -0,0 +1,53 @@
+; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
+; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
+; CHECK: %[[#Int:]] = OpTypeInt 32 0
+; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
+; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
+; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
+; CHECK: %[[#Casted:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
+; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
+
+define ptr @foo(i1 %arg) {
+entry:
+  %r1 = tail call ptr @f1()
+  %r2 = tail call ptr @f2()
+  br i1 %arg, label %l1, label %l2
+
+l1:
+  br label %exit
+
+l2:
+  br label %exit
+
+exit:
+  %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  tail call void @f0(ptr %ret)
+  ret ptr %ret2
+}
+
+define void @f0(ptr %arg) {
+entry:
+  ret void
+}
+
+define ptr @f1() {
+entry:
+  %p = alloca i8
+  store i8 8, ptr %p
+  ret ptr %p
+}
+
+define ptr @f2() {
+entry:
+  %p = alloca i32
+  store i32 32, ptr %p
+  ret ptr %p
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/109660


More information about the llvm-commits mailing list