[llvm] [SPIR-V] Improve type inference for a known instruction's builtin: OpGroupAsyncCopy (PR #96895)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 27 11:13:31 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/96895

>From d54f3a10269c82c43ac518502e4e03f3744ab91b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 03:43:34 -0700
Subject: [PATCH 1/3] Improve type inference fo a known instruction's builtin

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 71 ++++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVBuiltins.h         |  6 ++
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 25 ++++++-
 .../SPIRV/transcoding/spirv-event-null.ll     | 23 +++++-
 4 files changed, 119 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b93a4d85eedf..bdada81cef89f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2377,9 +2377,76 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
-/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
-/// and external instruction \p Set.
 namespace SPIRV {
+// Try to find a builtin funtion attributes by a demangled function name and
+// return a tuple <builtin group, op code, ext instruction number>, or a special
+// tuple value <-1, 0, 0> if the builtin funtion is not found.
+// Not all builtin funtions are supported, only those with a ready-to-use op
+// code or instruction number defined in TableGen.
+// TODO: consider a major rework of mapping demangled calls into a builtin
+// functions to unify search and decrease number of individual cases.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+                   SPIRV::InstructionSet::InstructionSet Set) {
+  const SPIRV::DemangledBuiltin *Builtin =
+      SPIRV::lookupBuiltin(DemangledCall, Set);
+  if (!Builtin)
+    return std::make_tuple(-1, 0, 0);
+
+  switch (Builtin->Group) {
+  case SPIRV::Relational:
+  case SPIRV::Atomic:
+  case SPIRV::Barrier:
+  case SPIRV::CastToPtr:
+  case SPIRV::ImageMiscQuery:
+  case SPIRV::SpecConstant:
+  case SPIRV::Enqueue:
+  case SPIRV::AsyncCopy:
+  case SPIRV::LoadStore:
+  case SPIRV::CoopMatr:
+    if (const auto *R = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::Extended:
+    if (const auto *R =
+            SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set))
+      return std::make_tuple(Builtin->Group, 0, R->Number);
+    break;
+  case SPIRV::VectorLoadStore:
+    if (const auto *R =
+            SPIRV::lookupVectorLoadStoreBuiltin(Builtin->Name, Builtin->Set))
+      return std::make_tuple(SPIRV::Extended, 0, R->Number);
+    break;
+  case SPIRV::Group:
+    if (const auto *R = SPIRV::lookupGroupBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::AtomicFloating:
+    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::IntelSubgroups:
+    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::GroupUniform:
+    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::WriteImage:
+    return std::make_tuple(Builtin->Group, SPIRV::OpImageWrite, 0);
+  case SPIRV::Select:
+    return std::make_tuple(Builtin->Group, TargetOpcode::G_SELECT, 0);
+  case SPIRV::Construct:
+    return std::make_tuple(Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+  case SPIRV::KernelClock:
+    return std::make_tuple(Builtin->Group, SPIRV::OpReadClockKHR, 0);
+  default:
+    return std::make_tuple(-1, 0, 0);
+  }
+  return std::make_tuple(-1, 0, 0);
+}
+
 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRV::InstructionSet::InstructionSet Set,
                                  MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 649f5bfd1d7c2..482b2ed853f7d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,12 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
+/// Helper external function for finding a builtin funtion attributes
+/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+                   SPIRV::InstructionSet::InstructionSet Set);
+
 /// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
 /// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
 /// element type (e.g. i8 for char*), or builtin type (TargetExtType).
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index dd5884096b85d..53f00eb2d1f10 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,6 +67,7 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Constant *> AggrConsts;
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
+  SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
@@ -384,9 +385,10 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
       auto AsArgIt = ResTypeByArg.find(DemangledName);
-      if (AsArgIt != ResTypeByArg.end())
+      if (AsArgIt != ResTypeByArg.end()) {
         Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
                                      Visited);
+      }
     }
   }
 
@@ -544,6 +546,25 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
       KnownElemTy = ElemTy1;
       Ops.push_back(std::make_pair(Op0, 0));
     }
+  } else if (auto *CI = dyn_cast<CallInst>(I)) {
+    if (Function *CalledF = CI->getCalledFunction()) {
+      std::string DemangledName =
+          getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+      auto [Grp, Opcode, ExtNo] =
+          SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+      if (Opcode == SPIRV::OpGroupAsyncCopy) {
+        for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+             ++i) {
+          Value *Op = CI->getArgOperand(i);
+          if (!isPointerTy(Op->getType()))
+            continue;
+          ++PtrCnt;
+          if (Type *ElemTy = GR->findDeducedElementType(Op))
+            KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+          Ops.push_back(std::make_pair(Op, i));
+        }
+      }
+    }
   }
 
   // There is no enough info to deduce types or all is valid.
@@ -1385,6 +1406,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
 
   const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
   GR = ST.getSPIRVGlobalRegistry();
+  InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
+                              : SPIRV::InstructionSet::GLSL_std_450;
 
   F = &Func;
   IRBuilder<> B(Func.getContext());
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index fe0d96f2773ec..82eff6f14469d 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -9,6 +9,7 @@
 ; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
+
 ; CHECK: OpFunction
 ; CHECK: OpFunctionParameter
 ; CHECK: %[[#Src:]] = OpFunctionParameter
@@ -18,11 +19,11 @@
 ; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
 
-%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+%StructEvent = type { target("spirv.Event") }
 
-define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
+define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
 entry:
-  %var = alloca %"class.sycl::_V1::device_event"
+  %var = alloca %StructEvent
   %dev_event.i.sroa.0 = alloca target("spirv.Event")
   %add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
   %call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
@@ -31,3 +32,19 @@ entry:
 }
 
 declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
+
+%Vec4 = type { <4 x i8> }
+
+define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg_In) {
+entry:
+  %E.i = alloca %StructEvent
+  %add.ptr.i42 = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg_In, i64 0
+  %call3.i.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %add.ptr.i42, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+  store target("spirv.Event") %call3.i.i, ptr %E.i
+  %E.ascast.i = addrspacecast ptr %E.i to ptr addrspace(4)
+  call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
+  ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))

>From ebf6d801693fd1ca0e7d86f5ff75efd667911e90 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 10:41:46 -0700
Subject: [PATCH 2/3] deduce type of arguments of OpGroupAsyncCopy

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 44 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 26 ++++++-----
 2 files changed, 37 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index bdada81cef89f..c99165e786216 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2388,12 +2388,14 @@ namespace SPIRV {
 std::tuple<int, unsigned, unsigned>
 mapBuiltinToOpcode(const StringRef DemangledCall,
                    SPIRV::InstructionSet::InstructionSet Set) {
-  const SPIRV::DemangledBuiltin *Builtin =
-      SPIRV::lookupBuiltin(DemangledCall, Set);
-  if (!Builtin)
+  Register Reg;
+  SmallVector<Register> Args;
+  std::unique_ptr<const IncomingCall> Call =
+      lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
+  if (!Call)
     return std::make_tuple(-1, 0, 0);
 
-  switch (Builtin->Group) {
+  switch (Call->Builtin->Group) {
   case SPIRV::Relational:
   case SPIRV::Atomic:
   case SPIRV::Barrier:
@@ -2404,43 +2406,43 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
   case SPIRV::AsyncCopy:
   case SPIRV::LoadStore:
   case SPIRV::CoopMatr:
-    if (const auto *R = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::Extended:
     if (const auto *R =
-            SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set))
-      return std::make_tuple(Builtin->Group, 0, R->Number);
+            SPIRV::lookupExtendedBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+      return std::make_tuple(Call->Builtin->Group, 0, R->Number);
     break;
   case SPIRV::VectorLoadStore:
     if (const auto *R =
-            SPIRV::lookupVectorLoadStoreBuiltin(Builtin->Name, Builtin->Set))
+            SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name, Call->Builtin->Set))
       return std::make_tuple(SPIRV::Extended, 0, R->Number);
     break;
   case SPIRV::Group:
-    if (const auto *R = SPIRV::lookupGroupBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::AtomicFloating:
-    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::IntelSubgroups:
-    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::GroupUniform:
-    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::WriteImage:
-    return std::make_tuple(Builtin->Group, SPIRV::OpImageWrite, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
   case SPIRV::Select:
-    return std::make_tuple(Builtin->Group, TargetOpcode::G_SELECT, 0);
+    return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
   case SPIRV::Construct:
-    return std::make_tuple(Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct, 0);
   case SPIRV::KernelClock:
-    return std::make_tuple(Builtin->Group, SPIRV::OpReadClockKHR, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
   default:
     return std::make_tuple(-1, 0, 0);
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 53f00eb2d1f10..5e7891496ce93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -550,18 +550,20 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
     if (Function *CalledF = CI->getCalledFunction()) {
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
-      auto [Grp, Opcode, ExtNo] =
-          SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
-      if (Opcode == SPIRV::OpGroupAsyncCopy) {
-        for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
-             ++i) {
-          Value *Op = CI->getArgOperand(i);
-          if (!isPointerTy(Op->getType()))
-            continue;
-          ++PtrCnt;
-          if (Type *ElemTy = GR->findDeducedElementType(Op))
-            KnownElemTy = ElemTy; // src will rewrite dest if both are defined
-          Ops.push_back(std::make_pair(Op, i));
+      if (!StringRef(DemangledName).starts_with("llvm.")) {
+        auto [Grp, Opcode, ExtNo] =
+            SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+        if (Opcode == SPIRV::OpGroupAsyncCopy) {
+          for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+               ++i) {
+            Value *Op = CI->getArgOperand(i);
+            if (!isPointerTy(Op->getType()))
+              continue;
+            ++PtrCnt;
+            if (Type *ElemTy = GR->findDeducedElementType(Op))
+              KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+            Ops.push_back(std::make_pair(Op, i));
+          }
         }
       }
     }

>From b448cb76dc4a5005f8137a70e288490a76357081 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 11:13:20 -0700
Subject: [PATCH 3/3] fix functions name condition

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5e7891496ce93..04e29b24d8fd6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -550,7 +550,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
     if (Function *CalledF = CI->getCalledFunction()) {
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
-      if (!StringRef(DemangledName).starts_with("llvm.")) {
+      if (DemangledName.length() > 0 &&
+          !StringRef(DemangledName).starts_with("llvm.")) {
         auto [Grp, Opcode, ExtNo] =
             SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
         if (Opcode == SPIRV::OpGroupAsyncCopy) {



More information about the llvm-commits mailing list