[llvm] [SPIRV] Improve builtins matching and type inference in SPIR-V Backend (PR #89948)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 25 06:26:11 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/89948
>From c5102cb39cf8b154b81648a09c8d0584b989af38 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 24 Apr 2024 09:20:26 -0700
Subject: [PATCH 1/2] Improve bultins matching and type inference
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 6 +-
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 87 ++++++++++++++-----
llvm/test/CodeGen/SPIRV/printf.ll | 40 +++++++++
3 files changed, 109 insertions(+), 24 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/printf.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 4b07d7e61fa113..64227ae062dfde 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -189,6 +189,10 @@ lookupBuiltin(StringRef DemangledCall,
std::string BuiltinName =
DemangledCall.substr(0, DemangledCall.find('(')).str();
+ // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
+ if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
+ BuiltinName = BuiltinName.substr(12);
+
// Check if the extracted name contains type information between angle
// brackets. If so, the builtin is an instantiated template - needs to have
// the information after angle brackets and return type removed.
@@ -2306,7 +2310,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
// parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
// base types.
if (TypeStr.ends_with("*"))
- TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
+ TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
Ctx);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 472bc8638c9af1..204a662240e54f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -98,6 +98,8 @@ class SPIRVEmitIntrinsics
return B.CreateIntrinsic(IntrID, {Types}, Args);
}
+ void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
+
void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
@@ -111,6 +113,7 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
+ void processParamTypesByFunHeader(Function *F, IRBuilder<> &B);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);
@@ -194,6 +197,17 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
false);
}
+void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
+ Value *Arg) {
+ CallInst *AssignPtrTyCI =
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+ Constant::getNullValue(ElemTy), Arg,
+ {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+ GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+ GR->addDeducedElementType(Arg, ElemTy);
+ AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
+}
+
// Set element pointer type to the given value of ValueTy and tries to
// specify this type further (recursively) by Operand value, if needed.
Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
@@ -232,6 +246,19 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
return nullptr;
}
+// Implements what we know in advance about intrinsics and builtin calls
+// TODO: consider feasibility of this particular case to be generalized by
+// encoding knowledge about intrinsics and builtin calls by corresponding
+// specification rules
+static Type *getPointeeTypeByCallInst(StringRef DemangledName,
+ Function *CalledF, unsigned OpIdx) {
+ if ((DemangledName.starts_with("__spirv_ocl_printf(") ||
+ DemangledName.starts_with("printf(")) &&
+ OpIdx == 0)
+ return IntegerType::getInt8Ty(CalledF->getContext());
+ return nullptr;
+}
+
// Deduce and return a successfully deduced Type of the Instruction,
// or nullptr otherwise.
Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
@@ -795,6 +822,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
return;
// collect information about formal parameter types
+ std::string DemangledName =
+ getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
Function *CalledF = CI->getCalledFunction();
SmallVector<Type *, 4> CalledArgTys;
bool HaveTypes = false;
@@ -811,10 +840,15 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
if (!ElemTy && hasPointeeTypeAttr(CalledArg))
ElemTy = getPointeeTypeByAttr(CalledArg);
if (!ElemTy) {
- for (User *U : CalledArg->users()) {
- if (Instruction *Inst = dyn_cast<Instruction>(U)) {
- if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
- break;
+ ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
+ if (ElemTy) {
+ GR->addDeducedElementType(CalledArg, ElemTy);
+ } else {
+ for (User *U : CalledArg->users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+ if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+ break;
+ }
}
}
}
@@ -823,8 +857,6 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
}
}
- std::string DemangledName =
- getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
if (DemangledName.empty() && !HaveTypes)
return;
@@ -835,8 +867,14 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
continue;
// Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
- if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
- continue;
+ if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand)) {
+ // However, we may have assumptions about the formal argument's type and
+ // may have a need to insert a ptr cast for the actual parameter of this
+ // call.
+ Argument *CalledArg = CalledF->getArg(OpIdx);
+ if (!GR->findDeducedElementType(CalledArg))
+ continue;
+ }
Type *ExpectedType =
OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
@@ -1179,28 +1217,29 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
return nullptr;
}
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
+ IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
if (!isUntypedPointerTy(Arg->getType()))
continue;
+ Type *ElemTy = GR->findDeducedElementType(Arg);
+ if (!ElemTy && hasPointeeTypeAttr(Arg) &&
+ (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr)
+ buildAssignPtr(B, ElemTy, Arg);
+ }
+}
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+ B.SetInsertPointPastAllocas(F);
+ for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+ Argument *Arg = F->getArg(OpIdx);
+ if (!isUntypedPointerTy(Arg->getType()))
+ continue;
Type *ElemTy = GR->findDeducedElementType(Arg);
- if (!ElemTy) {
- if (hasPointeeTypeAttr(Arg) &&
- (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
- GR->addDeducedElementType(Arg, ElemTy);
- } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
- CallInst *AssignPtrTyCI = buildIntrWithMD(
- Intrinsic::spv_assign_ptr_type, {Arg->getType()},
- Constant::getNullValue(ElemTy), Arg,
- {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
- GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
- GR->addDeducedElementType(Arg, ElemTy);
- AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
- }
- }
+ if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr)
+ buildAssignPtr(B, ElemTy, Arg);
}
}
@@ -1217,6 +1256,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
AggrConstTypes.clear();
AggrStores.clear();
+ processParamTypesByFunHeader(F, B);
+
// StoreInst's operand type can be changed during the next transformations,
// so we need to store it in the set. Also store already transformed types.
for (auto &I : instructions(Func)) {
diff --git a/llvm/test/CodeGen/SPIRV/printf.ll b/llvm/test/CodeGen/SPIRV/printf.ll
new file mode 100644
index 00000000000000..483fc1f244e57c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/printf.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#ExtImport:]] = OpExtInstImport "OpenCL.std"
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#CharPtr:]] = OpTypePointer UniformConstant %[[#Char]]
+; CHECK: %[[#GV:]] = OpVariable %[[#]] UniformConstant %[[#]]
+; CHECK: OpFunction
+; CHECK: %[[#Arg1:]] = OpFunctionParameter
+; CHECK: %[[#Arg2:]] = OpFunctionParameter
+; CHECK: %[[#CastedGV:]] = OpBitcast %[[#CharPtr]] %[[#GV]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst]]
+; CHECK-NEXT: %[[#CastedArg2:]] = OpBitcast %[[#CharPtr]] %[[#Arg2]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK: OpFunctionEnd
+
+%struct = type { [6 x i8] }
+
+ at FmtStr = internal addrspace(2) constant [6 x i8] c"c=%c\0A\00", align 1
+
+define spir_kernel void @foo(ptr addrspace(2) %_arg_fmt1, ptr addrspace(2) byval(%struct) %_arg_fmt2) {
+entry:
+ %r1 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+ %r2 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+ %r3 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+ %r4 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+ %r5 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+ %r6 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+ ret void
+}
+
+declare dso_local spir_func i32 @_Z6printfPU3AS2Kcz(ptr addrspace(2), ...)
+declare dso_local spir_func i32 @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2), ...)
>From d0ede5c5dd576cca8d4585cc231d3b68dad477c2 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 25 Apr 2024 06:25:58 -0700
Subject: [PATCH 2/2] fix target ext type constants; fix
OpGroupAsyncCopy/OpGroupWaitEvents generation
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 7 ++++
llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 4 +-
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 10 +++--
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 10 +++++
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 40 +++++++++++++------
.../SPIRV/transcoding/spirv-event-null.ll | 33 +++++++++++++++
6 files changed, 87 insertions(+), 17 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 64227ae062dfde..7439d0fefa9800 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2012,6 +2012,13 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode =
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+
+ bool IsSet = Opcode == SPIRV::OpGroupAsyncCopy;
+ Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+ if (Call->isSpirvOp())
+ return buildOpFromWrapper(MIRBuilder, Opcode, Call,
+ IsSet ? TypeReg : Register(0));
+
auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);
switch (Opcode) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 660000fb548d79..564028547821ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -585,9 +585,9 @@ defm : DemangledNativeBuiltin<"__spirv_SpecConstantComposite", OpenCL_std, SpecC
// Async Copy and Prefetch builtin records:
defm : DemangledNativeBuiltin<"async_work_group_copy", OpenCL_std, AsyncCopy, 4, 4, OpGroupAsyncCopy>;
-defm : DemangledNativeBuiltin<"__spirv_GroupAsyncCopy", OpenCL_std, AsyncCopy, 4, 4, OpGroupAsyncCopy>;
+defm : DemangledNativeBuiltin<"__spirv_GroupAsyncCopy", OpenCL_std, AsyncCopy, 6, 6, OpGroupAsyncCopy>;
defm : DemangledNativeBuiltin<"wait_group_events", OpenCL_std, AsyncCopy, 2, 2, OpGroupWaitEvents>;
-defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy, 2, 2, OpGroupWaitEvents>;
+defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy, 3, 3, OpGroupWaitEvents>;
// Load and store builtin records:
defm : DemangledNativeBuiltin<"__spirv_Load", OpenCL_std, LoadStore, 1, 3, OpLoad>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 204a662240e54f..0d539b1ed9a889 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1140,9 +1140,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
(II->paramHasAttr(OpNo, Attribute::ImmArg))))
continue;
B.SetInsertPoint(I);
- auto *NewOp =
- buildIntrWithMD(Intrinsic::spv_track_constant,
- {Op->getType(), Op->getType()}, Op, Op, {}, B);
+ Value *OpTyVal = Op;
+ if (Op->getType()->isTargetExtTy())
+ OpTyVal = Constant::getNullValue(
+ IntegerType::get(I->getContext(), GR->getPointerSize()));
+ auto *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant,
+ {Op->getType(), OpTyVal->getType()}, Op,
+ OpTyVal, {}, B);
I->setOperand(OpNo, NewOp);
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index b8296c3f6eeaee..96b4a570a26b1d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -314,6 +314,16 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
+ case SPIRV::OpConstantI: {
+ SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
+ if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
+ MI.getOperand(2).getImm() == 0) {
+ // Validate the null constant of a target extension type
+ MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
+ for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
+ MI.removeOperand(i);
+ }
+ } break;
}
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 9ee0b38d22332a..84508fb5fe09eb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -38,7 +38,9 @@ class SPIRVPreLegalizer : public MachineFunctionPass {
};
} // namespace
-static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
+static void
+addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+ DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
MachineRegisterInfo &MRI = MF.getRegInfo();
DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
@@ -47,6 +49,7 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
continue;
ToErase.push_back(&MI);
+ Register SrcReg = MI.getOperand(2).getReg();
auto *Const =
cast<Constant>(cast<ConstantAsMetadata>(
MI.getOperand(3).getMetadata()->getOperand(0))
@@ -54,14 +57,14 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
if (auto *GV = dyn_cast<GlobalValue>(Const)) {
Register Reg = GR->find(GV, &MF);
if (!Reg.isValid())
- GR->add(GV, &MF, MI.getOperand(2).getReg());
+ GR->add(GV, &MF, SrcReg);
else
RegsAlreadyAddedToDT[&MI] = Reg;
} else {
Register Reg = GR->find(Const, &MF);
if (!Reg.isValid()) {
if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
- auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
+ auto *BuildVec = MRI.getVRegDef(SrcReg);
assert(BuildVec &&
BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
@@ -75,7 +78,13 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
BuildVec->getOperand(1 + i).setReg(ElemReg);
}
}
- GR->add(Const, &MF, MI.getOperand(2).getReg());
+ GR->add(Const, &MF, SrcReg);
+ if (Const->getType()->isTargetExtTy()) {
+ // remember association so that we can restore it when assign types
+ MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
+ if (SrcMI && SrcMI->getOpcode() == TargetOpcode::G_CONSTANT)
+ TargetExtConstTypes[SrcMI] = Const->getType();
+ }
} else {
RegsAlreadyAddedToDT[&MI] = Reg;
// This MI is unused and will be removed. If the MI uses
@@ -364,8 +373,10 @@ void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
}
} // namespace llvm
-static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static void
+generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB,
+ DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
@@ -422,11 +433,14 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
continue;
}
Type *Ty = nullptr;
- if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
- Ty = MI.getOperand(1).getCImm()->getType();
- else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
+ if (MI.getOpcode() == TargetOpcode::G_CONSTANT) {
+ auto TargetExtIt = TargetExtConstTypes.find(&MI);
+ Ty = TargetExtIt == TargetExtConstTypes.end()
+ ? MI.getOperand(1).getCImm()->getType()
+ : TargetExtIt->second;
+ } else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT) {
Ty = MI.getOperand(1).getFPImm()->getType();
- else {
+ } else {
assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
Type *ElemTy = nullptr;
MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
@@ -616,10 +630,12 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
- addConstantsToTrack(MF, GR);
+ // a registry of target extension constants
+ DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
+ addConstantsToTrack(MF, GR, TargetExtConstTypes);
foldConstantsIntoIntrinsics(MF);
insertBitcasts(MF, GR, MIB);
- generateAssignInstrs(MF, GR, MIB);
+ generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
processSwitches(MF, GR, MIB);
processInstrsWithTypeFolding(MF, GR, MIB);
removeImplicitFallthroughs(MF, MIB);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
new file mode 100644
index 00000000000000..fe0d96f2773ec6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#TyEvent:]] = OpTypeEvent
+; CHECK-DAG: %[[#TyStruct:]] = OpTypeStruct %[[#TyEvent]]
+; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
+; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
+; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
+; CHECK: OpFunction
+; CHECK: OpFunctionParameter
+; CHECK: %[[#Src:]] = OpFunctionParameter
+; CHECK: OpVariable %[[#TyStructPtr]] Function
+; CHECK: %[[#EventVar:]] = OpVariable %[[#TyEventPtr]] Function
+; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
+; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
+
+%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+
+define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
+entry:
+ %var = alloca %"class.sycl::_V1::device_event"
+ %dev_event.i.sroa.0 = alloca target("spirv.Event")
+ %add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
+ %call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+ store target("spirv.Event") %call3.i, ptr %dev_event.i.sroa.0
+ ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
More information about the llvm-commits
mailing list