[llvm] [SPIR-V] Update type inference and instruction selection (PR #88254)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 11 05:06:02 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/88254

>From c65e1c31429f0fb08250069bdf1bd094fb7493bc Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 03:35:26 -0700
Subject: [PATCH 1/5] improve type inference, fix mismatched machine function
 context

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 47 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  9 ++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  8 +++-
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   | 23 +++++----
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  3 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  7 +++
 6 files changed, 83 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..8113de6d7fd181 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -75,6 +75,9 @@ class SPIRVEmitIntrinsics
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
                                std::unordered_set<Value *> &Visited);
 
+  // deduce Types of operands of the Instruction if possible
+  void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
@@ -368,6 +371,28 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
+// Deduce Types of operands of the Instruction if possible.
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Value *I, DenseMap<Value *, Type *> &Collected) {
+  Type *KnownTy = GR->findDeducedElementType(I);
+  if (!KnownTy)
+    return;
+
+  // look for known basic patterns of type inference
+  if (auto *Ref = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Value *Op = Ref->getIncomingValue(i);
+      if (!isUntypedPointerTy(Op->getType()))
+        continue;
+      Type *Ty = GR->findDeducedElementType(Op);
+      if (!Ty) {
+        Collected[Op] = KnownTy;
+        GR->addDeducedElementType(Op, KnownTy);
+      }
+    }
+  }
+}
+
 void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
                                               Instruction *New,
                                               IRBuilder<> &B) {
@@ -1126,6 +1151,28 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
+  for (auto &I : instructions(Func)) {
+    Type *ITy = I.getType();
+    if (!isPointerTy(ITy))
+      continue;
+    DenseMap<Value *, Type *> CollectedTys;
+    deduceOperandElementType(&I, CollectedTys);
+    if (CollectedTys.size() == 0)
+      continue;
+    for (const auto &Rec : CollectedTys) {
+      if (!Rec.first->use_empty()) {
+        Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
+        if (!User)
+          continue;
+        Type *OpTy = Rec.first->getType();
+        setInsertPointSkippingPhis(B, User->getNextNode());
+        buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+                        UndefValue::get(Rec.second), Rec.first,
+                        {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+      }
+    }
+  }
+
   // check if function parameter types are set
   if (!F->isIntrinsic())
     processParamTypes(F, B);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 70197e948c6582..c61a135f12fcf0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -726,7 +726,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
                                                 bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -919,8 +920,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
-  auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+                                         const MachineFunction *MF) const {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
   if (t != VRegToTypeMap.end()) {
     auto tt = t->second.find(VReg);
     if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2e3e69456ac260..c31d945a1e180b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -276,8 +276,12 @@ class SPIRVGlobalRegistry {
           SPIRV::AccessQualifier::ReadWrite);
 
   // Return the SPIR-V type instruction corresponding to the given VReg, or
-  // nullptr if no such type instruction exists.
-  SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+  // nullptr if no such type instruction exists. The second argument MF
+  // allows to search for the association in a context of the machine functions
+  // than the current one, without switching between different "current" machine
+  // functions.
+  SPIRVType *getSPIRVTypeForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr) const;
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
                              MachineInstr &I, unsigned OpIdx,
                              SPIRVType *ResType, const Type *ResTy = nullptr) {
+  // Get operand type
+  MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+  Register OpTypeReg =
       TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
           ? TypeInst->getOperand(1).getReg()
-          : OpReg);
+          : OpReg;
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
-  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  // Get operand's pointee type
+  Register ElemTypeReg = OpType->getOperand(2).getReg();
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
   if (!ElemType)
     return;
-  bool IsSameMF =
-      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  // Check if we need a bitcast to make a statement valid
+  bool IsSameMF = MF == ResType->getParent()->getParent();
   bool IsEqualTypes = IsSameMF ? ElemType == ResType
                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
   if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
     SPIRVType *DefElemType =
         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
-            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+                                     DefPtrType->getParent()->getParent())
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
 // with a processed definition. Return Function pointer if it's a forward
 // call (ahead of definition), and nullptr otherwise.
 const Function *validateFunCall(const SPIRVSubtarget &STI,
-                                MachineRegisterInfo *MRI,
+                                MachineRegisterInfo *CallMRI,
                                 SPIRVGlobalRegistry &GR,
                                 MachineInstr &FunCall) {
   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
   if (!FunDef)
     return F;
-  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+  validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
   return nullptr;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 299a4341193bfd..2e44c208ed8e04 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
 }
 
 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
-  return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  return toTypedPointer(ElementTy, N->getContext());
 }
 
 // The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
   return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
 }
 
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+  return isUntypedPointerTy(Ty)
+             ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+                                     getPointerAddressSpace(Ty))
+             : Ty;
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

>From 67f1d092a24665361baa0f060cf0f23ffb82c116 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 04:33:58 -0700
Subject: [PATCH 2/5] re-format code

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 21 ++++++++-----------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8113de6d7fd181..e5d327780d4e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1157,19 +1157,16 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       continue;
     DenseMap<Value *, Type *> CollectedTys;
     deduceOperandElementType(&I, CollectedTys);
-    if (CollectedTys.size() == 0)
-      continue;
+    Instruction *User;
     for (const auto &Rec : CollectedTys) {
-      if (!Rec.first->use_empty()) {
-        Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
-        if (!User)
-          continue;
-        Type *OpTy = Rec.first->getType();
-        setInsertPointSkippingPhis(B, User->getNextNode());
-        buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
-                        UndefValue::get(Rec.second), Rec.first,
-                        {B.getInt32(getPointerAddressSpace(OpTy))}, B);
-      }
+      if (Rec.first->use_empty() ||
+          !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
+        continue;
+      Type *OpTy = Rec.first->getType();
+      setInsertPointSkippingPhis(B, User->getNextNode());
+      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+                      UndefValue::get(Rec.second), Rec.first,
+                      {B.getInt32(getPointerAddressSpace(OpTy))}, B);
     }
   }
 

>From e71d88d6e10758177f485f7b155e0e40035c5f00 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 07:18:01 -0700
Subject: [PATCH 3/5] fix crash on OpConstantComposite doesn't reuse existing
 constant

---
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 13 ++++++++---
 llvm/test/CodeGen/SPIRV/const-composite.ll  | 26 +++++++++++++++++++++
 2 files changed, 36 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/const-composite.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7e155a36aadbc4..2c964595fc39e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -64,9 +64,16 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
             assert(BuildVec &&
                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
-            for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
-              GR->add(ConstVec->getElementAsConstant(i), &MF,
-                      BuildVec->getOperand(1 + i).getReg());
+            for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
+              // Ensure that OpConstantComposite reuses a constant when it's
+              // already created and available in the same machine function.
+              Constant *ElemConst = ConstVec->getElementAsConstant(i);
+              Register ElemReg = GR->find(ElemConst, &MF);
+              if (!ElemReg.isValid())
+                GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
+              else
+                BuildVec->getOperand(1 + i).setReg(ElemReg);
+            }
           }
           GR->add(Const, &MF, MI.getOperand(2).getReg());
         } else {
diff --git a/llvm/test/CodeGen/SPIRV/const-composite.ll b/llvm/test/CodeGen/SPIRV/const-composite.ll
new file mode 100644
index 00000000000000..4e304bb9516702
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/const-composite.ll
@@ -0,0 +1,26 @@
+; This test is to ensure that OpConstantComposite reuses a constant when it's
+; already created and available in the same machine function. In this test case
+; it's `1` that is passed implicitly as a part of the `foo` function argument
+; and also takes part in a composite constant creation.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#type_int32:]] = OpTypeInt 32 0
+; CHECK-SPIRV: %[[#const1:]] = OpConstant %[[#type_int32]] 1
+; CHECK-SPIRV: OpTypeArray %[[#]] %[[#const1:]]
+; CHECK-SPIRV: %[[#const0:]] = OpConstant %[[#type_int32]] 0
+; CHECK-SPIRV: OpConstantComposite %[[#]] %[[#const0]] %[[#const1]]
+
+%struct = type { [1 x i64] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
+entry:
+  call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
+  ret void
+}
+
+define spir_func void @bar(<2 x i32> noundef) {
+entry:
+  ret void
+}

>From 15aa918de59a5e103b09c0ca4b7f6acea4064f1f Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 14:10:41 -0700
Subject: [PATCH 4/5] Fix OpSelect to support ptr type / Fix TableGen typo's /
 Introduce ID vector of ptr in tableGen

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 25 +++++++--
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp      |  2 +-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       | 13 ++++-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  2 +-
 .../Target/SPIRV/SPIRVRegisterBankInfo.cpp    |  2 +
 llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td   |  1 +
 llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td    | 19 +++++--
 .../CodeGen/SPIRV/instructions/select-phi.ll  | 52 +++++++++++++++++++
 .../test/CodeGen/SPIRV/instructions/select.ll | 15 ++++++
 9 files changed, 120 insertions(+), 11 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/instructions/select-phi.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e5d327780d4e9f..3b2f2df2c8b582 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
 
+  // a registry of created Intrinsic::spv_assign_ptr_type instructions
+  DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
+
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I);
   Type *deduceElementTypeHelper(Value *I);
@@ -655,6 +658,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
         ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
     GR->addDeducedElementType(CI, ExpectedElementType);
     GR->addDeducedElementType(Pointer, ExpectedElementType);
+    AssignPtrTypeInstr[Pointer] = CI;
     return;
   }
 
@@ -939,6 +943,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
                                  EltTyConst, I, {B.getInt32(AddressSpace)}, B);
   GR->addDeducedElementType(CI, ElemTy);
+  AssignPtrTypeInstr[I] = CI;
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1095,6 +1100,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
             {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
         GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
         GR->addDeducedElementType(Arg, ElemTy);
+        AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
       }
     }
   }
@@ -1158,15 +1164,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     DenseMap<Value *, Type *> CollectedTys;
     deduceOperandElementType(&I, CollectedTys);
     Instruction *User;
+    LLVMContext &Ctx = F->getContext();
     for (const auto &Rec : CollectedTys) {
       if (Rec.first->use_empty() ||
           !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
         continue;
       Type *OpTy = Rec.first->getType();
-      setInsertPointSkippingPhis(B, User->getNextNode());
-      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
-                      UndefValue::get(Rec.second), Rec.first,
-                      {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+      Value *OpTyVal = Constant::getNullValue(Rec.second);
+      // check if there is existing Intrinsic::spv_assign_ptr_type instruction
+      auto It = AssignPtrTypeInstr.find(Rec.first);
+      if (It == AssignPtrTypeInstr.end()) {
+        setInsertPointSkippingPhis(B, User->getNextNode());
+        buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal,
+                        Rec.first, {B.getInt32(getPointerAddressSpace(OpTy))},
+                        B);
+      } else {
+        It->second->setArgOperand(
+            1,
+            MetadataAsValue::get(
+                Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
+      }
     }
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index e3f76419f13137..aacfecc1e313f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -248,7 +248,7 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
 bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
       MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
-      MI.getOpcode() == SPIRV::GET_vID) {
+      MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
     auto &MRI = MI.getMF()->getRegInfo();
     MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
     MI.eraseFromParent();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 99c57dac4141d8..a3f981457c8daa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -22,6 +22,7 @@ let isCodeGenOnly=1 in {
   def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>;
   def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>;
   def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>;
+  def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins ANYID:$src)>;
 }
 
 def SPVTypeBin : SDTypeProfile<1, 2, []>;
@@ -55,7 +56,7 @@ multiclass BinOpTypedGen<string name, bits<16> opCode, SDNode node, bit genF = 0
   }
 }
 
-multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> {
+multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genP = 1, bit genI = 1, bit genF = 0, bit genV = 0> {
   if genF then {
     def SFSCond: TernOpTyped<name, opCode, ID, fID, node>;
     def SFVCond: TernOpTyped<name, opCode, vID, fID, node>;
@@ -64,6 +65,10 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI =
     def SISCond: TernOpTyped<name, opCode, ID, ID, node>;
     def SIVCond: TernOpTyped<name, opCode, vID, ID, node>;
   }
+  if genP then {
+    def SPSCond: TernOpTyped<name, opCode, ID, pID, node>;
+    def SPVCond: TernOpTyped<name, opCode, vID, pID, node>;
+  }
   if genV then {
     if genF then {
       def VFSCond: TernOpTyped<name, opCode, ID, vfID, node>;
@@ -73,6 +78,10 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI =
       def VISCond: TernOpTyped<name, opCode, ID, vID, node>;
       def VIVCond: TernOpTyped<name, opCode, vID, vID, node>;
     }
+    if genP then {
+      def VPSCond: TernOpTyped<name, opCode, ID, vpID, node>;
+      def VPVCond: TernOpTyped<name, opCode, vID, vpID, node>;
+    }
   }
 }
 
@@ -552,7 +561,7 @@ def OpLogicalOr: BinOp<"OpLogicalOr", 166>;
 def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>;
 def OpLogicalNot: UnOp<"OpLogicalNot", 168>;
 
-defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>;
+defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1, 1>;
 
 def OpIEqual: BinOp<"OpIEqual", 170>;
 def OpINotEqual: BinOp<"OpINotEqual", 171>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index b9d66de9555b11..f069a92ac68683 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -56,7 +56,7 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
 static bool isMetaInstrGET(unsigned Opcode) {
   return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
          Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
-         Opcode == SPIRV::GET_vfID;
+         Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID;
 }
 
 static bool mayBeInserted(unsigned Opcode) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
index 9bf9d7fe5b39e8..5983c9229cb3c2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
@@ -39,6 +39,8 @@ SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
     return SPIRV::vIDRegBank;
   case SPIRV::vfIDRegClassID:
     return SPIRV::vfIDRegBank;
+  case SPIRV::vpIDRegClassID:
+    return SPIRV::vpIDRegBank;
   case SPIRV::ANYIDRegClassID:
   case SPIRV::ANYRegClassID:
     return SPIRV::IDRegBank;
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
index 90c7f3a6e67265..c7f1e172f3d4f1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
@@ -12,4 +12,5 @@ def IDRegBank : RegisterBank<"IDBank", [ID]>;
 def fIDRegBank : RegisterBank<"fIDBank", [fID]>;
 def vIDRegBank : RegisterBank<"vIDBank", [vID]>;
 def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>;
+def vpIDRegBank : RegisterBank<"vpIDBank", [vpID]>;
 def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
index d0b64b6895d035..6d2bfb91a97f12 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
@@ -12,6 +12,17 @@
 
 let Namespace = "SPIRV" in {
   def p0 : PtrValueType <i32, 0>;
+
+  class P0Vec<ValueType scalar>
+      : PtrValueType <scalar, 0> {
+    let nElem = 2;
+    let ElementType = p0;
+    let isInteger = false;
+    let isFP = false;
+    let isVector = true;
+  }
+
+  def v2p0 : P0Vec<i32>;
   // All registers are for 32-bit identifiers, so have a single dummy register
 
   // Class for registers that are the result of OpTypeXXX instructions
@@ -21,14 +32,16 @@ let Namespace = "SPIRV" in {
   // Class for every other non-type ID
   def ID0 : Register<"ID0">;
   def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>;
-  def fID0 : Register<"FID0">;
+  def fID0 : Register<"fID0">;
   def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>;
   def pID0 : Register<"pID0">;
   def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>;
-  def vID0 : Register<"pID0">;
+  def vID0 : Register<"vID0">;
   def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>;
-  def vfID0 : Register<"pID0">;
+  def vfID0 : Register<"vfID0">;
   def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>;
+  def vpID0 : Register<"vpID0">;
+  def vpID : RegisterClass<"SPIRV", [v2p0], 32, (add vpID0)>;
 
   def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>;
 
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
new file mode 100644
index 00000000000000..849276a1028ad5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
@@ -0,0 +1,52 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-DAG: %[[Array:.*]] = OpTypeArray %[[Long]] %[[#]]
+; CHECK-DAG: %[[Struct:.*]] = OpTypeStruct %[[Array]]
+; CHECK-DAG: %[[StructPtr:.*]] = OpTypePointer Function %[[Struct]]
+
+; CHECK: %[[Branch1:.*]] = OpLabel
+; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function
+; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]]
+; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[StructPtr]] %[[#]] %[[#]]
+; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]]
+; CHECK: %[[SelectRes:.*]] = OpSelect %[[StructPtr]] %[[#]] %[[#]] %[[#]]
+; CHECK: OpLabel
+; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]]
+
+%struct = type { %array }
+%array = type { [1 x i64] }
+%array3 = type { [3 x i32] }
+
+define spir_kernel void @foo(ptr addrspace(1) noundef align 1 %arg1, ptr noundef byval(%struct) align 8 %arg2, i1 noundef zeroext %expected) {
+entry:
+  %agg = alloca %array3, align 8
+  %r0 = load i64, ptr %arg2, align 8
+  %add.ptr = getelementptr inbounds i8, ptr %agg, i64 12
+  %r1 = load i32, ptr %agg, align 4
+  %tobool0 = icmp slt i32 %r1, 0
+  br i1 %tobool0, label %exit, label %sw1
+
+sw1:                            ; preds = %entry
+  %incdec1 = getelementptr inbounds i8, ptr %agg, i64 4
+  %r2 = load i32, ptr %incdec1, align 4
+  %tobool1 = icmp slt i32 %r2, 0
+  br i1 %tobool1, label %exit, label %sw2
+
+sw2:                            ; preds = %sw1
+  %incdec2 = getelementptr inbounds i8, ptr %agg, i64 8
+  %r3 = load i32, ptr %incdec2, align 4
+  %tobool2 = icmp slt i32 %r3, 0
+  %spec.select = select i1 %tobool2, ptr %incdec2, ptr %add.ptr
+  br label %exit
+
+exit: ; preds = %sw2, %sw1, %entry
+  %retval.0 = phi ptr [ %agg, %entry ], [ %incdec1, %sw1 ], [ %spec.select, %sw2 ]
+  %add.ptr.i = getelementptr inbounds i8, ptr addrspace(1) %arg1, i64 %r0
+  %r4 = icmp eq ptr %retval.0, %add.ptr
+  %cmp = xor i1 %r4, %expected
+  %frombool6.i = zext i1 %cmp to i8
+  store i8 %frombool6.i, ptr addrspace(1) %add.ptr.i, align 1
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select.ll b/llvm/test/CodeGen/SPIRV/instructions/select.ll
index f54ef21f208596..c4176b17abb449 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/select.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/select.ll
@@ -1,6 +1,8 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG:  OpName [[SCALARi32:%.+]] "select_i32"
+; CHECK-DAG:  OpName [[SCALARPTR:%.+]] "select_ptr"
 ; CHECK-DAG:  OpName [[VEC2i32:%.+]] "select_i32v2"
 ; CHECK-DAG:  OpName [[VEC2i32v2:%.+]] "select_v2i32v2"
 
@@ -17,6 +19,19 @@ define i32 @select_i32(i1 %c, i32 %t, i32 %f) {
   ret i32 %r
 }
 
+; CHECK:      [[SCALARPTR]] = OpFunction
+; CHECK-NEXT: [[C:%.+]] = OpFunctionParameter
+; CHECK-NEXT: [[T:%.+]] = OpFunctionParameter
+; CHECK-NEXT: [[F:%.+]] = OpFunctionParameter
+; CHECK:      OpLabel
+; CHECK:      [[R:%.+]] = OpSelect {{%.+}} [[C]] [[T]] [[F]]
+; CHECK:      OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define ptr @select_ptr(i1 %c, ptr %t, ptr %f) {
+  %r = select i1 %c, ptr %t, ptr %f
+  ret ptr %r
+}
+
 ; CHECK:      [[VEC2i32]] = OpFunction
 ; CHECK-NEXT: [[C:%.+]] = OpFunctionParameter
 ; CHECK-NEXT: [[T:%.+]] = OpFunctionParameter

>From fa73cdc9b68ddac4863f1105001ee6ab1e450c24 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 11 Apr 2024 05:04:57 -0700
Subject: [PATCH 5/5] improve and fix type deduction

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 136 ++++++++++++------
 .../CodeGen/SPIRV/instructions/select-phi.ll  |  12 +-
 2 files changed, 101 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 3b2f2df2c8b582..95f6c703004bfc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -79,7 +79,7 @@ class SPIRVEmitIntrinsics
                                std::unordered_set<Value *> &Visited);
 
   // deduce Types of operands of the Instruction if possible
-  void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+  void deduceOperandElementType(Instruction *I);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -275,6 +275,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       if (Ty)
         break;
     }
+  } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+    for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
+      Ty = deduceElementTypeByUsersDeep(Op, Visited);
+      if (Ty)
+        break;
+    }
   }
 
   // remember the found relationship
@@ -374,24 +380,92 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
-// Deduce Types of operands of the Instruction if possible.
-void SPIRVEmitIntrinsics::deduceOperandElementType(
-    Value *I, DenseMap<Value *, Type *> &Collected) {
-  Type *KnownTy = GR->findDeducedElementType(I);
-  if (!KnownTy)
-    return;
-
+// If the Instruction has Pointer operands with unresolved types, this function
+// tries to deduce them. If the Instruction has Pointer operands with known
+// types which differ from expected, this function tries to insert a bitcast to
+// resolve the issue.
+void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
+  SmallVector<std::pair<Value *, unsigned>> Ops;
+  Type *KnownElemTy = nullptr;
   // look for known basic patterns of type inference
   if (auto *Ref = dyn_cast<PHINode>(I)) {
+    if (!isPointerTy(I->getType()) ||
+        !(KnownElemTy = GR->findDeducedElementType(I)))
+      return;
     for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
       Value *Op = Ref->getIncomingValue(i);
-      if (!isUntypedPointerTy(Op->getType()))
-        continue;
-      Type *Ty = GR->findDeducedElementType(Op);
-      if (!Ty) {
-        Collected[Op] = KnownTy;
-        GR->addDeducedElementType(Op, KnownTy);
+      if (isPointerTy(Op->getType()))
+        Ops.push_back(std::make_pair(Op, i));
+    }
+  } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+    if (!isPointerTy(I->getType()) ||
+        !(KnownElemTy = GR->findDeducedElementType(I)))
+      return;
+    for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
+      Value *Op = Ref->getOperand(i);
+      if (isPointerTy(Op->getType()))
+        Ops.push_back(std::make_pair(Op, i));
+    }
+  } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
+    if (!isPointerTy(Ref->getOperand(0)->getType()))
+      return;
+    Value *Op0 = Ref->getOperand(0);
+    Value *Op1 = Ref->getOperand(1);
+    Type *ElemTy0 = GR->findDeducedElementType(Op0);
+    Type *ElemTy1 = GR->findDeducedElementType(Op1);
+    if (ElemTy0) {
+      KnownElemTy = ElemTy0;
+      Ops.push_back(std::make_pair(Op1, 1));
+    } else if (ElemTy1) {
+      KnownElemTy = ElemTy1;
+      Ops.push_back(std::make_pair(Op0, 0));
+    }
+  }
+
+  // There is no enough info to deduce types or all is valid.
+  if (!KnownElemTy || Ops.size() == 0)
+    return;
+
+  Instruction *User = nullptr;
+  LLVMContext &Ctx = F->getContext();
+  IRBuilder<> B(Ctx);
+  for (auto &OpIt : Ops) {
+    Value *Op = OpIt.first;
+    // unsigned i = OpIt.second;
+    Type *Ty = GR->findDeducedElementType(Op);
+    if (Ty == KnownElemTy)
+      continue;
+    if (Op->use_empty() ||
+        !(User = dyn_cast<Instruction>(Op->use_begin()->get())))
+      continue;
+
+    setInsertPointSkippingPhis(B, User->getNextNode());
+    Value *OpTyVal = Constant::getNullValue(KnownElemTy);
+    Type *OpTy = Op->getType();
+    if (!Ty) {
+      GR->addDeducedElementType(Op, KnownElemTy);
+      // check if there is existing Intrinsic::spv_assign_ptr_type instruction
+      auto It = AssignPtrTypeInstr.find(Op);
+      if (It == AssignPtrTypeInstr.end()) {
+        CallInst *CI =
+            buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
+                            {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+        AssignPtrTypeInstr[Op] = CI;
+      } else {
+        It->second->setArgOperand(
+            1,
+            MetadataAsValue::get(
+                Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
       }
+    } else {
+      SmallVector<Type *, 2> Types = {OpTy, OpTy};
+      MetadataAsValue *VMD = MetadataAsValue::get(
+          Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
+      SmallVector<Value *, 2> Args = {Op, VMD,
+                                      B.getInt32(getPointerAddressSpace(OpTy))};
+      CallInst *PtrCastI =
+          B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+      I->setOperand(OpIt.second, PtrCastI);
     }
   }
 }
@@ -1145,6 +1219,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     insertAssignTypeIntrs(I, B);
     insertPtrCastOrAssignTypeInstr(I, B);
   }
+
+  for (auto &I : instructions(Func))
+    deduceOperandElementType(&I);
+
   for (auto *I : Worklist) {
     TrackConstants = true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1157,36 +1235,6 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
-  for (auto &I : instructions(Func)) {
-    Type *ITy = I.getType();
-    if (!isPointerTy(ITy))
-      continue;
-    DenseMap<Value *, Type *> CollectedTys;
-    deduceOperandElementType(&I, CollectedTys);
-    Instruction *User;
-    LLVMContext &Ctx = F->getContext();
-    for (const auto &Rec : CollectedTys) {
-      if (Rec.first->use_empty() ||
-          !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
-        continue;
-      Type *OpTy = Rec.first->getType();
-      Value *OpTyVal = Constant::getNullValue(Rec.second);
-      // check if there is existing Intrinsic::spv_assign_ptr_type instruction
-      auto It = AssignPtrTypeInstr.find(Rec.first);
-      if (It == AssignPtrTypeInstr.end()) {
-        setInsertPointSkippingPhis(B, User->getNextNode());
-        buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal,
-                        Rec.first, {B.getInt32(getPointerAddressSpace(OpTy))},
-                        B);
-      } else {
-        It->second->setArgOperand(
-            1,
-            MetadataAsValue::get(
-                Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
-      }
-    }
-  }
-
   // check if function parameter types are set
   if (!F->isIntrinsic())
     processParamTypes(F, B);
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
index 849276a1028ad5..afc75c616f023b 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
@@ -1,19 +1,24 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %}
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
+; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[Long:.*]] = OpTypeInt 32 0
 ; CHECK-DAG: %[[Array:.*]] = OpTypeArray %[[Long]] %[[#]]
 ; CHECK-DAG: %[[Struct:.*]] = OpTypeStruct %[[Array]]
 ; CHECK-DAG: %[[StructPtr:.*]] = OpTypePointer Function %[[Struct]]
+; CHECK-DAG: %[[CharPtr:.*]] = OpTypePointer Function %[[Char]]
 
 ; CHECK: %[[Branch1:.*]] = OpLabel
 ; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function
 ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]]
-; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[StructPtr]] %[[#]] %[[#]]
+; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]]
+; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]]
 ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]]
-; CHECK: %[[SelectRes:.*]] = OpSelect %[[StructPtr]] %[[#]] %[[#]] %[[#]]
+; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]]
+; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]]
 ; CHECK: OpLabel
-; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]]
+; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]]
 
 %struct = type { %array }
 %array = type { [1 x i64] }
@@ -48,5 +53,6 @@ exit: ; preds = %sw2, %sw1, %entry
   %cmp = xor i1 %r4, %expected
   %frombool6.i = zext i1 %cmp to i8
   store i8 %frombool6.i, ptr addrspace(1) %add.ptr.i, align 1
+  %r5 = icmp eq ptr %add.ptr, %retval.0
   ret void
 }



More information about the llvm-commits mailing list