[llvm] [SPIRV] Improve builtins matching and type inference in SPIR-V Backend (PR #89948)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 24 09:24:05 PDT 2024


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/89948

This PR is to improve builtins matching and type inference in SPIR-V Backend. The model test case is printf call from OpenCL.std that has several features allowing for a wider look at builtins support/type inference:
(1) call in a "spirv-friendly" style (prefixed by __spirv_ocl_)
(2) restricted type of the 1st argument

Attached test cases checks several possible inputs. Support of the extension SPV_EXT_relaxed_printf_string_address_space is to do (see: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_relaxed_printf_string_address_space.asciidoc).

>From c5102cb39cf8b154b81648a09c8d0584b989af38 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 24 Apr 2024 09:20:26 -0700
Subject: [PATCH] Improve bultins matching and type inference

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  6 +-
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 87 ++++++++++++++-----
 llvm/test/CodeGen/SPIRV/printf.ll             | 40 +++++++++
 3 files changed, 109 insertions(+), 24 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/printf.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 4b07d7e61fa113..64227ae062dfde 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -189,6 +189,10 @@ lookupBuiltin(StringRef DemangledCall,
   std::string BuiltinName =
       DemangledCall.substr(0, DemangledCall.find('(')).str();
 
+  // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
+  if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
+    BuiltinName = BuiltinName.substr(12);
+
   // Check if the extracted name contains type information between angle
   // brackets. If so, the builtin is an instantiated template - needs to have
   // the information after angle brackets and return type removed.
@@ -2306,7 +2310,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
     // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
     // base types.
     if (TypeStr.ends_with("*"))
-      TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
+      TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
 
     return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
                                                Ctx);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 472bc8638c9af1..204a662240e54f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -98,6 +98,8 @@ class SPIRVEmitIntrinsics
     return B.CreateIntrinsic(IntrID, {Types}, Args);
   }
 
+  void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
+
   void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
   void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
   void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
@@ -111,6 +113,7 @@ class SPIRVEmitIntrinsics
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
   void processParamTypes(Function *F, IRBuilder<> &B);
+  void processParamTypesByFunHeader(Function *F, IRBuilder<> &B);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
@@ -194,6 +197,17 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
                        false);
 }
 
+void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
+                                         Value *Arg) {
+  CallInst *AssignPtrTyCI =
+      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+                      Constant::getNullValue(ElemTy), Arg,
+                      {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+  GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+  GR->addDeducedElementType(Arg, ElemTy);
+  AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
+}
+
 // Set element pointer type to the given value of ValueTy and tries to
 // specify this type further (recursively) by Operand value, if needed.
 Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
@@ -232,6 +246,19 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
   return nullptr;
 }
 
+// Implements what we know in advance about intrinsics and builtin calls
+// TODO: consider feasibility of this particular case to be generalized by
+// encoding knowledge about intrinsics and builtin calls by corresponding
+// specification rules
+static Type *getPointeeTypeByCallInst(StringRef DemangledName,
+                                      Function *CalledF, unsigned OpIdx) {
+  if ((DemangledName.starts_with("__spirv_ocl_printf(") ||
+       DemangledName.starts_with("printf(")) &&
+      OpIdx == 0)
+    return IntegerType::getInt8Ty(CalledF->getContext());
+  return nullptr;
+}
+
 // Deduce and return a successfully deduced Type of the Instruction,
 // or nullptr otherwise.
 Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
@@ -795,6 +822,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
     return;
 
   // collect information about formal parameter types
+  std::string DemangledName =
+      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
   Function *CalledF = CI->getCalledFunction();
   SmallVector<Type *, 4> CalledArgTys;
   bool HaveTypes = false;
@@ -811,10 +840,15 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       if (!ElemTy && hasPointeeTypeAttr(CalledArg))
         ElemTy = getPointeeTypeByAttr(CalledArg);
       if (!ElemTy) {
-        for (User *U : CalledArg->users()) {
-          if (Instruction *Inst = dyn_cast<Instruction>(U)) {
-            if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
-              break;
+        ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
+        if (ElemTy) {
+          GR->addDeducedElementType(CalledArg, ElemTy);
+        } else {
+          for (User *U : CalledArg->users()) {
+            if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+              if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+                break;
+            }
           }
         }
       }
@@ -823,8 +857,6 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
     }
   }
 
-  std::string DemangledName =
-      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
   if (DemangledName.empty() && !HaveTypes)
     return;
 
@@ -835,8 +867,14 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       continue;
 
     // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
-    if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
-      continue;
+    if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand)) {
+      // However, we may have assumptions about the formal argument's type and
+      // may have a need to insert a ptr cast for the actual parameter of this
+      // call.
+      Argument *CalledArg = CalledF->getArg(OpIdx);
+      if (!GR->findDeducedElementType(CalledArg))
+        continue;
+    }
 
     Type *ExpectedType =
         OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
@@ -1179,28 +1217,29 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
   return nullptr;
 }
 
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
+                                                       IRBuilder<> &B) {
   B.SetInsertPointPastAllocas(F);
   for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
     Argument *Arg = F->getArg(OpIdx);
     if (!isUntypedPointerTy(Arg->getType()))
       continue;
+    Type *ElemTy = GR->findDeducedElementType(Arg);
+    if (!ElemTy && hasPointeeTypeAttr(Arg) &&
+        (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr)
+      buildAssignPtr(B, ElemTy, Arg);
+  }
+}
 
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+  B.SetInsertPointPastAllocas(F);
+  for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+    Argument *Arg = F->getArg(OpIdx);
+    if (!isUntypedPointerTy(Arg->getType()))
+      continue;
     Type *ElemTy = GR->findDeducedElementType(Arg);
-    if (!ElemTy) {
-      if (hasPointeeTypeAttr(Arg) &&
-          (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
-        GR->addDeducedElementType(Arg, ElemTy);
-      } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
-        CallInst *AssignPtrTyCI = buildIntrWithMD(
-            Intrinsic::spv_assign_ptr_type, {Arg->getType()},
-            Constant::getNullValue(ElemTy), Arg,
-            {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
-        GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
-        GR->addDeducedElementType(Arg, ElemTy);
-        AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
-      }
-    }
+    if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr)
+      buildAssignPtr(B, ElemTy, Arg);
   }
 }
 
@@ -1217,6 +1256,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
+  processParamTypesByFunHeader(F, B);
+
   // StoreInst's operand type can be changed during the next transformations,
   // so we need to store it in the set. Also store already transformed types.
   for (auto &I : instructions(Func)) {
diff --git a/llvm/test/CodeGen/SPIRV/printf.ll b/llvm/test/CodeGen/SPIRV/printf.ll
new file mode 100644
index 00000000000000..483fc1f244e57c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/printf.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#ExtImport:]] = OpExtInstImport "OpenCL.std"
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#CharPtr:]] = OpTypePointer UniformConstant %[[#Char]]
+; CHECK: %[[#GV:]] = OpVariable %[[#]] UniformConstant %[[#]]
+; CHECK: OpFunction
+; CHECK: %[[#Arg1:]] = OpFunctionParameter
+; CHECK: %[[#Arg2:]] = OpFunctionParameter
+; CHECK: %[[#CastedGV:]] = OpBitcast %[[#CharPtr]] %[[#GV]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst:]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst]]
+; CHECK-NEXT: %[[#CastedArg2:]] = OpBitcast %[[#CharPtr]] %[[#Arg2]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
+; CHECK: OpFunctionEnd
+
+%struct = type { [6 x i8] }
+
+ at FmtStr = internal addrspace(2) constant [6 x i8] c"c=%c\0A\00", align 1
+
+define spir_kernel void @foo(ptr addrspace(2) %_arg_fmt1, ptr addrspace(2) byval(%struct) %_arg_fmt2) {
+entry:
+  %r1 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+  %r2 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
+  %r3 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+  %r4 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
+  %r5 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+  %r6 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
+  ret void
+}
+
+declare dso_local spir_func i32 @_Z6printfPU3AS2Kcz(ptr addrspace(2), ...)
+declare dso_local spir_func i32 @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2), ...)



More information about the llvm-commits mailing list