[llvm] [SPIR-V] Fix tracking ptr null constants of builtin types in calls with mangling (PR #94263)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 3 10:24:19 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Michal Paszkowski (michalpaszkowski)
<details>
<summary>Changes</summary>
This change makes sure that ptr null (constant) arguments of builtin
function calls are assigned proper builtin types if such are deduced
from mangled names.
Two tests demonstrating the expected bahavior (as in the SPIR-V
Translator) are added.
WIP:
- The test builtin-call-multiple-ptr-null-args-one-of-builtin-
type.ll is failing and requires additional work.
- processInstrAfterVisit method must be simplified. TrackConstants can
be removed.
---
Full diff: https://github.com/llvm/llvm-project/pull/94263.diff
8 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+14-3)
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+35-8)
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+2-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+1-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+3-3)
- (added) llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll (+13)
- (added) llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll (+13)
- (added) llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll (+15)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f4daab7d06eb5..7d4d9801c7ce2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -211,12 +211,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}
- // In case OriginalArgType is of untyped pointer type, there are three
+ // In case OriginalArgType is of untyped pointer type, there are four
// possibilities:
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
// intrinsic assigning a TargetExtType.
- // 3) This is a pointer, try to retrieve pointer element type from a
+ // 3) This is an OpenCL/SPIR-V builtin type if the mangled function name
+ // contains type information (the Arg's function is a builtin, has no
+ // body).
+ // 4) This is a pointer, try to retrieve pointer element type from a
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
if (hasPointeeTypeAttr(Arg)) {
@@ -255,6 +258,14 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
}
+ std::string DemangledFuncName = demangleBuiltinCall(F.getName());
+ if (!DemangledFuncName.empty()) {
+ Type *BuiltinType = SPIRV::parseBuiltinCallArgumentBaseType(
+ DemangledFuncName, ArgIdx, F.getContext());
+ if (BuiltinType && BuiltinType->isTargetExtTy())
+ return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
+ }
+
// Replace PointerType with TypedPointerType to be able to map SPIR-V types to
// LLVM types in a consistent manner
if (isUntypedPointerTy(OriginalArgType)) {
@@ -509,7 +520,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// globally later.
if (Info.Callee.isGlobal()) {
std::string FuncName = Info.Callee.getGlobal()->getName().str();
- DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
+ DemangledName = demangleBuiltinCall(FuncName);
CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index ffbd1e17bad5e..7e9347fa2fbc9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -851,6 +851,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
I->setOperand(OperandToReplace, PtrCastI);
}
+static bool inline isDirectNonIntrinsicCall(CallInst *CI) {
+ return CI && !CI->isIndirectCall() && !CI->isInlineAsm() &&
+ CI->getCalledFunction() && !CI->getCalledFunction()->isIntrinsic();
+}
+
void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
IRBuilder<> &B) {
// Handle basic instructions:
@@ -874,14 +879,12 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
// Handle calls to builtins (non-intrinsics):
CallInst *CI = dyn_cast<CallInst>(I);
- if (!CI || CI->isIndirectCall() || CI->isInlineAsm() ||
- !CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic())
+ if (!isDirectNonIntrinsicCall(CI))
return;
- // collect information about formal parameter types
- std::string DemangledName =
- getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
+ // Collect information about formal parameter types
Function *CalledF = CI->getCalledFunction();
+ std::string DemangledName = demangleBuiltinCall(CalledF->getName());
SmallVector<Type *, 4> CalledArgTys;
bool HaveTypes = false;
for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) {
@@ -1195,12 +1198,21 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
I->replaceAllUsesWith(NewOp);
NewOp->setArgOperand(0, I);
}
+
+ auto *CI = dyn_cast<CallInst>(I);
+ bool IsCall = CI && isDirectNonIntrinsicCall(CI);
+ std::string DemangledCall =
+ IsCall ? demangleBuiltinCall(CI->getCalledFunction()->getName()) : "";
+
bool IsPhi = isa<PHINode>(I), BPrepared = false;
+
for (const auto &Op : I->operands()) {
if ((isa<ConstantAggregateZero>(Op) && Op->getType()->isVectorTy()) ||
isa<PHINode>(I) || isa<SwitchInst>(I))
TrackConstants = false;
if ((isa<ConstantData>(Op) || isa<ConstantExpr>(Op)) && TrackConstants) {
+ Constant *OpConst = cast<Constant>(Op);
+
unsigned OpNo = Op.getOperandNo();
if (II && ((II->getIntrinsicID() == Intrinsic::spv_gep && OpNo == 0) ||
(II->paramHasAttr(OpNo, Attribute::ImmArg))))
@@ -1210,12 +1222,27 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
: B.SetInsertPoint(I);
BPrepared = true;
}
+
Value *OpTyVal = Op;
- if (Op->getType()->isTargetExtTy())
+ Value *ConstVal = Op;
+
+ if (Op->getType()->isTargetExtTy() ||
+ (Op->getType()->isPointerTy() && !DemangledCall.empty() &&
+ OpConst->isNullValue())) {
OpTyVal = PoisonValue::get(Op->getType());
+ }
+
+ if (Op->getType()->isPointerTy() && !DemangledCall.empty() &&
+ OpConst->isNullValue()) {
+ Type *DemangledTy = SPIRV::parseBuiltinCallArgumentBaseType(
+ DemangledCall, Op.getOperandNo(), I->getContext());
+ if (DemangledTy)
+ ConstVal = Constant::getNullValue(DemangledTy);
+ }
+
auto *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant,
- {Op->getType(), OpTyVal->getType()}, Op,
- OpTyVal, {}, B);
+ {Op->getType(), OpTyVal->getType()},
+ ConstVal, OpTyVal, {}, B);
I->setOperand(OpNo, NewOp);
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 624899600693a..3ce3b9b3b1e64 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -81,13 +81,14 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
}
GR->add(Const, &MF, SrcReg);
if (Const->getType()->isTargetExtTy()) {
- // remember association so that we can restore it when assign types
+ // Remember association so that we can restore it when assign types.
MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
TargetExtConstTypes[SrcMI] = Const->getType();
if (Const->isNullValue()) {
MachineIRBuilder MIB(MF);
+ MIB.setInsertPt(*MI.getParent(), MI);
SPIRVType *ExtType =
GR->getOrCreateSPIRVType(Const->getType(), MIB);
SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c20f3546a3e55..b8cd83d1ca975 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -332,7 +332,7 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
Name == "__translate_sampler_initializer";
}
-std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
+std::string demangleBuiltinCall(StringRef Name) {
bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 33cb509dc4a59..a539e960f22a2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -90,9 +90,9 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID);
// Get type of i-th operand of the metadata node.
Type *getMDOperandAsType(const MDNode *N, unsigned I);
-// If OpenCL or SPIR-V builtin function name is recognized, return a demangled
-// name, otherwise return an empty string.
-std::string getOclOrSpirvBuiltinDemangledName(StringRef Name);
+// If SPIR-V builtin function name is recognized, return a demangled name,
+// otherwise return an empty string.
+std::string demangleBuiltinCall(StringRef Name);
// Check if a string contains a builtin prefix.
bool hasBuiltinTypePrefix(StringRef Name);
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll
new file mode 100644
index 0000000000000..103fab2838607
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll
@@ -0,0 +1,13 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent
+; CHECK-DAG: %[[#EVENT_NULL:]] = OpConstantNull %[[#EVENT]]
+; CHECK-DAG: %[[#]] = OpFunctionCall %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#EVENT_NULL]]
+
+define spir_kernel void @foo() {
+ %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr null, ptr null, i64 1, i64 1, ptr null)
+ ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll
new file mode 100644
index 0000000000000..51d094e398216
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll
@@ -0,0 +1,13 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent
+; CHECK-DAG: %[[#EVENT_NULL:]] = OpConstantNull %[[#EVENT]]
+; CHECK-DAG: %[[#]] = OpFunctionCall %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#EVENT_NULL]]
+
+define spir_kernel void @foo(ptr %a, ptr %b) {
+ %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr %a, ptr %b, i64 1, i64 1, ptr null)
+ ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll
new file mode 100644
index 0000000000000..ff386affdc4c5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll
@@ -0,0 +1,15 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#PTR_INT8:]] = OpTypePointer Function %[[#INT8]]
+; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent
+; CHECK-DAG: %[[#FUNC_TY:]] = OpTypeFunction %[[#]] %[[#PTR_INT8]] %[[#PTR_INT8]] %[[#]] %[[#]] %[[#EVENT]]
+; CHECK-DAG: %[[#]] = OpFunction %[[#]] None %[[#FUNC_TY]]
+
+define spir_kernel void @foo(ptr %a, ptr %b) {
+ %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr %a, ptr %b, i64 1, i64 1, ptr null)
+ ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)
``````````
</details>
https://github.com/llvm/llvm-project/pull/94263
More information about the llvm-commits
mailing list