[llvm] [SPIR-V] Improve type inference: fix types of return values in call lowering (PR #116609)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 19 12:55:28 PST 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/116609
>From d89bb33f0adb8d98fe668765353f4d79ec64dbd8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 04:47:27 -0800
Subject: [PATCH 1/3] Improve type inference: return values in call lowering
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 18 +-----
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 25 +++++++++
.../SPIRV/pointers/builtin-ret-reg-type.ll | 55 +++++++++++++++++++
.../SPIRV/transcoding/OpGenericCastToPtr.ll | 2 -
4 files changed, 83 insertions(+), 17 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..e34f6c3c282750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2517,23 +2517,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
SPIRVGlobalRegistry *GR) {
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
- // SPIR-V type and return register.
- Register ReturnRegister = OrigRet;
- SPIRVType *ReturnType = nullptr;
- if (OrigRetTy && !OrigRetTy->isVoidTy()) {
- ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
- if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
- MIRBuilder.getMRI()->setRegClass(ReturnRegister,
- GR->getRegClass(ReturnType));
- } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
- ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
- MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
- ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
- }
-
// Lookup the builtin in the TableGen records.
+ SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+ assert(SpvType && "Inconsistent return register: expected valid type info");
std::unique_ptr<const IncomingCall> Call =
- lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+ lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
if (!Call) {
LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..a7b6b0efa99551 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,31 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (isFunctionDecl && !DemangledName.empty() &&
(canUseGLSL || canUseOpenCL)) {
+ if (ResVReg.isValid()) {
+ if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+ const Type *RetTy = OrigRetTy;
+ if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+ const Value *OrigValue = Info.OrigRet.OrigValue;
+ if (!OrigValue)
+ OrigValue = Info.CB;
+ if (OrigValue)
+ if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+ RetTy =
+ TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+ }
+ SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
+ GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+ if (!MRI->getRegClassOrNull(ResVReg)) {
+ MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
+ MRI->setType(ResVReg, GR->getRegType(SpvType));
+ }
+ }
+ } else {
+ SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
+ ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+ MRI->setType(ResVReg, GR->getRegType(SpvType));
+ GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+ }
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
new file mode 100644
index 00000000000000..afa97ccfc0a69c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers which were
+; used as return values in call lowering. Pass criterion is that spirv-val considers output valid.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpPhi %[[#]] %[[#Ptr:]] %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV: %[[#Ptr]] = OpPtrAccessChain %[[#]] %[[#]] %[[#]]
+
+
+%t_half = type { half }
+%t_i17 = type { [17 x i32] }
+%t_h17 = type { [17 x %t_half] }
+
+define internal spir_func void @foo(i64 %arrayinit.cur.add_4, half %r1, ptr addrspace(4) noundef align 8 dereferenceable_or_null(72) %this) {
+entry:
+ %r_3 = alloca %t_h17, align 8
+ %p_src = alloca %t_i17, align 4
+ %p_src4 = addrspacecast ptr %p_src to ptr addrspace(4)
+ %call_2 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %p_src4, i32 noundef 7)
+ br label %l_body
+
+l_body: ; preds = %l_body, %entry
+ %l_done = icmp eq i64 %arrayinit.cur.add_4, 34
+ br i1 %l_done, label %exit, label %l_body
+
+exit: ; preds = %l_body
+ %0 = addrspacecast ptr %call_2 to ptr addrspace(4)
+ %call_6 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %0, i32 noundef 7)
+ br label %for.cond_3
+
+for.cond_3: ; preds = %for.body_3, %exit
+ %lsr.iv1 = phi ptr [ %scevgep2, %for.body_3 ], [ %call_6, %exit ]
+ %lsr.iv = phi ptr [ %scevgep, %for.body_3 ], [ %r_3, %exit ]
+ %i.0_3 = phi i64 [ 0, %exit ], [ %inc_3, %for.body_3 ]
+ %cmp_3 = icmp ult i64 %i.0_3, 17
+ br i1 %cmp_3, label %for.body_3, label %exit2
+
+for.body_3: ; preds = %for.cond_3
+ %call2_5 = call spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef %r1, ptr noundef %lsr.iv1)
+ store half %call2_5, ptr %lsr.iv, align 2
+ %inc_3 = add nuw nsw i64 %i.0_3, 1
+ %scevgep = getelementptr i8, ptr %lsr.iv, i64 2
+ %scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
+ br label %for.cond_3
+
+exit2: ; preds = %for.cond_3
+ ret void
+}
+
+declare dso_local spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef, ptr noundef)
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
index 54b2c786747768..2cba0f6ebd74be 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
@@ -2,9 +2,7 @@
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
-; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
-; CHECK-SPIRV-DAG: %[[#PrivateCharPtr:]] = OpTypePointer Function %[[#Char]]
; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0
>From 6e23f694f2ab278bcb70c5ec1dd2629a1e0e9a8b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 07:50:54 -0800
Subject: [PATCH 2/3] add and use internal api call to create registers/assign
types; fix v-reg type/class assignments
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 10 ++--
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 12 +----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 2 +-
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 5 +-
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 9 +---
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 49 +++++++++++++++++++
llvm/lib/Target/SPIRV/SPIRVUtils.h | 16 ++++++
9 files changed, 79 insertions(+), 32 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index e34f6c3c282750..bed34b83d2e546 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -447,12 +447,8 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR, LLT LowLevelType,
Register DestinationReg = Register(0)) {
- MachineRegisterInfo *MRI = MIRBuilder.getMRI();
- if (!DestinationReg.isValid()) {
- DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
- MRI->setType(DestinationReg, LLT::scalar(64));
- GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
- }
+ if (!DestinationReg.isValid())
+ DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
// TODO: consider using correct address space and alignment (p0 is canonical
// type for selection though).
MachinePointerInfo PtrInfo = MachinePointerInfo();
@@ -2129,7 +2125,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
for (unsigned I = 0; I < LocalSizeNum; ++I) {
- Register Reg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+ Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
MRI->setType(Reg, LLType);
GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
auto GEPInst = MIRBuilder.buildIntrinsic(
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index a7b6b0efa99551..3fdaa6aa3257ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -551,18 +551,10 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
RetTy =
TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
}
- SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
- GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
- if (!MRI->getRegClassOrNull(ResVReg)) {
- MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
- MRI->setType(ResVReg, GR->getRegType(SpvType));
- }
+ setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
}
} else {
- SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
- ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
- MRI->setType(ResVReg, GR->getRegType(SpvType));
- GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+ ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
}
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6f222883ee07de..4e539fcd6c9999 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -69,7 +69,7 @@ SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
Register VReg,
- MachineFunction &MF) {
+ const MachineFunction &MF) {
VRegToTypeMap[&MF][VReg] = SpirvType;
}
@@ -578,7 +578,7 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
if (!Res.isValid()) {
LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
- CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
+ CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 3bb86e8be69500..ff4b0ea8757fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -330,7 +330,7 @@ class SPIRVGlobalRegistry {
// In cases where the SPIR-V type is already known, this function can be
// used to map it to the given VReg via an ASSIGN_TYPE instruction.
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg,
- MachineFunction &MF);
+ const MachineFunction &MF);
// Either generate a new OpTypeXXX instruction or return an existing one
// corresponding to the given LLVM IR type.
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 59a1bf50b771b9..b53ea1f7edf4a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -111,8 +111,8 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
SPIRVGlobalRegistry &GR, MachineInstr &I,
Register OpReg, unsigned OpIdx,
SPIRVType *NewPtrType) {
- Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MachineIRBuilder MIB(I);
+ Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(NewPtrType))
@@ -121,8 +121,6 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
*STI.getRegBankInfo());
if (!Res)
report_fatal_error("insert validation bitcast: cannot constrain all uses");
- MRI->setRegClass(NewReg, &SPIRV::iIDRegClass);
- GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
I.getOperand(OpIdx).setReg(NewReg);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 460f0127d4ffcd..bd04b8c1c0b333 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -357,12 +357,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
verify(*ST.getInstrInfo());
}
-static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
+static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
LegalizerHelper &Helper,
MachineRegisterInfo &MRI,
SPIRVGlobalRegistry *GR) {
Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
- GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
+ MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
+ GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
.addDef(ConvReg)
.addUse(Reg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 11b9e4f6f6d17b..3373d8e24dab48 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -102,10 +102,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
if (!ResType) {
// There was no "assign type" actions, let's fix this now
ResType = ScalarType;
- MRI.setRegClass(ResVReg, &SPIRV::iIDRegClass);
- MRI.setType(ResVReg,
- LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
- GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
+ setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
}
}
} else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
@@ -124,9 +121,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
if (!ResVType)
continue;
// Set type & class
- MRI.setRegClass(ResVReg, GR->getRegClass(ResVType));
- MRI.setType(ResVReg, GR->getRegType(ResVType));
- GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
+ setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
}
// If this is a simple operation that is to be reduced by TableGen
// definition we must apply some of pre-legalizer rules here
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index aeb2c29f7b8618..7e5bb1990626ff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -13,6 +13,7 @@
#include "SPIRVUtils.h"
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRV.h"
+#include "SPIRVGlobalRegistry.h"
#include "SPIRVInstrInfo.h"
#include "SPIRVSubtarget.h"
#include "llvm/ADT/StringRef.h"
@@ -677,4 +678,52 @@ bool getVacantFunctionName(Module &M, std::string &Name) {
return false;
}
+// Assign SPIR-V type to the register. If the register has no valid assigned
+// class, set register LLT type and class according to the SPIR-V type.
+void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo *MRI, const MachineFunction &MF,
+ bool Force) {
+ GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
+ if (!MRI->getRegClassOrNull(Reg) || Force) {
+ MRI->setRegClass(Reg, GR->getRegClass(SpvType));
+ MRI->setType(Reg, GR->getRegType(SpvType));
+ }
+}
+
+// Create a SPIR-V type, assign SPIR-V type to the register. If the register has
+// no valid assigned class, set register LLT type and class according to the
+// SPIR-V type.
+void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder, bool Force) {
+ setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
+ MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
+}
+
+// Create a virtual register and assign SPIR-V type to the register. Set
+// register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo *MRI,
+ const MachineFunction &MF) {
+ Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+ MRI->setType(Reg, GR->getRegType(SpvType));
+ GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
+ return Reg;
+}
+
+// Create a virtual register and assign SPIR-V type to the register. Set
+// register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder) {
+ return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(),
+ MIRBuilder.getMF());
+}
+
+// Create a SPIR-V type, virtual register and assign SPIR-V type to the
+// register. Set register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder) {
+ return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
+ MIRBuilder);
+}
+
} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 298b0b93b0e4d2..7a292b52bd1d16 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -34,6 +34,7 @@ class Register;
class StringRef;
class SPIRVInstrInfo;
class SPIRVSubtarget;
+class SPIRVGlobalRegistry;
// This class implements a partial ordering visitor, which visits a cyclic graph
// in natural topological-like ordering. Topological ordering is not defined for
@@ -355,5 +356,20 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
#define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
bool getVacantFunctionName(Module &M, std::string &Name);
+void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder, bool Force = false);
+void setRegClassType(Register Reg, const MachineInstr *SpvType,
+ SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
+ const MachineFunction &MF, bool Force = false);
+Register createVirtualRegister(const MachineInstr *SpvType,
+ SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo *MRI,
+ const MachineFunction &MF);
+Register createVirtualRegister(const MachineInstr *SpvType,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder);
+Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder);
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
>From c87bad4db5511940a3d061bb9f94b7b9b8fb1535 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 19 Nov 2024 12:55:15 -0800
Subject: [PATCH 3/3] improve type inference: change processing order,
calculate uncomplete types, speed up postprocessing of types
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 164 ++++++++++++------
.../fp_two_calls.ll | 12 +-
.../CodeGen/SPIRV/pointers/phi-chain-types.ll | 82 +++++++++
3 files changed, 200 insertions(+), 58 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e6ef40e010dc20..c98c22641273ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -78,6 +78,11 @@ class SPIRVEmitIntrinsics
// a register of Instructions that don't have a complete type definition
DenseMap<Value *, unsigned> UncompleteTypeInfo;
SmallVector<Value *> PostprocessWorklist;
+ void addToUncompleteTypeInfo(Value *Op) {
+ auto It = UncompleteTypeInfo.try_emplace(Op, PostprocessWorklist.size());
+ if (It.second)
+ PostprocessWorklist.push_back(Op);
+ }
// well known result types of builtins
enum WellKnownTypes { Event };
@@ -105,8 +110,9 @@ class SPIRVEmitIntrinsics
bool UnknownElemTypeI8);
// deduce Types of operands of the Instruction if possible
- void deduceOperandElementType(Instruction *I, Instruction *AskOp = 0,
- Type *AskTy = 0, CallInst *AssignCI = 0);
+ void deduceOperandElementType(Instruction *I,
+ const SmallPtrSet<Value *, 4> *AskOps = nullptr,
+ SmallPtrSet<Value *, 16> *Completed = nullptr);
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -145,12 +151,20 @@ class SPIRVEmitIntrinsics
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);
+
+ bool deduceOperandElementTypeCalledFunction(
+ SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
+ SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy);
+ void deduceOperandElementTypeFunctionPointer(
+ CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+ Type *&KnownElemTy, SmallPtrSet<Value *, 16> *Completed);
+
void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
CallInst *AssignCI);
void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);
bool runOnFunction(Function &F);
- bool postprocessTypes();
+ bool postprocessTypes(Module &M);
bool processFunctionPointers(Module &M);
public:
@@ -286,11 +300,11 @@ void SPIRVEmitIntrinsics::replaceAllUsesWith(Value *Src, Value *Dest,
if (DeleteOld) {
unsigned Pos = It->second;
UncompleteTypeInfo.erase(Src);
- UncompleteTypeInfo[Dest] = Pos;
- PostprocessWorklist[Pos] = Dest;
+ auto It = UncompleteTypeInfo.try_emplace(Dest, Pos);
+ if (It.second)
+ PostprocessWorklist[Pos] = Dest;
} else {
- UncompleteTypeInfo[Dest] = PostprocessWorklist.size();
- PostprocessWorklist.push_back(Dest);
+ addToUncompleteTypeInfo(Dest);
}
}
@@ -455,10 +469,7 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
if (isUntypedPointerTy(RefTy)) {
if (!UnknownElemTypeI8)
return;
- if (auto *I = dyn_cast<Instruction>(Op)) {
- UncompleteTypeInfo[I] = PostprocessWorklist.size();
- PostprocessWorklist.push_back(I);
- }
+ addToUncompleteTypeInfo(Op);
}
Ty = RefTy;
}
@@ -661,10 +672,7 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
return Ty;
if (!UnknownElemTypeI8)
return nullptr;
- if (auto *Instr = dyn_cast<Instruction>(I)) {
- UncompleteTypeInfo[Instr] = PostprocessWorklist.size();
- PostprocessWorklist.push_back(Instr);
- }
+ addToUncompleteTypeInfo(I);
return IntegerType::getInt8Ty(I->getContext());
}
@@ -683,8 +691,7 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
// Try to deduce element type for a call base. Returns false if this is an
// indirect function invocation, and true otherwise.
-static bool deduceOperandElementTypeCalledFunction(
- SPIRVGlobalRegistry *GR, Instruction *I,
+bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
Function *CalledF = CI->getCalledFunction();
@@ -726,7 +733,7 @@ static bool deduceOperandElementTypeCalledFunction(
case SPIRV::OpAtomicUMax:
case SPIRV::OpAtomicSMin:
case SPIRV::OpAtomicSMax: {
- KnownElemTy = getAtomicElemTy(GR, I, Op);
+ KnownElemTy = getAtomicElemTy(GR, CI, Op);
if (!KnownElemTy)
return true;
Ops.push_back(std::make_pair(Op, 0));
@@ -738,32 +745,44 @@ static bool deduceOperandElementTypeCalledFunction(
}
// Try to deduce element type for a function pointer.
-static void deduceOperandElementTypeFunctionPointer(
- SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
- SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
+void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
+ CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+ Type *&KnownElemTy, SmallPtrSet<Value *, 16> *Completed) {
Value *Op = CI->getCalledOperand();
if (!Op || !isPointerTy(Op->getType()))
return;
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
FunctionType *FTy = CI->getFunctionType();
- bool IsNewFTy = false;
+ bool IsNewFTy = false, IsUncomplete = false;
SmallVector<Type *, 4> ArgTys;
for (Value *Arg : CI->args()) {
Type *ArgTy = Arg->getType();
- if (ArgTy->isPointerTy())
+ if (ArgTy->isPointerTy()) {
if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
IsNewFTy = true;
ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
+ if (UncompleteTypeInfo.contains(Arg))
+ IsUncomplete = true;
+ } else {
+ IsUncomplete = true;
}
+ }
ArgTys.push_back(ArgTy);
}
Type *RetTy = FTy->getReturnType();
- if (I->getType()->isPointerTy())
- if (Type *ElemTy = GR->findDeducedElementType(I)) {
+ if (CI->getType()->isPointerTy()) {
+ if (Type *ElemTy = GR->findDeducedElementType(CI)) {
IsNewFTy = true;
RetTy =
- TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
+ TypedPointerType::get(ElemTy, getPointerAddressSpace(CI->getType()));
+ if (UncompleteTypeInfo.contains(CI))
+ IsUncomplete = true;
+ } else {
+ IsUncomplete = true;
}
+ }
+ if (!Completed && IsUncomplete)
+ addToUncompleteTypeInfo(Op);
KnownElemTy =
IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
}
@@ -772,10 +791,9 @@ static void deduceOperandElementTypeFunctionPointer(
// tries to deduce them. If the Instruction has Pointer operands with known
// types which differ from expected, this function tries to insert a bitcast to
// resolve the issue.
-void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
- Instruction *AskOp,
- Type *AskTy,
- CallInst *AskCI) {
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+ Instruction *I, const SmallPtrSet<Value *, 4> *AskOps,
+ SmallPtrSet<Value *, 16> *Completed) {
SmallVector<std::pair<Value *, unsigned>> Ops;
Type *KnownElemTy = nullptr;
// look for known basic patterns of type inference
@@ -875,10 +893,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (!CI->isIndirectCall())
- deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
- KnownElemTy);
+ deduceOperandElementTypeCalledFunction(InstrSet, CI, Ops, KnownElemTy);
else if (HaveFunPtrs)
- deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
+ deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy, Completed);
}
// There is no enough info to deduce types or all is valid.
@@ -889,9 +906,19 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
IRBuilder<> B(Ctx);
for (auto &OpIt : Ops) {
Value *Op = OpIt.first;
- if (Op->use_empty() || (AskOp && Op != AskOp))
+ if (Op->use_empty())
continue;
- Type *Ty = AskOp ? AskTy : GR->findDeducedElementType(Op);
+ Type *AskTy = nullptr;
+ CallInst *AskCI = nullptr;
+ if (AskOps) {
+ auto It = AskOps->find(Op);
+ if (It == AskOps->end())
+ continue;
+ AskTy = GR->findDeducedElementType(Op);
+ AskCI = GR->findAssignPtrTypeInstr(Op);
+ assert(AskTy && AskCI);
+ }
+ Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
if (Ty == KnownElemTy)
continue;
Value *OpTyVal = PoisonValue::get(KnownElemTy);
@@ -899,6 +926,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
if (!Ty || AskTy || isUntypedPointerTy(Ty) ||
UncompleteTypeInfo.contains(Op)) {
GR->addDeducedElementType(Op, KnownElemTy);
+ // check if KnownElemTy is complete
+ if (!Completed && UncompleteTypeInfo.contains(I))
+ addToUncompleteTypeInfo(Op);
// check if there is existing Intrinsic::spv_assign_ptr_type instruction
CallInst *AssignCI = AskCI ? AskCI : GR->findAssignPtrTypeInstr(Op);
if (AssignCI == nullptr) {
@@ -910,6 +940,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
GR->addAssignPtrTypeInstr(Op, CI);
} else {
updateAssignType(AssignCI, Op, OpTyVal);
+ if (Completed)
+ Completed->insert(Op);
}
} else {
if (auto *OpI = dyn_cast<Instruction>(Op)) {
@@ -1878,6 +1910,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
for (auto &I : instructions(Func))
Worklist.push_back(&I);
+ // Pass forward: use operand to deduce instructions result.
for (auto &I : Worklist) {
// Don't emit intrinsincs for convergence intrinsics.
if (isConvergenceIntrinsic(I))
@@ -1894,9 +1927,17 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
insertAssignPtrTypeIntrs(I, B, true);
}
- for (auto &I : instructions(Func))
+ // Pass backward: use instructions results to specify/update/cast operands
+ // where needed.
+ for (auto &I : llvm::reverse(instructions(Func)))
deduceOperandElementType(&I);
+ // Pass forward for PHIs only, their operands are not preceed the instruction
+ // in meaning of `instructions(Func)`.
+ for (BasicBlock &BB : Func)
+ for (PHINode &Phi : BB.phis())
+ deduceOperandElementType(&Phi);
+
for (auto *I : Worklist) {
TrackConstants = true;
if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1938,16 +1979,19 @@ void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
}
// Try to deduce a better type for pointers to untyped ptr.
-bool SPIRVEmitIntrinsics::postprocessTypes() {
- bool Changed = false;
- if (!GR)
- return Changed;
+bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
+ if (!GR || UncompleteTypeInfo.size() == 0)
+ return false;
+
+ DenseMap<Value *, SmallPtrSet<Value *, 4>> ToProcess;
+ SmallPtrSet<Value *, 16> Completed;
for (auto IB = PostprocessWorklist.rbegin(), IE = PostprocessWorklist.rend();
IB != IE; ++IB) {
CallInst *AssignCI = GR->findAssignPtrTypeInstr(*IB);
Type *KnownTy = GR->findDeducedElementType(*IB);
- if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
+ if (!KnownTy || !AssignCI)
continue;
+ assert(AssignCI->getArgOperand(0) == *IB);
// Try to improve the type deduced after all Functions are processed.
if (auto *CI = dyn_cast<CallInst>(*IB)) {
if (Function *CalledF = CI->getCalledFunction()) {
@@ -1955,24 +1999,37 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
// Fix inconsistency between known type and function's return type.
if (RetElemTy && RetElemTy != KnownTy) {
replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
- Changed = true;
+ Completed.insert(CI);
continue;
}
}
}
- Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
- for (User *U : I->users()) {
+ Value *Op = AssignCI->getArgOperand(0);
+ for (User *U : Op->users()) {
Instruction *Inst = dyn_cast<Instruction>(U);
- if (!Inst || isa<IntrinsicInst>(Inst))
+ if (Inst && !isa<IntrinsicInst>(Inst))
+ ToProcess[Inst].insert(Op);
+ }
+ }
+ if (Completed.size() >= UncompleteTypeInfo.size())
+ return true;
+
+ for (auto &F : M) {
+ for (auto &I : llvm::reverse(instructions(F))) {
+ auto It = ToProcess.find(&I);
+ if (It == ToProcess.end())
continue;
- deduceOperandElementType(Inst, I, KnownTy, AssignCI);
- if (KnownTy != GR->findDeducedElementType(I)) {
- Changed = true;
- break;
- }
+ It->second.remove_if(
+ [&Completed](Value *V) { return Completed.contains(V); });
+ if (It->second.size() == 0)
+ continue;
+ deduceOperandElementType(&I, &It->second, &Completed);
+ if (Completed.size() >= UncompleteTypeInfo.size())
+ return true;
}
}
- return Changed;
+
+ return Completed.size() > 0;
}
bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
@@ -1983,17 +2040,16 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
for (auto &F : M)
Changed |= runOnFunction(F);
+ // Specify function parameters after all functions were processed.
for (auto &F : M) {
// check if function parameter types are set
if (!F.isDeclaration() && !F.isIntrinsic()) {
- const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
- GR = ST.getSPIRVGlobalRegistry();
IRBuilder<> B(F.getContext());
processParamTypes(&F, B);
}
}
- Changed |= postprocessTypes();
+ Changed |= postprocessTypes(M);
if (HaveFunPtrs)
Changed |= processFunctionPointers(M);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
index eb7b1dffaee501..621d06aa4aadee 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
@@ -1,4 +1,4 @@
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-DAG: OpCapability Int8
@@ -15,10 +15,14 @@
; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
-; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
-; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]]
-; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrFp]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyUncompleteBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrUncompleteBar:.*]] = OpTypePointer Function %[[TyUncompleteBar]]
+; CHECK-DAG: %[[TyUncompleteFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrUncompleteBar]]
+; CHECK-DAG: %[[TyPtrUncompleteFp:.*]] = OpTypePointer Function %[[TyUncompleteFp]]
+; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrUncompleteFp]] %[[TyPtrInt8]]
; CHECK-DAG: %[[TyPtrBar:.*]] = OpTypePointer Function %[[TyBar]]
+; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrBar]]
+; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]]
; CHECK-DAG: %[[TyTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFp]] %[[TyPtrInt8]] %[[TyPtrBar]]
; CHECK: %[[test]] = OpFunction %[[TyVoid]] None %[[TyTest]]
; CHECK: %[[fp]] = OpFunctionParameter %[[TyPtrFp]]
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll
new file mode 100644
index 00000000000000..a9e79df259c4fb
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll
@@ -0,0 +1,82 @@
+; The goal of the test case is to ensure that correct types are applied to PHI's as arguments of other PHI's.
+; Pass criterion is that spirv-val considers output valid.
+
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+; CHECK-DAG: OpName %[[#Foo:]] "foo"
+; CHECK-DAG: OpName %[[#FooVal1:]] "val1"
+; CHECK-DAG: OpName %[[#FooVal2:]] "val2"
+; CHECK-DAG: OpName %[[#FooVal3:]] "val3"
+; CHECK-DAG: OpName %[[#Bar:]] "bar"
+; CHECK-DAG: OpName %[[#BarVal1:]] "val1"
+; CHECK-DAG: OpName %[[#BarVal2:]] "val2"
+; CHECK-DAG: OpName %[[#BarVal3:]] "val3"
+
+; CHECK-DAG: %[[#Short:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#ShortGenPtr:]] = OpTypePointer Generic %[[#Short]]
+; CHECK-DAG: %[[#ShortWrkPtr:]] = OpTypePointer Workgroup %[[#Short]]
+; CHECK-DAG: %[[#G1:]] = OpVariable %[[#ShortWrkPtr]] Workgroup
+
+; CHECK: %[[#Foo:]] = OpFunction %[[#]] None %[[#]]
+; CHECK: %[[#FooArgP:]] = OpFunctionParameter %[[#ShortGenPtr]]
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: %[[#FooG1:]] = OpPtrCastToGeneric %[[#ShortGenPtr]] %[[#G1]]
+; CHECK: %[[#FooVal2]] = OpPhi %[[#ShortGenPtr]] %[[#FooArgP]] %[[#]] %[[#FooVal3]] %[[#]]
+; CHECK: %[[#FooVal1]] = OpPhi %[[#ShortGenPtr]] %[[#FooG1]] %[[#]] %[[#FooVal2]] %[[#]]
+; CHECK: %[[#FooVal3]] = OpLoad %[[#ShortGenPtr]] %[[#]]
+
+; CHECK: %[[#Bar:]] = OpFunction %[[#]] None %[[#]]
+; CHECK: %[[#BarArgP:]] = OpFunctionParameter %[[#ShortGenPtr]]
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: %[[#BarVal3]] = OpLoad %[[#ShortGenPtr]] %[[#]]
+; CHECK: %[[#BarG1:]] = OpPtrCastToGeneric %[[#ShortGenPtr]] %[[#G1]]
+; CHECK: %[[#BarVal1]] = OpPhi %[[#ShortGenPtr]] %[[#BarG1]] %[[#]] %[[#BarVal2]] %[[#]]
+; CHECK: %[[#BarVal2]] = OpPhi %[[#ShortGenPtr]] %[[#BarArgP]] %[[#]] %[[#BarVal3]] %[[#]]
+
+ at G1 = internal addrspace(3) global i16 undef, align 8
+ at G2 = internal unnamed_addr addrspace(3) global ptr addrspace(4) undef, align 8
+
+define spir_kernel void @foo(ptr addrspace(4) %p, i1 %f1, i1 %f2, i1 %f3) {
+entry:
+ br label %l1
+
+l1:
+ br i1 %f1, label %l2, label %exit
+
+l2:
+ %val2 = phi ptr addrspace(4) [ %p, %l1 ], [ %val3, %l3 ]
+ %val1 = phi ptr addrspace(4) [ addrspacecast (ptr addrspace(3) @G1 to ptr addrspace(4)), %l1 ], [ %val2, %l3 ]
+ br i1 %f2, label %l3, label %exit
+
+l3:
+ %val3 = load ptr addrspace(4), ptr addrspace(3) @G2, align 8
+ br i1 %f3, label %l2, label %exit
+
+exit:
+ ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(4) %p, i1 %f1, i1 %f2, i1 %f3) {
+entry:
+ %val3 = load ptr addrspace(4), ptr addrspace(3) @G2, align 8
+ br label %l1
+
+l3:
+ br i1 %f3, label %l2, label %exit
+
+l1:
+ br i1 %f1, label %l2, label %exit
+
+l2:
+ %val1 = phi ptr addrspace(4) [ addrspacecast (ptr addrspace(3) @G1 to ptr addrspace(4)), %l1 ], [ %val2, %l3 ]
+ %val2 = phi ptr addrspace(4) [ %p, %l1 ], [ %val3, %l3 ]
+ br i1 %f2, label %l3, label %exit
+
+exit:
+ ret void
+}
More information about the llvm-commits
mailing list