[llvm] [SPIR-V] Type inference must realize that a <1 x Type> vector type is not a legal vector type in LLT (PR #124560)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 27 07:22:14 PST 2025


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/124560

In this PR we account for possible <1 x LLVM Type> input to ensure that we produce legal vector types during type inference.

We modify an LLVM type to conform with future transformations in IRTranslator, if it's a <1 x Type> vector type, replacing it by the element type, because <1 x Type> vector type is not a legal vector type in LLT and IRTranslator will represent it as the scalar eventually.

>From f4c6727837303b5a630d00b19c0dd38949a5eb95 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 27 Jan 2025 07:18:21 -0800
Subject: [PATCH] account for possible <1 x LLVM Type> input

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 42 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVUtils.h            | 17 ++++++++
 2 files changed, 40 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 702206b8e0dc56..96f67d6117e973 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -427,7 +427,7 @@ Type *SPIRVEmitIntrinsics::reconstructType(Value *Op, bool UnknownElemTypeI8,
 
 void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
                                           Value *Arg) {
-  Value *OfType = PoisonValue::get(Ty);
+  Value *OfType = getNormalizedPoisonValue(Ty);
   CallInst *AssignCI = nullptr;
   if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
       allowEmitFakeUse(Arg)) {
@@ -447,6 +447,7 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
 
 void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
                                          Value *Arg) {
+  ElemTy = normalizeType(ElemTy);
   Value *OfType = PoisonValue::get(ElemTy);
   CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
   if (AssignPtrTyCI == nullptr ||
@@ -470,7 +471,7 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
     return;
 
   // update association with the pointee type
-  Type *ElemTy = OfType->getType();
+  Type *ElemTy = normalizeType(OfType->getType());
   GR->addDeducedElementType(AssignCI, ElemTy);
   GR->addDeducedElementType(Arg, ElemTy);
 }
@@ -490,7 +491,7 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
   }
   Type *OpTy = Op->getType();
   SmallVector<Type *, 2> Types = {OpTy, OpTy};
-  SmallVector<Value *, 2> Args = {Op, buildMD(PoisonValue::get(ElemTy)),
+  SmallVector<Value *, 2> Args = {Op, buildMD(getNormalizedPoisonValue(ElemTy)),
                                   B.getInt32(getPointerAddressSpace(OpTy))};
   CallInst *PtrCasted =
       B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -766,7 +767,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   // remember the found relationship
   if (Ty && !IgnoreKnownType) {
     // specify nested types if needed, otherwise return unchanged
-    GR->addDeducedElementType(I, Ty);
+    GR->addDeducedElementType(I, normalizeType(Ty));
   }
 
   return Ty;
@@ -852,7 +853,7 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
       }
       if (Ty != OpTy) {
         Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
-        GR->addDeducedCompositeType(U, NewTy);
+        GR->addDeducedCompositeType(U, normalizeType(NewTy));
         return NewTy;
       }
     }
@@ -990,6 +991,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
   if (KnownElemTy)
     return false;
   if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+    OpElemTy = normalizeType(OpElemTy);
     GR->addDeducedElementType(F, OpElemTy);
     GR->addReturnType(
         F, TypedPointerType::get(OpElemTy,
@@ -1002,7 +1004,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
         continue;
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
         if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
-          updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
+          updateAssignType(AssignCI, CI, getNormalizedPoisonValue(OpElemTy));
           propagateElemType(CI, PrevElemTy, VisitedSubst);
         }
       }
@@ -1162,11 +1164,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
     Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
     if (Ty == KnownElemTy)
       continue;
-    Value *OpTyVal = PoisonValue::get(KnownElemTy);
+    Value *OpTyVal = getNormalizedPoisonValue(KnownElemTy);
     Type *OpTy = Op->getType();
     if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
       Type *PrevElemTy = GR->findDeducedElementType(Op);
-      GR->addDeducedElementType(Op, KnownElemTy);
+      GR->addDeducedElementType(Op, normalizeType(KnownElemTy));
       // check if KnownElemTy is complete
       if (!Uncomplete)
         eraseTodoType(Op);
@@ -1492,7 +1494,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
 
   // Our previous guess about the type seems to be wrong, let's update
   // inferred type according to a new, more precise type information.
-  updateAssignType(AssignCI, V, PoisonValue::get(AssignedType));
+  updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType));
 }
 
 void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
@@ -1507,7 +1509,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     return;
 
   setInsertPointSkippingPhis(B, I);
-  Value *ExpectedElementVal = PoisonValue::get(ExpectedElementType);
+  Value *ExpectedElementVal = getNormalizedPoisonValue(ExpectedElementType);
   MetadataAsValue *VMD = buildMD(ExpectedElementVal);
   unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
   bool FirstPtrCastOrAssignPtrType = true;
@@ -1653,7 +1655,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       if (!ElemTy) {
         ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
         if (ElemTy) {
-          GR->addDeducedElementType(CalledArg, ElemTy);
+          GR->addDeducedElementType(CalledArg, normalizeType(ElemTy));
         } else {
           for (User *U : CalledArg->users()) {
             if (Instruction *Inst = dyn_cast<Instruction>(U)) {
@@ -1984,8 +1986,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
           Type *ElemTy = GR->findDeducedElementType(Op);
           buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op);
         } else {
-          CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
-                                               {OpTy}, Op, Op, {}, B);
+          CallInst *AssignCI =
+              buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy},
+                              getNormalizedPoisonValue(OpTy), Op, {}, B);
           GR->addAssignPtrTypeInstr(Op, AssignCI);
         }
       }
@@ -2034,7 +2037,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
       Type *OpTy = Op->getType();
       Value *OpTyVal = Op;
       if (OpTy->isTargetExtTy())
-        OpTyVal = PoisonValue::get(OpTy);
+        OpTyVal = getNormalizedPoisonValue(OpTy);
       CallInst *NewOp =
           buildIntrWithMD(Intrinsic::spv_track_constant,
                           {OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B);
@@ -2045,7 +2048,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
         buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
         SmallVector<Type *, 2> Types = {OpTy, OpTy};
         SmallVector<Value *, 2> Args = {
-            NewOp, buildMD(PoisonValue::get(OpElemTy)),
+            NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
             B.getInt32(getPointerAddressSpace(OpTy))};
         CallInst *PtrCasted =
             B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -2178,7 +2181,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
     if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Arg)) {
         DenseSet<std::pair<Value *, Value *>> VisitedSubst;
-        updateAssignType(AssignCI, Arg, PoisonValue::get(ElemTy));
+        updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy));
         propagateElemType(Arg, IntegerType::getInt8Ty(F->getContext()),
                           VisitedSubst);
       } else {
@@ -2232,7 +2235,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
           continue;
         if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
             II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
-          updateAssignType(II, &F, PoisonValue::get(FPElemTy));
+          updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy));
           break;
         }
       }
@@ -2256,7 +2259,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
   for (Function *F : Worklist) {
     SmallVector<Value *> Args;
     for (const auto &Arg : F->args())
-      Args.push_back(PoisonValue::get(Arg.getType()));
+      Args.push_back(getNormalizedPoisonValue(Arg.getType()));
     IRB.CreateCall(F, Args);
   }
   IRB.CreateRetVoid();
@@ -2286,7 +2289,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
             buildAssignPtr(B, ElemTy, Arg);
           }
         } else if (isa<Instruction>(Param)) {
-          GR->addDeducedElementType(Param, ElemTy);
+          GR->addDeducedElementType(Param, normalizeType(ElemTy));
           // insertAssignTypeIntrs() will complete buildAssignPtr()
         } else {
           B.SetInsertPoint(CI->getParent()
@@ -2302,6 +2305,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
         if (!RefF || !isPointerTy(RefF->getReturnType()) ||
             GR->findDeducedElementType(RefF))
           continue;
+        ElemTy = normalizeType(ElemTy);
         GR->addDeducedElementType(RefF, ElemTy);
         GR->addReturnType(
             RefF, TypedPointerType::get(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index fd48098257065a..ed7b2ef1becd95 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -383,6 +383,23 @@ inline const Type *unifyPtrType(const Type *Ty) {
   return toTypedPointer(const_cast<Type *>(Ty));
 }
 
+// Modify an LLVM type to conform with future transformations in IRTranslator.
+// At the moment use cases comprise only a <1 x Type> vector. To extend when/if
+// needed.
+inline Type *normalizeType(Type *Ty) {
+  auto *FVTy = dyn_cast<FixedVectorType>(Ty);
+  if (!FVTy || FVTy->getNumElements() != 1)
+    return Ty;
+  // If it's a <1 x Type> vector type, replace it by the element type, because
+  // it's not a legal vector type in LLT and IRTranslator will represent it as
+  // the scalar eventually.
+  return normalizeType(FVTy->getElementType());
+}
+
+inline PoisonValue *getNormalizedPoisonValue(Type *Ty) {
+  return PoisonValue::get(normalizeType(Ty));
+}
+
 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 
 #define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"



More information about the llvm-commits mailing list