[llvm] [SPIR-V] Improve type inference, fix mismatched machine function context (PR #88254)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 10 04:34:10 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/88254

>From 617bd2d5e0f2ddf1a9fb7e6a155ef794bf13edc5 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 03:35:26 -0700
Subject: [PATCH 1/2] improve type inference, fix mismatched machine function
 context

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 47 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  9 ++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  8 +++-
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   | 23 +++++----
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  3 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  7 +++
 6 files changed, 83 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..8113de6d7fd181 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -75,6 +75,9 @@ class SPIRVEmitIntrinsics
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
                                std::unordered_set<Value *> &Visited);
 
+  // deduce Types of operands of the Instruction if possible
+  void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
@@ -368,6 +371,28 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
+// Deduce Types of operands of the Instruction if possible.
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Value *I, DenseMap<Value *, Type *> &Collected) {
+  Type *KnownTy = GR->findDeducedElementType(I);
+  if (!KnownTy)
+    return;
+
+  // look for known basic patterns of type inference
+  if (auto *Ref = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Value *Op = Ref->getIncomingValue(i);
+      if (!isUntypedPointerTy(Op->getType()))
+        continue;
+      Type *Ty = GR->findDeducedElementType(Op);
+      if (!Ty) {
+        Collected[Op] = KnownTy;
+        GR->addDeducedElementType(Op, KnownTy);
+      }
+    }
+  }
+}
+
 void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
                                               Instruction *New,
                                               IRBuilder<> &B) {
@@ -1126,6 +1151,28 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
+  for (auto &I : instructions(Func)) {
+    Type *ITy = I.getType();
+    if (!isPointerTy(ITy))
+      continue;
+    DenseMap<Value *, Type *> CollectedTys;
+    deduceOperandElementType(&I, CollectedTys);
+    if (CollectedTys.size() == 0)
+      continue;
+    for (const auto &Rec : CollectedTys) {
+      if (!Rec.first->use_empty()) {
+        Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
+        if (!User)
+          continue;
+        Type *OpTy = Rec.first->getType();
+        setInsertPointSkippingPhis(B, User->getNextNode());
+        buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+                        UndefValue::get(Rec.second), Rec.first,
+                        {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+      }
+    }
+  }
+
   // check if function parameter types are set
   if (!F->isIntrinsic())
     processParamTypes(F, B);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..bd14da0ecc557b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -589,7 +589,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
                                                 bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -782,8 +783,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
-  auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+                                         const MachineFunction *MF) const {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
   if (t != VRegToTypeMap.end()) {
     auto tt = t->second.find(VReg);
     if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 37f575e884ef48..4dcc66f741edd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -273,8 +273,12 @@ class SPIRVGlobalRegistry {
           SPIRV::AccessQualifier::ReadWrite);
 
   // Return the SPIR-V type instruction corresponding to the given VReg, or
-  // nullptr if no such type instruction exists.
-  SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+  // nullptr if no such type instruction exists. The second argument MF
+  // allows to search for the association in a context of the machine functions
+  // than the current one, without switching between different "current" machine
+  // functions.
+  SPIRVType *getSPIRVTypeForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr) const;
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
                              MachineInstr &I, unsigned OpIdx,
                              SPIRVType *ResType, const Type *ResTy = nullptr) {
+  // Get operand type
+  MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+  Register OpTypeReg =
       TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
           ? TypeInst->getOperand(1).getReg()
-          : OpReg);
+          : OpReg;
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
-  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  // Get operand's pointee type
+  Register ElemTypeReg = OpType->getOperand(2).getReg();
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
   if (!ElemType)
     return;
-  bool IsSameMF =
-      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  // Check if we need a bitcast to make a statement valid
+  bool IsSameMF = MF == ResType->getParent()->getParent();
   bool IsEqualTypes = IsSameMF ? ElemType == ResType
                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
   if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
     SPIRVType *DefElemType =
         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
-            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+                                     DefPtrType->getParent()->getParent())
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
 // with a processed definition. Return Function pointer if it's a forward
 // call (ahead of definition), and nullptr otherwise.
 const Function *validateFunCall(const SPIRVSubtarget &STI,
-                                MachineRegisterInfo *MRI,
+                                MachineRegisterInfo *CallMRI,
                                 SPIRVGlobalRegistry &GR,
                                 MachineInstr &FunCall) {
   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
   if (!FunDef)
     return F;
-  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+  validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
   return nullptr;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c87c1293c622fc..07ce9d9078de27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
 }
 
 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
-  return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  return toTypedPointer(ElementTy, N->getContext());
 }
 
 // The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
   return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
 }
 
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+  return isUntypedPointerTy(Ty)
+             ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+                                     getPointerAddressSpace(Ty))
+             : Ty;
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

>From b07b588de5c9f12e1383648c1959acfb3ec56c9a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 04:33:58 -0700
Subject: [PATCH 2/2] re-format code

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 21 ++++++++-----------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8113de6d7fd181..e5d327780d4e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1157,19 +1157,16 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       continue;
     DenseMap<Value *, Type *> CollectedTys;
     deduceOperandElementType(&I, CollectedTys);
-    if (CollectedTys.size() == 0)
-      continue;
+    Instruction *User;
     for (const auto &Rec : CollectedTys) {
-      if (!Rec.first->use_empty()) {
-        Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
-        if (!User)
-          continue;
-        Type *OpTy = Rec.first->getType();
-        setInsertPointSkippingPhis(B, User->getNextNode());
-        buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
-                        UndefValue::get(Rec.second), Rec.first,
-                        {B.getInt32(getPointerAddressSpace(OpTy))}, B);
-      }
+      if (Rec.first->use_empty() ||
+          !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
+        continue;
+      Type *OpTy = Rec.first->getType();
+      setInsertPointSkippingPhis(B, User->getNextNode());
+      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+                      UndefValue::get(Rec.second), Rec.first,
+                      {B.getInt32(getPointerAddressSpace(OpTy))}, B);
     }
   }
 



More information about the llvm-commits mailing list