[llvm] [SPIR-V] Improve type inference, fix mismatched machine function context (PR #88254)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 10 07:21:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR contains several fixes which are to improve type inference. Namely, it includes:
* fix usage of a machine function context when there is a need to switch between different machine functions to infer/validate correct types;
* add usage of TypedPointerType instead of PointerType so that later stages of type inference are able to distinguish pointer types by their element types, effectively supporting hierarchy of pointer/pointee types and avoiding more complicated recursive type matching on level of machine instructions in favor of direct pointer comparison using LLVM's `Type *` values;
* extracting detailed information about operand types using known type rules for some llvm instructions (for instance, by deducing PHI's operand pointee types if PHI's results type was deducted on previous stages of type inference), and adding correspondent `Intrinsic::spv_assign_ptr_type` to keep type info along consequent passes,
* ensure that OpConstantComposite reuses a constant when it's already created and available in the same machine function -- otherwise there is a crash while building a dependency graph, the corresponding test case is attached.


---
Full diff: https://github.com/llvm/llvm-project/pull/88254.diff


8 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+44) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+6-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+6-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+15-8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+10-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7) 
- (added) llvm/test/CodeGen/SPIRV/const-composite.ll (+26) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..e5d327780d4e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -75,6 +75,9 @@ class SPIRVEmitIntrinsics
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
                                std::unordered_set<Value *> &Visited);
 
+  // deduce Types of operands of the Instruction if possible
+  void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
@@ -368,6 +371,28 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
+// Deduce Types of operands of the Instruction if possible.
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Value *I, DenseMap<Value *, Type *> &Collected) {
+  Type *KnownTy = GR->findDeducedElementType(I);
+  if (!KnownTy)
+    return;
+
+  // look for known basic patterns of type inference
+  if (auto *Ref = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Value *Op = Ref->getIncomingValue(i);
+      if (!isUntypedPointerTy(Op->getType()))
+        continue;
+      Type *Ty = GR->findDeducedElementType(Op);
+      if (!Ty) {
+        Collected[Op] = KnownTy;
+        GR->addDeducedElementType(Op, KnownTy);
+      }
+    }
+  }
+}
+
 void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
                                               Instruction *New,
                                               IRBuilder<> &B) {
@@ -1126,6 +1151,25 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
+  for (auto &I : instructions(Func)) {
+    Type *ITy = I.getType();
+    if (!isPointerTy(ITy))
+      continue;
+    DenseMap<Value *, Type *> CollectedTys;
+    deduceOperandElementType(&I, CollectedTys);
+    Instruction *User;
+    for (const auto &Rec : CollectedTys) {
+      if (Rec.first->use_empty() ||
+          !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
+        continue;
+      Type *OpTy = Rec.first->getType();
+      setInsertPointSkippingPhis(B, User->getNextNode());
+      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+                      UndefValue::get(Rec.second), Rec.first,
+                      {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+    }
+  }
+
   // check if function parameter types are set
   if (!F->isIntrinsic())
     processParamTypes(F, B);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..bd14da0ecc557b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -589,7 +589,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
                                                 bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -782,8 +783,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
-  auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+                                         const MachineFunction *MF) const {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
   if (t != VRegToTypeMap.end()) {
     auto tt = t->second.find(VReg);
     if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 37f575e884ef48..4dcc66f741edd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -273,8 +273,12 @@ class SPIRVGlobalRegistry {
           SPIRV::AccessQualifier::ReadWrite);
 
   // Return the SPIR-V type instruction corresponding to the given VReg, or
-  // nullptr if no such type instruction exists.
-  SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+  // nullptr if no such type instruction exists. The second argument MF
+  // allows to search for the association in a context of the machine functions
+  // than the current one, without switching between different "current" machine
+  // functions.
+  SPIRVType *getSPIRVTypeForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr) const;
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
                              MachineInstr &I, unsigned OpIdx,
                              SPIRVType *ResType, const Type *ResTy = nullptr) {
+  // Get operand type
+  MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+  Register OpTypeReg =
       TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
           ? TypeInst->getOperand(1).getReg()
-          : OpReg);
+          : OpReg;
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
-  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  // Get operand's pointee type
+  Register ElemTypeReg = OpType->getOperand(2).getReg();
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
   if (!ElemType)
     return;
-  bool IsSameMF =
-      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  // Check if we need a bitcast to make a statement valid
+  bool IsSameMF = MF == ResType->getParent()->getParent();
   bool IsEqualTypes = IsSameMF ? ElemType == ResType
                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
   if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
     SPIRVType *DefElemType =
         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
-            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+                                     DefPtrType->getParent()->getParent())
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
 // with a processed definition. Return Function pointer if it's a forward
 // call (ahead of definition), and nullptr otherwise.
 const Function *validateFunCall(const SPIRVSubtarget &STI,
-                                MachineRegisterInfo *MRI,
+                                MachineRegisterInfo *CallMRI,
                                 SPIRVGlobalRegistry &GR,
                                 MachineInstr &FunCall) {
   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
   if (!FunDef)
     return F;
-  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+  validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
   return nullptr;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7e155a36aadbc4..2c964595fc39e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -64,9 +64,16 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
             assert(BuildVec &&
                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
-            for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
-              GR->add(ConstVec->getElementAsConstant(i), &MF,
-                      BuildVec->getOperand(1 + i).getReg());
+            for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
+              // Ensure that OpConstantComposite reuses a constant when it's
+              // already created and available in the same machine function.
+              Constant *ElemConst = ConstVec->getElementAsConstant(i);
+              Register ElemReg = GR->find(ElemConst, &MF);
+              if (!ElemReg.isValid())
+                GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
+              else
+                BuildVec->getOperand(1 + i).setReg(ElemReg);
+            }
           }
           GR->add(Const, &MF, MI.getOperand(2).getReg());
         } else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c87c1293c622fc..07ce9d9078de27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
 }
 
 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
-  return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  return toTypedPointer(ElementTy, N->getContext());
 }
 
 // The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
   return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
 }
 
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+  return isUntypedPointerTy(Ty)
+             ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+                                     getPointerAddressSpace(Ty))
+             : Ty;
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/const-composite.ll b/llvm/test/CodeGen/SPIRV/const-composite.ll
new file mode 100644
index 00000000000000..4e304bb9516702
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/const-composite.ll
@@ -0,0 +1,26 @@
+; This test is to ensure that OpConstantComposite reuses a constant when it's
+; already created and available in the same machine function. In this test case
+; it's `1` that is passed implicitly as a part of the `foo` function argument
+; and also takes part in a composite constant creation.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#type_int32:]] = OpTypeInt 32 0
+; CHECK-SPIRV: %[[#const1:]] = OpConstant %[[#type_int32]] 1
+; CHECK-SPIRV: OpTypeArray %[[#]] %[[#const1:]]
+; CHECK-SPIRV: %[[#const0:]] = OpConstant %[[#type_int32]] 0
+; CHECK-SPIRV: OpConstantComposite %[[#]] %[[#const0]] %[[#const1]]
+
+%struct = type { [1 x i64] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
+entry:
+  call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
+  ret void
+}
+
+define spir_func void @bar(<2 x i32> noundef) {
+entry:
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88254


More information about the llvm-commits mailing list