[llvm] [SPIR-V] Improve type inference, fix mismatched machine function context (PR #88254)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 10 07:18:15 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/88254
>From 617bd2d5e0f2ddf1a9fb7e6a155ef794bf13edc5 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 03:35:26 -0700
Subject: [PATCH 1/3] improve type inference, fix mismatched machine function
context
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 47 +++++++++++++++++++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 9 ++--
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 8 +++-
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 23 +++++----
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 3 +-
llvm/lib/Target/SPIRV/SPIRVUtils.h | 7 +++
6 files changed, 83 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..8113de6d7fd181 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -75,6 +75,9 @@ class SPIRVEmitIntrinsics
Type *deduceNestedTypeHelper(User *U, Type *Ty,
std::unordered_set<Value *> &Visited);
+ // deduce Types of operands of the Instruction if possible
+ void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -368,6 +371,28 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
return IntegerType::getInt8Ty(I->getContext());
}
+// Deduce Types of operands of the Instruction if possible.
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+ Value *I, DenseMap<Value *, Type *> &Collected) {
+ Type *KnownTy = GR->findDeducedElementType(I);
+ if (!KnownTy)
+ return;
+
+ // look for known basic patterns of type inference
+ if (auto *Ref = dyn_cast<PHINode>(I)) {
+ for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+ Value *Op = Ref->getIncomingValue(i);
+ if (!isUntypedPointerTy(Op->getType()))
+ continue;
+ Type *Ty = GR->findDeducedElementType(Op);
+ if (!Ty) {
+ Collected[Op] = KnownTy;
+ GR->addDeducedElementType(Op, KnownTy);
+ }
+ }
+ }
+}
+
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
Instruction *New,
IRBuilder<> &B) {
@@ -1126,6 +1151,28 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
processInstrAfterVisit(I, B);
}
+ for (auto &I : instructions(Func)) {
+ Type *ITy = I.getType();
+ if (!isPointerTy(ITy))
+ continue;
+ DenseMap<Value *, Type *> CollectedTys;
+ deduceOperandElementType(&I, CollectedTys);
+ if (CollectedTys.size() == 0)
+ continue;
+ for (const auto &Rec : CollectedTys) {
+ if (!Rec.first->use_empty()) {
+ Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
+ if (!User)
+ continue;
+ Type *OpTy = Rec.first->getType();
+ setInsertPointSkippingPhis(B, User->getNextNode());
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+ UndefValue::get(Rec.second), Rec.first,
+ {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+ }
+ }
+ }
+
// check if function parameter types are set
if (!F->isIntrinsic())
processParamTypes(F, B);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..bd14da0ecc557b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -589,7 +589,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+ SPIRVType *ElemTy =
+ findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -782,8 +783,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
return SpirvType;
}
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
- auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF) const {
+ auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 37f575e884ef48..4dcc66f741edd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -273,8 +273,12 @@ class SPIRVGlobalRegistry {
SPIRV::AccessQualifier::ReadWrite);
// Return the SPIR-V type instruction corresponding to the given VReg, or
- // nullptr if no such type instruction exists.
- SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+ // nullptr if no such type instruction exists. The second argument MF
+ // allows to search for the association in a context of the machine functions
+ // than the current one, without switching between different "current" machine
+ // functions.
+ SPIRVType *getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF = nullptr) const;
// Whether the given VReg has a SPIR-V type mapped to it yet.
bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
MachineInstr &I, unsigned OpIdx,
SPIRVType *ResType, const Type *ResTy = nullptr) {
+ // Get operand type
+ MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+ Register OpTypeReg =
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
? TypeInst->getOperand(1).getReg()
- : OpReg);
+ : OpReg;
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
- SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ // Get operand's pointee type
+ Register ElemTypeReg = OpType->getOperand(2).getReg();
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
if (!ElemType)
return;
- bool IsSameMF =
- ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+ // Check if we need a bitcast to make a statement valid
+ bool IsSameMF = MF == ResType->getParent()->getParent();
bool IsEqualTypes = IsSameMF ? ElemType == ResType
: GR.getTypeForSPIRVType(ElemType) == ResTy;
if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
SPIRVType *DefElemType =
DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
- ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+ ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+ DefPtrType->getParent()->getParent())
: nullptr;
if (DefElemType) {
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
// with a processed definition. Return Function pointer if it's a forward
// call (ahead of definition), and nullptr otherwise.
const Function *validateFunCall(const SPIRVSubtarget &STI,
- MachineRegisterInfo *MRI,
+ MachineRegisterInfo *CallMRI,
SPIRVGlobalRegistry &GR,
MachineInstr &FunCall) {
const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
if (!FunDef)
return F;
- validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+ MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+ validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
return nullptr;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c87c1293c622fc..07ce9d9078de27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
}
Type *getMDOperandAsType(const MDNode *N, unsigned I) {
- return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+ Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+ return toTypedPointer(ElementTy, N->getContext());
}
// The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
}
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+ return isUntypedPointerTy(Ty)
+ ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+ getPointerAddressSpace(Ty))
+ : Ty;
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
>From b07b588de5c9f12e1383648c1959acfb3ec56c9a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 04:33:58 -0700
Subject: [PATCH 2/3] re-format code
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 21 ++++++++-----------
1 file changed, 9 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8113de6d7fd181..e5d327780d4e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1157,19 +1157,16 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
continue;
DenseMap<Value *, Type *> CollectedTys;
deduceOperandElementType(&I, CollectedTys);
- if (CollectedTys.size() == 0)
- continue;
+ Instruction *User;
for (const auto &Rec : CollectedTys) {
- if (!Rec.first->use_empty()) {
- Instruction *User = dyn_cast<Instruction>(Rec.first->use_begin()->get());
- if (!User)
- continue;
- Type *OpTy = Rec.first->getType();
- setInsertPointSkippingPhis(B, User->getNextNode());
- buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
- UndefValue::get(Rec.second), Rec.first,
- {B.getInt32(getPointerAddressSpace(OpTy))}, B);
- }
+ if (Rec.first->use_empty() ||
+ !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
+ continue;
+ Type *OpTy = Rec.first->getType();
+ setInsertPointSkippingPhis(B, User->getNextNode());
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+ UndefValue::get(Rec.second), Rec.first,
+ {B.getInt32(getPointerAddressSpace(OpTy))}, B);
}
}
>From 37bc8eed9716ac8ccd3bcebb2c059134d7c89174 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 10 Apr 2024 07:18:01 -0700
Subject: [PATCH 3/3] fix crash on OpConstantComposite doesn't reuse existing
constant
---
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 13 ++++++++---
llvm/test/CodeGen/SPIRV/const-composite.ll | 26 +++++++++++++++++++++
2 files changed, 36 insertions(+), 3 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/const-composite.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7e155a36aadbc4..2c964595fc39e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -64,9 +64,16 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
assert(BuildVec &&
BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
- for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
- GR->add(ConstVec->getElementAsConstant(i), &MF,
- BuildVec->getOperand(1 + i).getReg());
+ for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
+ // Ensure that OpConstantComposite reuses a constant when it's
+ // already created and available in the same machine function.
+ Constant *ElemConst = ConstVec->getElementAsConstant(i);
+ Register ElemReg = GR->find(ElemConst, &MF);
+ if (!ElemReg.isValid())
+ GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
+ else
+ BuildVec->getOperand(1 + i).setReg(ElemReg);
+ }
}
GR->add(Const, &MF, MI.getOperand(2).getReg());
} else {
diff --git a/llvm/test/CodeGen/SPIRV/const-composite.ll b/llvm/test/CodeGen/SPIRV/const-composite.ll
new file mode 100644
index 00000000000000..4e304bb9516702
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/const-composite.ll
@@ -0,0 +1,26 @@
+; This test is to ensure that OpConstantComposite reuses a constant when it's
+; already created and available in the same machine function. In this test case
+; it's `1` that is passed implicitly as a part of the `foo` function argument
+; and also takes part in a composite constant creation.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#type_int32:]] = OpTypeInt 32 0
+; CHECK-SPIRV: %[[#const1:]] = OpConstant %[[#type_int32]] 1
+; CHECK-SPIRV: OpTypeArray %[[#]] %[[#const1:]]
+; CHECK-SPIRV: %[[#const0:]] = OpConstant %[[#type_int32]] 0
+; CHECK-SPIRV: OpConstantComposite %[[#]] %[[#const0]] %[[#const1]]
+
+%struct = type { [1 x i64] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
+entry:
+ call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
+ ret void
+}
+
+define spir_func void @bar(<2 x i32> noundef) {
+entry:
+ ret void
+}
More information about the llvm-commits
mailing list