[llvm] [SPIR-V] Improve type inference, fix mismatched machine function context (PR #88254)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 10 04:34:10 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/88254
>From 617bd2d5e0f2ddf1a9fb7e6a155ef794bf13edc5 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 03:35:26 -0700
Subject: [PATCH 1/2] improve type inference, fix mismatched machine function
context
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 47 +++++++++++++++++++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 9 ++--
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 8 +++-
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 23 +++++----
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 3 +-
llvm/lib/Target/SPIRV/SPIRVUtils.h | 7 +++
6 files changed, 83 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..8113de6d7fd181 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -75,6 +75,9 @@ class SPIRVEmitIntrinsics
Type *deduceNestedTypeHelper(User *U, Type *Ty,
std::unordered_set<Value *> &Visited);
+ // deduce Types of operands of the Instruction if possible
+ void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -368,6 +371,28 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
return IntegerType::getInt8Ty(I->getContext());
}
+// Deduce Types of operands of the Instruction if possible.
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+ Value *I, DenseMap<Value *, Type *> &Collected) {
+ Type *KnownTy = GR->findDeducedElementType(I);
+ if (!KnownTy)
+ return;
+
+ // look for known basic patterns of type inference
+ if (auto *Ref = dyn_cast<PHINode>(I)) {
+ for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+ Value *Op = Ref->getIncomingValue(i);
+ if (!isUntypedPointerTy(Op->getType()))
+ continue;
+ Type *Ty = GR->findDeducedElementType(Op);
+ if (!Ty) {
+ Collected[Op] = KnownTy;
+ GR->addDeducedElementType(Op, KnownTy);
+ }
+ }
+ }
+}
+
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
Instruction *New,
IRBuilder<> &B) {
@@ -1126,6 +1151,28 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
processInstrAfterVisit(I, B);
}
+ for (auto &I : instructions(Func)) {
+ Type *ITy = I.getType();
+ if (!isPointerTy(ITy))
+ continue;
+ DenseMap<Value *, Type *> CollectedTys;
+ deduceOperandElementType(&I, CollectedTys);
+ if (CollectedTys.size() == 0)
+ continue;
+ for (const auto &Rec : CollectedTys) {
+ if (!Rec.first->use_empty()) {
+ Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
+ if (!User)
+ continue;
+ Type *OpTy = Rec.first->getType();
+ setInsertPointSkippingPhis(B, User->getNextNode());
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+ UndefValue::get(Rec.second), Rec.first,
+ {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+ }
+ }
+ }
+
// check if function parameter types are set
if (!F->isIntrinsic())
processParamTypes(F, B);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..bd14da0ecc557b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -589,7 +589,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+ SPIRVType *ElemTy =
+ findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -782,8 +783,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
return SpirvType;
}
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
- auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF) const {
+ auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 37f575e884ef48..4dcc66f741edd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -273,8 +273,12 @@ class SPIRVGlobalRegistry {
SPIRV::AccessQualifier::ReadWrite);
// Return the SPIR-V type instruction corresponding to the given VReg, or
- // nullptr if no such type instruction exists.
- SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+ // nullptr if no such type instruction exists. The second argument MF
+ // allows to search for the association in a context of the machine functions
+ // than the current one, without switching between different "current" machine
+ // functions.
+ SPIRVType *getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF = nullptr) const;
// Whether the given VReg has a SPIR-V type mapped to it yet.
bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
MachineInstr &I, unsigned OpIdx,
SPIRVType *ResType, const Type *ResTy = nullptr) {
+ // Get operand type
+ MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+ Register OpTypeReg =
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
? TypeInst->getOperand(1).getReg()
- : OpReg);
+ : OpReg;
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
- SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ // Get operand's pointee type
+ Register ElemTypeReg = OpType->getOperand(2).getReg();
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
if (!ElemType)
return;
- bool IsSameMF =
- ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+ // Check if we need a bitcast to make a statement valid
+ bool IsSameMF = MF == ResType->getParent()->getParent();
bool IsEqualTypes = IsSameMF ? ElemType == ResType
: GR.getTypeForSPIRVType(ElemType) == ResTy;
if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
SPIRVType *DefElemType =
DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
- ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+ ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+ DefPtrType->getParent()->getParent())
: nullptr;
if (DefElemType) {
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
// with a processed definition. Return Function pointer if it's a forward
// call (ahead of definition), and nullptr otherwise.
const Function *validateFunCall(const SPIRVSubtarget &STI,
- MachineRegisterInfo *MRI,
+ MachineRegisterInfo *CallMRI,
SPIRVGlobalRegistry &GR,
MachineInstr &FunCall) {
const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
if (!FunDef)
return F;
- validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+ MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+ validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
return nullptr;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c87c1293c622fc..07ce9d9078de27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
}
Type *getMDOperandAsType(const MDNode *N, unsigned I) {
- return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+ Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+ return toTypedPointer(ElementTy, N->getContext());
}
// The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
}
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+ return isUntypedPointerTy(Ty)
+ ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+ getPointerAddressSpace(Ty))
+ : Ty;
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
>From b07b588de5c9f12e1383648c1959acfb3ec56c9a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 04:33:58 -0700
Subject: [PATCH 2/2] re-format code
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 21 ++++++++-----------
1 file changed, 9 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8113de6d7fd181..e5d327780d4e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1157,19 +1157,16 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
continue;
DenseMap<Value *, Type *> CollectedTys;
deduceOperandElementType(&I, CollectedTys);
- if (CollectedTys.size() == 0)
- continue;
+ Instruction *User;
for (const auto &Rec : CollectedTys) {
- if (!Rec.first->use_empty()) {
- Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
- if (!User)
- continue;
- Type *OpTy = Rec.first->getType();
- setInsertPointSkippingPhis(B, User->getNextNode());
- buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
- UndefValue::get(Rec.second), Rec.first,
- {B.getInt32(getPointerAddressSpace(OpTy))}, B);
- }
+ if (Rec.first->use_empty() ||
+ !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
+ continue;
+ Type *OpTy = Rec.first->getType();
+ setInsertPointSkippingPhis(B, User->getNextNode());
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+ UndefValue::get(Rec.second), Rec.first,
+ {B.getInt32(getPointerAddressSpace(OpTy))}, B);
}
}
More information about the llvm-commits
mailing list