[llvm] [SPIR-V] Fix tracking ptr null constants of builtin types in calls with mangling (PR #94263)

Michal Paszkowski via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 10:23:42 PDT 2024


https://github.com/michalpaszkowski created https://github.com/llvm/llvm-project/pull/94263

This change makes sure that ptr null (constant) arguments of builtin
function calls are assigned proper builtin types if such are deduced
from mangled names.

Two tests demonstrating the expected bahavior (as in the SPIR-V
Translator) are added.

WIP:
- The test builtin-call-multiple-ptr-null-args-one-of-builtin-
type.ll is failing and requires additional work.

- processInstrAfterVisit method must be simplified. TrackConstants can
  be removed.

>From 7b36f94a55f11b71e46d8bc99ad50143308d7d30 Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal at paszkowski.org>
Date: Mon, 3 Jun 2024 09:49:09 -0700
Subject: [PATCH 1/2] [SPIR-V] Use mangled names for deducing builtin funcs arg
 types

Before this change, the method for deducing function argument types
assumed that any argument of untyped pointer type must be either:
1) A pointer of an LLVM IR element type, passed byval/byref.
2) An OpenCL/SPIR-V builtin type if there is spv_assign_type intrinsic
   assigning a TargetExtType.
3) Just a pointer (with default size)

This does not take into consideration builtin functions which might
also have arguments of OpenCL/SPIR-V builtin type. Since builtins have
just their prototypes inside a module (no body), no spv_assign_type
intrinsics are generared for their arguments. Hence, a fourth option:
4) An OpenCL/SPIR-V builtin type if the mangled function name contains
   type information.

A test mimicking SPIR-V Translator behavior was added.
---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp      | 16 ++++++++++++++--
 .../builtin-function-ptr-arg-of-builtin-type.ll  | 15 +++++++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f4daab7d06eb5..51b0bce46691e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -211,12 +211,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
   }
 
-  // In case OriginalArgType is of untyped pointer type, there are three
+  // In case OriginalArgType is of untyped pointer type, there are four
   // possibilities:
   // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
   // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
   //    intrinsic assigning a TargetExtType.
-  // 3) This is a pointer, try to retrieve pointer element type from a
+  // 3) This is an OpenCL/SPIR-V builtin type if the mangled function name
+  //    contains type information (the Arg's function is a builtin, has no
+  //    body).
+  // 4) This is a pointer, try to retrieve pointer element type from a
   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
   // type.
   if (hasPointeeTypeAttr(Arg)) {
@@ -255,6 +258,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
             cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
   }
 
+  std::string DemangledFuncName =
+      getOclOrSpirvBuiltinDemangledName(F.getName());
+  if (!DemangledFuncName.empty()) {
+    Type *BuiltinType = SPIRV::parseBuiltinCallArgumentBaseType(
+        DemangledFuncName, ArgIdx, F.getContext());
+    if (BuiltinType && BuiltinType->isTargetExtTy())
+      return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
+  }
+
   // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
   // LLVM types in a consistent manner
   if (isUntypedPointerTy(OriginalArgType)) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll
new file mode 100644
index 0000000000000..ff386affdc4c5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-function-ptr-arg-of-builtin-type.ll
@@ -0,0 +1,15 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#PTR_INT8:]] = OpTypePointer Function %[[#INT8]]
+; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent
+; CHECK-DAG: %[[#FUNC_TY:]] = OpTypeFunction %[[#]] %[[#PTR_INT8]] %[[#PTR_INT8]] %[[#]] %[[#]] %[[#EVENT]]
+; CHECK-DAG: %[[#]] = OpFunction %[[#]] None %[[#FUNC_TY]]
+
+define spir_kernel void @foo(ptr %a, ptr %b) {
+  %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr %a, ptr %b, i64 1, i64 1, ptr null)
+  ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)

>From 740e5f299cb5485acadd5fddc8d92899d7018223 Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal at paszkowski.org>
Date: Mon, 3 Jun 2024 10:21:39 -0700
Subject: [PATCH 2/2] [SPIR-V] Fix tracking ptr null constants of builtin types
 in calls with mangling

This change makes sure that ptr null (constant) arguments of builtin
function calls are assigned proper builtin types if such are deduced
from mangled names.

Two tests demonstrating the expected bahavior (as in the SPIR-V
Translator) are added.

WIP:
- The test builtin-call-multiple-ptr-null-args-one-of-builtin-
type.ll is failing and requires additional work.

- processInstrAfterVisit method must be simplified. TrackConstants can
  be removed.
---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  5 +--
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 43 +++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  3 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  2 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  6 +--
 ...tiple-ptr-null-args-one-of-builtin-type.ll | 13 ++++++
 ...iltin-call-ptr-null-arg-of-builtin-type.ll | 13 ++++++
 7 files changed, 69 insertions(+), 16 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 51b0bce46691e..7d4d9801c7ce2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -258,8 +258,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
             cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
   }
 
-  std::string DemangledFuncName =
-      getOclOrSpirvBuiltinDemangledName(F.getName());
+  std::string DemangledFuncName = demangleBuiltinCall(F.getName());
   if (!DemangledFuncName.empty()) {
     Type *BuiltinType = SPIRV::parseBuiltinCallArgumentBaseType(
         DemangledFuncName, ArgIdx, F.getContext());
@@ -521,7 +520,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // globally later.
   if (Info.Callee.isGlobal()) {
     std::string FuncName = Info.Callee.getGlobal()->getName().str();
-    DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
+    DemangledName = demangleBuiltinCall(FuncName);
     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index ffbd1e17bad5e..7e9347fa2fbc9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -851,6 +851,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
   I->setOperand(OperandToReplace, PtrCastI);
 }
 
+static bool inline isDirectNonIntrinsicCall(CallInst *CI) {
+  return CI && !CI->isIndirectCall() && !CI->isInlineAsm() &&
+         CI->getCalledFunction() && !CI->getCalledFunction()->isIntrinsic();
+}
+
 void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
                                                          IRBuilder<> &B) {
   // Handle basic instructions:
@@ -874,14 +879,12 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
 
   // Handle calls to builtins (non-intrinsics):
   CallInst *CI = dyn_cast<CallInst>(I);
-  if (!CI || CI->isIndirectCall() || CI->isInlineAsm() ||
-      !CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic())
+  if (!isDirectNonIntrinsicCall(CI))
     return;
 
-  // collect information about formal parameter types
-  std::string DemangledName =
-      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
+  // Collect information about formal parameter types
   Function *CalledF = CI->getCalledFunction();
+  std::string DemangledName = demangleBuiltinCall(CalledF->getName());
   SmallVector<Type *, 4> CalledArgTys;
   bool HaveTypes = false;
   for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) {
@@ -1195,12 +1198,21 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
     I->replaceAllUsesWith(NewOp);
     NewOp->setArgOperand(0, I);
   }
+
+  auto *CI = dyn_cast<CallInst>(I);
+  bool IsCall = CI && isDirectNonIntrinsicCall(CI);
+  std::string DemangledCall =
+      IsCall ? demangleBuiltinCall(CI->getCalledFunction()->getName()) : "";
+
   bool IsPhi = isa<PHINode>(I), BPrepared = false;
+
   for (const auto &Op : I->operands()) {
     if ((isa<ConstantAggregateZero>(Op) && Op->getType()->isVectorTy()) ||
         isa<PHINode>(I) || isa<SwitchInst>(I))
       TrackConstants = false;
     if ((isa<ConstantData>(Op) || isa<ConstantExpr>(Op)) && TrackConstants) {
+      Constant *OpConst = cast<Constant>(Op);
+
       unsigned OpNo = Op.getOperandNo();
       if (II && ((II->getIntrinsicID() == Intrinsic::spv_gep && OpNo == 0) ||
                  (II->paramHasAttr(OpNo, Attribute::ImmArg))))
@@ -1210,12 +1222,27 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
               : B.SetInsertPoint(I);
         BPrepared = true;
       }
+
       Value *OpTyVal = Op;
-      if (Op->getType()->isTargetExtTy())
+      Value *ConstVal = Op;
+
+      if (Op->getType()->isTargetExtTy() ||
+          (Op->getType()->isPointerTy() && !DemangledCall.empty() &&
+           OpConst->isNullValue())) {
         OpTyVal = PoisonValue::get(Op->getType());
+      }
+
+      if (Op->getType()->isPointerTy() && !DemangledCall.empty() &&
+          OpConst->isNullValue()) {
+        Type *DemangledTy = SPIRV::parseBuiltinCallArgumentBaseType(
+            DemangledCall, Op.getOperandNo(), I->getContext());
+        if (DemangledTy)
+          ConstVal = Constant::getNullValue(DemangledTy);
+      }
+
       auto *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant,
-                                    {Op->getType(), OpTyVal->getType()}, Op,
-                                    OpTyVal, {}, B);
+                                    {Op->getType(), OpTyVal->getType()},
+                                    ConstVal, OpTyVal, {}, B);
       I->setOperand(OpNo, NewOp);
     }
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 624899600693a..3ce3b9b3b1e64 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -81,13 +81,14 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           }
           GR->add(Const, &MF, SrcReg);
           if (Const->getType()->isTargetExtTy()) {
-            // remember association so that we can restore it when assign types
+            // Remember association so that we can restore it when assign types.
             MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
             if (SrcMI && (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT ||
                           SrcMI->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
               TargetExtConstTypes[SrcMI] = Const->getType();
             if (Const->isNullValue()) {
               MachineIRBuilder MIB(MF);
+              MIB.setInsertPt(*MI.getParent(), MI);
               SPIRVType *ExtType =
                   GR->getOrCreateSPIRVType(Const->getType(), MIB);
               SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c20f3546a3e55..b8cd83d1ca975 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -332,7 +332,7 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
          Name == "__translate_sampler_initializer";
 }
 
-std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
+std::string demangleBuiltinCall(StringRef Name) {
   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
   bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 33cb509dc4a59..a539e960f22a2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -90,9 +90,9 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID);
 // Get type of i-th operand of the metadata node.
 Type *getMDOperandAsType(const MDNode *N, unsigned I);
 
-// If OpenCL or SPIR-V builtin function name is recognized, return a demangled
-// name, otherwise return an empty string.
-std::string getOclOrSpirvBuiltinDemangledName(StringRef Name);
+// If SPIR-V builtin function name is recognized, return a demangled name,
+// otherwise return an empty string.
+std::string demangleBuiltinCall(StringRef Name);
 
 // Check if a string contains a builtin prefix.
 bool hasBuiltinTypePrefix(StringRef Name);
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll
new file mode 100644
index 0000000000000..103fab2838607
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-multiple-ptr-null-args-one-of-builtin-type.ll
@@ -0,0 +1,13 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent
+; CHECK-DAG: %[[#EVENT_NULL:]] = OpConstantNull %[[#EVENT]]
+; CHECK-DAG: %[[#]] = OpFunctionCall %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#EVENT_NULL]]
+
+define spir_kernel void @foo() {
+  %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr null, ptr null, i64 1, i64 1, ptr null)
+  ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll
new file mode 100644
index 0000000000000..51d094e398216
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-call-ptr-null-arg-of-builtin-type.ll
@@ -0,0 +1,13 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#EVENT:]] = OpTypeEvent
+; CHECK-DAG: %[[#EVENT_NULL:]] = OpConstantNull %[[#EVENT]]
+; CHECK-DAG: %[[#]] = OpFunctionCall %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#EVENT_NULL]]
+
+define spir_kernel void @foo(ptr %a, ptr %b) {
+  %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr %a, ptr %b, i64 1, i64 1, ptr null)
+  ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)



More information about the llvm-commits mailing list