[llvm] [SPIRV] Improve builtins matching and type inference in SPIR-V Backend (PR #89948)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 24 09:55:44 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR is to improve builtins matching and type inference in SPIR-V Backend. The model test case is printf call from OpenCL.std that has several features allowing for a wider look at builtins support/type inference:
(1) call in a "spirv-friendly" style (prefixed by __spirv_ocl_)
(2) restricted type of the 1st argument
Attached test cases checks several possible inputs. Support of the extension SPV_EXT_relaxed_printf_string_address_space is to do (see: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_relaxed_printf_string_address_space.asciidoc).
---
Full diff: https://github.com/llvm/llvm-project/pull/89948.diff
3 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+5-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+64-23)
- (added) llvm/test/CodeGen/SPIRV/printf.ll (+40)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 4b07d7e61fa113..64227ae062dfde 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -189,6 +189,10 @@ lookupBuiltin(StringRef DemangledCall,
std::string BuiltinName =
DemangledCall.substr(0, DemangledCall.find('(')).str();
+ // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
+ if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
+ BuiltinName = BuiltinName.substr(12);
+
// Check if the extracted name contains type information between angle
// brackets. If so, the builtin is an instantiated template - needs to have
// the information after angle brackets and return type removed.
@@ -2306,7 +2310,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
// parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
// base types.
if (TypeStr.ends_with("*"))
- TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
+ TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
Ctx);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 472bc8638c9af1..204a662240e54f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -98,6 +98,8 @@ class SPIRVEmitIntrinsics
return B.CreateIntrinsic(IntrID, {Types}, Args);
}
+ void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
+
void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
@@ -111,6 +113,7 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
+ void processParamTypesByFunHeader(Function *F, IRBuilder<> &B);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);
@@ -194,6 +197,17 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
false);
}
+void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
+ Value *Arg) {
+ CallInst *AssignPtrTyCI =
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+ Constant::getNullValue(ElemTy), Arg,
+ {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+ GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+ GR->addDeducedElementType(Arg, ElemTy);
+ AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
+}
+
// Set element pointer type to the given value of ValueTy and tries to
// specify this type further (recursively) by Operand value, if needed.
Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
@@ -232,6 +246,19 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
return nullptr;
}
+// Implements what we know in advance about intrinsics and builtin calls
+// TODO: consider feasibility of this particular case to be generalized by
+// encoding knowledge about intrinsics and builtin calls by corresponding
+// specification rules
+static Type *getPointeeTypeByCallInst(StringRef DemangledName,
+ Function *CalledF, unsigned OpIdx) {
+ if ((DemangledName.starts_with("__spirv_ocl_printf(") ||
+ DemangledName.starts_with("printf(")) &&
+ OpIdx == 0)
+ return IntegerType::getInt8Ty(CalledF->getContext());
+ return nullptr;
+}
+
// Deduce and return a successfully deduced Type of the Instruction,
// or nullptr otherwise.
Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
@@ -795,6 +822,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
return;
// collect information about formal parameter types
+ std::string DemangledName =
+ getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
Function *CalledF = CI->getCalledFunction();
SmallVector<Type *, 4> CalledArgTys;
bool HaveTypes = false;
@@ -811,10 +840,15 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
if (!ElemTy && hasPointeeTypeAttr(CalledArg))
ElemTy = getPointeeTypeByAttr(CalledArg);
if (!ElemTy) {
- for (User *U : CalledArg->users()) {
- if (Instruction *Inst = dyn_cast<Instruction>(U)) {
- if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
- break;
+ ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
+ if (ElemTy) {
+ GR->addDeducedElementType(CalledArg, ElemTy);
+ } else {
+ for (User *U : CalledArg->users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+ if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+ break;
+ }
}
}
}
@@ -823,8 +857,6 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
}
}
- std::string DemangledName =
- getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
if (DemangledName.empty() && !HaveTypes)
return;
@@ -835,8 +867,14 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
continue;
// Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
- if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
- continue;
+ if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand)) {
+ // However, we may have assumptions about the formal argument's type and
+ // may have a need to insert a ptr cast for the actual parameter of this
+ // call.
+ Argument *CalledArg = CalledF->getArg(OpIdx);
+ if (!GR->findDeducedElementType(CalledArg))
+ continue;
+ }
Type *ExpectedType =
OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
@@ -1179,28 +1217,29 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
return nullptr;
}
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
+ IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
if (!isUntypedPointerTy(Arg->getType()))
continue;
+ Type *ElemTy = GR->findDeducedElementType(Arg);
+ if (!ElemTy && hasPointeeTypeAttr(Arg) &&
+ (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr)
+ buildAssignPtr(B, ElemTy, Arg);
+ }
+}
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+ B.SetInsertPointPastAllocas(F);
+ for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+ Argument *Arg = F->getArg(OpIdx);
+ if (!isUntypedPointerTy(Arg->getType()))
+ continue;
Type *ElemTy = GR->findDeducedElementType(Arg);
- if (!ElemTy) {
- if (hasPointeeTypeAttr(Arg) &&
- (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
- GR->addDeducedElementType(Arg, ElemTy);
- } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
- CallInst *AssignPtrTyCI = buildIntrWithMD(
- Intrinsic::spv_assign_ptr_type, {Arg->getType()},
- Constant::getNullValue(ElemTy), Arg,
- {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
- GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
- GR->addDeducedElementType(Arg, ElemTy);
- AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
- }
- }
+ if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr)
+ buildAssignPtr(B, ElemTy, Arg);
}
}
@@ -1217,6 +1256,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
AggrConstTypes.clear();
AggrStores.clear();
+ processParamTypesByFunHeader(F, B);
+
// StoreInst's operand type can be changed during the next transformations,
// so we need to store it in the set. Also store already transformed types.
for (auto &I : instructions(Func)) {
diff --git a/llvm/test/CodeGen/SPIRV/printf.ll b/llvm/test/CodeGen/SPIRV/printf.ll
new file mode 100644
index 00000000000000..483fc1f244e57c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/printf.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#ExtImport:]] = OpExtInstImport "OpenCL.std"
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#CharPtr:]] = OpTypePointer UniformConstant %[[#Char]]
+; CHECK: %[[#GV:]] = OpVariable %[[#]] UniformConstant %[[#]]
+; CHECK: OpFunction
+; CHECK: %[[#Arg1:]] = OpFunctionParameter
+; CHECK: %[[#Arg2:]] = OpFunctionParameter
+; CHECK: %[[#CastedGV:]] = OpBitcast %[[#CharPtr]] %[[#GV]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst]]
+; CHECK-NEXT: %[[#CastedArg2:]] = OpBitcast %[[#CharPtr]] %[[#Arg2]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK: OpFunctionEnd
+
+%struct = type { [6 x i8] }
+
+ at FmtStr = internal addrspace(2) constant [6 x i8] c"c=%c\0A\00", align 1
+
+define spir_kernel void @foo(ptr addrspace(2) %_arg_fmt1, ptr addrspace(2) byval(%struct) %_arg_fmt2) {
+entry:
+ %r1 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+ %r2 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+ %r3 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+ %r4 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+ %r5 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+ %r6 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+ ret void
+}
+
+declare dso_local spir_func i32 @_Z6printfPU3AS2Kcz(ptr addrspace(2), ...)
+declare dso_local spir_func i32 @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2), ...)
``````````
</details>
https://github.com/llvm/llvm-project/pull/89948
More information about the llvm-commits
mailing list