[llvm] [SPIR-V] Type inference must realize that a <1 x Type> vector type is not a legal vector type in LLT (PR #124560)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 27 07:22:14 PST 2025
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/124560
In this PR we account for possible <1 x LLVM Type> input to ensure that we produce legal vector types during type inference.
We modify an LLVM type to conform with future transformations in IRTranslator, if it's a <1 x Type> vector type, replacing it by the element type, because <1 x Type> vector type is not a legal vector type in LLT and IRTranslator will represent it as the scalar eventually.
>From f4c6727837303b5a630d00b19c0dd38949a5eb95 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 27 Jan 2025 07:18:21 -0800
Subject: [PATCH] account for possible <1 x LLVM Type> input
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 42 ++++++++++---------
llvm/lib/Target/SPIRV/SPIRVUtils.h | 17 ++++++++
2 files changed, 40 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 702206b8e0dc56..96f67d6117e973 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -427,7 +427,7 @@ Type *SPIRVEmitIntrinsics::reconstructType(Value *Op, bool UnknownElemTypeI8,
void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
Value *Arg) {
- Value *OfType = PoisonValue::get(Ty);
+ Value *OfType = getNormalizedPoisonValue(Ty);
CallInst *AssignCI = nullptr;
if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
allowEmitFakeUse(Arg)) {
@@ -447,6 +447,7 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
Value *Arg) {
+ ElemTy = normalizeType(ElemTy);
Value *OfType = PoisonValue::get(ElemTy);
CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
if (AssignPtrTyCI == nullptr ||
@@ -470,7 +471,7 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
return;
// update association with the pointee type
- Type *ElemTy = OfType->getType();
+ Type *ElemTy = normalizeType(OfType->getType());
GR->addDeducedElementType(AssignCI, ElemTy);
GR->addDeducedElementType(Arg, ElemTy);
}
@@ -490,7 +491,7 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
}
Type *OpTy = Op->getType();
SmallVector<Type *, 2> Types = {OpTy, OpTy};
- SmallVector<Value *, 2> Args = {Op, buildMD(PoisonValue::get(ElemTy)),
+ SmallVector<Value *, 2> Args = {Op, buildMD(getNormalizedPoisonValue(ElemTy)),
B.getInt32(getPointerAddressSpace(OpTy))};
CallInst *PtrCasted =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -766,7 +767,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
// remember the found relationship
if (Ty && !IgnoreKnownType) {
// specify nested types if needed, otherwise return unchanged
- GR->addDeducedElementType(I, Ty);
+ GR->addDeducedElementType(I, normalizeType(Ty));
}
return Ty;
@@ -852,7 +853,7 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
}
if (Ty != OpTy) {
Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
- GR->addDeducedCompositeType(U, NewTy);
+ GR->addDeducedCompositeType(U, normalizeType(NewTy));
return NewTy;
}
}
@@ -990,6 +991,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
if (KnownElemTy)
return false;
if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+ OpElemTy = normalizeType(OpElemTy);
GR->addDeducedElementType(F, OpElemTy);
GR->addReturnType(
F, TypedPointerType::get(OpElemTy,
@@ -1002,7 +1004,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
continue;
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
- updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
+ updateAssignType(AssignCI, CI, getNormalizedPoisonValue(OpElemTy));
propagateElemType(CI, PrevElemTy, VisitedSubst);
}
}
@@ -1162,11 +1164,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
if (Ty == KnownElemTy)
continue;
- Value *OpTyVal = PoisonValue::get(KnownElemTy);
+ Value *OpTyVal = getNormalizedPoisonValue(KnownElemTy);
Type *OpTy = Op->getType();
if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
Type *PrevElemTy = GR->findDeducedElementType(Op);
- GR->addDeducedElementType(Op, KnownElemTy);
+ GR->addDeducedElementType(Op, normalizeType(KnownElemTy));
// check if KnownElemTy is complete
if (!Uncomplete)
eraseTodoType(Op);
@@ -1492,7 +1494,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
// Our previous guess about the type seems to be wrong, let's update
// inferred type according to a new, more precise type information.
- updateAssignType(AssignCI, V, PoisonValue::get(AssignedType));
+ updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType));
}
void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
@@ -1507,7 +1509,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
return;
setInsertPointSkippingPhis(B, I);
- Value *ExpectedElementVal = PoisonValue::get(ExpectedElementType);
+ Value *ExpectedElementVal = getNormalizedPoisonValue(ExpectedElementType);
MetadataAsValue *VMD = buildMD(ExpectedElementVal);
unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
bool FirstPtrCastOrAssignPtrType = true;
@@ -1653,7 +1655,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
if (!ElemTy) {
ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
if (ElemTy) {
- GR->addDeducedElementType(CalledArg, ElemTy);
+ GR->addDeducedElementType(CalledArg, normalizeType(ElemTy));
} else {
for (User *U : CalledArg->users()) {
if (Instruction *Inst = dyn_cast<Instruction>(U)) {
@@ -1984,8 +1986,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
Type *ElemTy = GR->findDeducedElementType(Op);
buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op);
} else {
- CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
- {OpTy}, Op, Op, {}, B);
+ CallInst *AssignCI =
+ buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy},
+ getNormalizedPoisonValue(OpTy), Op, {}, B);
GR->addAssignPtrTypeInstr(Op, AssignCI);
}
}
@@ -2034,7 +2037,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
Type *OpTy = Op->getType();
Value *OpTyVal = Op;
if (OpTy->isTargetExtTy())
- OpTyVal = PoisonValue::get(OpTy);
+ OpTyVal = getNormalizedPoisonValue(OpTy);
CallInst *NewOp =
buildIntrWithMD(Intrinsic::spv_track_constant,
{OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B);
@@ -2045,7 +2048,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
SmallVector<Type *, 2> Types = {OpTy, OpTy};
SmallVector<Value *, 2> Args = {
- NewOp, buildMD(PoisonValue::get(OpElemTy)),
+ NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
B.getInt32(getPointerAddressSpace(OpTy))};
CallInst *PtrCasted =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -2178,7 +2181,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Arg)) {
DenseSet<std::pair<Value *, Value *>> VisitedSubst;
- updateAssignType(AssignCI, Arg, PoisonValue::get(ElemTy));
+ updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy));
propagateElemType(Arg, IntegerType::getInt8Ty(F->getContext()),
VisitedSubst);
} else {
@@ -2232,7 +2235,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
continue;
if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
- updateAssignType(II, &F, PoisonValue::get(FPElemTy));
+ updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy));
break;
}
}
@@ -2256,7 +2259,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
for (Function *F : Worklist) {
SmallVector<Value *> Args;
for (const auto &Arg : F->args())
- Args.push_back(PoisonValue::get(Arg.getType()));
+ Args.push_back(getNormalizedPoisonValue(Arg.getType()));
IRB.CreateCall(F, Args);
}
IRB.CreateRetVoid();
@@ -2286,7 +2289,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
buildAssignPtr(B, ElemTy, Arg);
}
} else if (isa<Instruction>(Param)) {
- GR->addDeducedElementType(Param, ElemTy);
+ GR->addDeducedElementType(Param, normalizeType(ElemTy));
// insertAssignTypeIntrs() will complete buildAssignPtr()
} else {
B.SetInsertPoint(CI->getParent()
@@ -2302,6 +2305,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
if (!RefF || !isPointerTy(RefF->getReturnType()) ||
GR->findDeducedElementType(RefF))
continue;
+ ElemTy = normalizeType(ElemTy);
GR->addDeducedElementType(RefF, ElemTy);
GR->addReturnType(
RefF, TypedPointerType::get(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index fd48098257065a..ed7b2ef1becd95 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -383,6 +383,23 @@ inline const Type *unifyPtrType(const Type *Ty) {
return toTypedPointer(const_cast<Type *>(Ty));
}
+// Modify an LLVM type to conform with future transformations in IRTranslator.
+// At the moment use cases comprise only a <1 x Type> vector. To extend when/if
+// needed.
+inline Type *normalizeType(Type *Ty) {
+ auto *FVTy = dyn_cast<FixedVectorType>(Ty);
+ if (!FVTy || FVTy->getNumElements() != 1)
+ return Ty;
+ // If it's a <1 x Type> vector type, replace it by the element type, because
+ // it's not a legal vector type in LLT and IRTranslator will represent it as
+ // the scalar eventually.
+ return normalizeType(FVTy->getElementType());
+}
+
+inline PoisonValue *getNormalizedPoisonValue(Type *Ty) {
+ return PoisonValue::get(normalizeType(Ty));
+}
+
MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
#define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
More information about the llvm-commits
mailing list