[llvm] [SPIR-V] Improve type inference for a known instruction's builtin: OpGroupAsyncCopy (PR #96895)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 3 02:17:58 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/96895

>From d54f3a10269c82c43ac518502e4e03f3744ab91b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 03:43:34 -0700
Subject: [PATCH 1/7] Improve type inference fo a known instruction's builtin

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 71 ++++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVBuiltins.h         |  6 ++
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 25 ++++++-
 .../SPIRV/transcoding/spirv-event-null.ll     | 23 +++++-
 4 files changed, 119 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b93a4d85eedf..bdada81cef89f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2377,9 +2377,76 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
-/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
-/// and external instruction \p Set.
 namespace SPIRV {
+// Try to find a builtin funtion attributes by a demangled function name and
+// return a tuple <builtin group, op code, ext instruction number>, or a special
+// tuple value <-1, 0, 0> if the builtin funtion is not found.
+// Not all builtin funtions are supported, only those with a ready-to-use op
+// code or instruction number defined in TableGen.
+// TODO: consider a major rework of mapping demangled calls into a builtin
+// functions to unify search and decrease number of individual cases.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+                   SPIRV::InstructionSet::InstructionSet Set) {
+  const SPIRV::DemangledBuiltin *Builtin =
+      SPIRV::lookupBuiltin(DemangledCall, Set);
+  if (!Builtin)
+    return std::make_tuple(-1, 0, 0);
+
+  switch (Builtin->Group) {
+  case SPIRV::Relational:
+  case SPIRV::Atomic:
+  case SPIRV::Barrier:
+  case SPIRV::CastToPtr:
+  case SPIRV::ImageMiscQuery:
+  case SPIRV::SpecConstant:
+  case SPIRV::Enqueue:
+  case SPIRV::AsyncCopy:
+  case SPIRV::LoadStore:
+  case SPIRV::CoopMatr:
+    if (const auto *R = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::Extended:
+    if (const auto *R =
+            SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set))
+      return std::make_tuple(Builtin->Group, 0, R->Number);
+    break;
+  case SPIRV::VectorLoadStore:
+    if (const auto *R =
+            SPIRV::lookupVectorLoadStoreBuiltin(Builtin->Name, Builtin->Set))
+      return std::make_tuple(SPIRV::Extended, 0, R->Number);
+    break;
+  case SPIRV::Group:
+    if (const auto *R = SPIRV::lookupGroupBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::AtomicFloating:
+    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::IntelSubgroups:
+    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::GroupUniform:
+    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Builtin->Name))
+      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::WriteImage:
+    return std::make_tuple(Builtin->Group, SPIRV::OpImageWrite, 0);
+  case SPIRV::Select:
+    return std::make_tuple(Builtin->Group, TargetOpcode::G_SELECT, 0);
+  case SPIRV::Construct:
+    return std::make_tuple(Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+  case SPIRV::KernelClock:
+    return std::make_tuple(Builtin->Group, SPIRV::OpReadClockKHR, 0);
+  default:
+    return std::make_tuple(-1, 0, 0);
+  }
+  return std::make_tuple(-1, 0, 0);
+}
+
 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRV::InstructionSet::InstructionSet Set,
                                  MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 649f5bfd1d7c2..482b2ed853f7d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,12 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
+/// Helper external function for finding a builtin funtion attributes
+/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+                   SPIRV::InstructionSet::InstructionSet Set);
+
 /// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
 /// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
 /// element type (e.g. i8 for char*), or builtin type (TargetExtType).
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index dd5884096b85d..53f00eb2d1f10 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,6 +67,7 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Constant *> AggrConsts;
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
+  SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
@@ -384,9 +385,10 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
       auto AsArgIt = ResTypeByArg.find(DemangledName);
-      if (AsArgIt != ResTypeByArg.end())
+      if (AsArgIt != ResTypeByArg.end()) {
         Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
                                      Visited);
+      }
     }
   }
 
@@ -544,6 +546,25 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
       KnownElemTy = ElemTy1;
       Ops.push_back(std::make_pair(Op0, 0));
     }
+  } else if (auto *CI = dyn_cast<CallInst>(I)) {
+    if (Function *CalledF = CI->getCalledFunction()) {
+      std::string DemangledName =
+          getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+      auto [Grp, Opcode, ExtNo] =
+          SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+      if (Opcode == SPIRV::OpGroupAsyncCopy) {
+        for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+             ++i) {
+          Value *Op = CI->getArgOperand(i);
+          if (!isPointerTy(Op->getType()))
+            continue;
+          ++PtrCnt;
+          if (Type *ElemTy = GR->findDeducedElementType(Op))
+            KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+          Ops.push_back(std::make_pair(Op, i));
+        }
+      }
+    }
   }
 
   // There is no enough info to deduce types or all is valid.
@@ -1385,6 +1406,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
 
   const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
   GR = ST.getSPIRVGlobalRegistry();
+  InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
+                              : SPIRV::InstructionSet::GLSL_std_450;
 
   F = &Func;
   IRBuilder<> B(Func.getContext());
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index fe0d96f2773ec..82eff6f14469d 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -9,6 +9,7 @@
 ; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
+
 ; CHECK: OpFunction
 ; CHECK: OpFunctionParameter
 ; CHECK: %[[#Src:]] = OpFunctionParameter
@@ -18,11 +19,11 @@
 ; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
 
-%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+%StructEvent = type { target("spirv.Event") }
 
-define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
+define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
 entry:
-  %var = alloca %"class.sycl::_V1::device_event"
+  %var = alloca %StructEvent
   %dev_event.i.sroa.0 = alloca target("spirv.Event")
   %add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
   %call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
@@ -31,3 +32,19 @@ entry:
 }
 
 declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
+
+%Vec4 = type { <4 x i8> }
+
+define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg_In) {
+entry:
+  %E.i = alloca %StructEvent
+  %add.ptr.i42 = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg_In, i64 0
+  %call3.i.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %add.ptr.i42, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+  store target("spirv.Event") %call3.i.i, ptr %E.i
+  %E.ascast.i = addrspacecast ptr %E.i to ptr addrspace(4)
+  call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
+  ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))

>From ebf6d801693fd1ca0e7d86f5ff75efd667911e90 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 10:41:46 -0700
Subject: [PATCH 2/7] deduce type of arguments of OpGroupAsyncCopy

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 44 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 26 ++++++-----
 2 files changed, 37 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index bdada81cef89f..c99165e786216 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2388,12 +2388,14 @@ namespace SPIRV {
 std::tuple<int, unsigned, unsigned>
 mapBuiltinToOpcode(const StringRef DemangledCall,
                    SPIRV::InstructionSet::InstructionSet Set) {
-  const SPIRV::DemangledBuiltin *Builtin =
-      SPIRV::lookupBuiltin(DemangledCall, Set);
-  if (!Builtin)
+  Register Reg;
+  SmallVector<Register> Args;
+  std::unique_ptr<const IncomingCall> Call =
+      lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
+  if (!Call)
     return std::make_tuple(-1, 0, 0);
 
-  switch (Builtin->Group) {
+  switch (Call->Builtin->Group) {
   case SPIRV::Relational:
   case SPIRV::Atomic:
   case SPIRV::Barrier:
@@ -2404,43 +2406,43 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
   case SPIRV::AsyncCopy:
   case SPIRV::LoadStore:
   case SPIRV::CoopMatr:
-    if (const auto *R = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::Extended:
     if (const auto *R =
-            SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set))
-      return std::make_tuple(Builtin->Group, 0, R->Number);
+            SPIRV::lookupExtendedBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+      return std::make_tuple(Call->Builtin->Group, 0, R->Number);
     break;
   case SPIRV::VectorLoadStore:
     if (const auto *R =
-            SPIRV::lookupVectorLoadStoreBuiltin(Builtin->Name, Builtin->Set))
+            SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name, Call->Builtin->Set))
       return std::make_tuple(SPIRV::Extended, 0, R->Number);
     break;
   case SPIRV::Group:
-    if (const auto *R = SPIRV::lookupGroupBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::AtomicFloating:
-    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::IntelSubgroups:
-    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::GroupUniform:
-    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Builtin->Name))
-      return std::make_tuple(Builtin->Group, R->Opcode, 0);
+    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::WriteImage:
-    return std::make_tuple(Builtin->Group, SPIRV::OpImageWrite, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
   case SPIRV::Select:
-    return std::make_tuple(Builtin->Group, TargetOpcode::G_SELECT, 0);
+    return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
   case SPIRV::Construct:
-    return std::make_tuple(Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct, 0);
   case SPIRV::KernelClock:
-    return std::make_tuple(Builtin->Group, SPIRV::OpReadClockKHR, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
   default:
     return std::make_tuple(-1, 0, 0);
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 53f00eb2d1f10..5e7891496ce93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -550,18 +550,20 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
     if (Function *CalledF = CI->getCalledFunction()) {
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
-      auto [Grp, Opcode, ExtNo] =
-          SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
-      if (Opcode == SPIRV::OpGroupAsyncCopy) {
-        for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
-             ++i) {
-          Value *Op = CI->getArgOperand(i);
-          if (!isPointerTy(Op->getType()))
-            continue;
-          ++PtrCnt;
-          if (Type *ElemTy = GR->findDeducedElementType(Op))
-            KnownElemTy = ElemTy; // src will rewrite dest if both are defined
-          Ops.push_back(std::make_pair(Op, i));
+      if (!StringRef(DemangledName).starts_with("llvm.")) {
+        auto [Grp, Opcode, ExtNo] =
+            SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+        if (Opcode == SPIRV::OpGroupAsyncCopy) {
+          for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+               ++i) {
+            Value *Op = CI->getArgOperand(i);
+            if (!isPointerTy(Op->getType()))
+              continue;
+            ++PtrCnt;
+            if (Type *ElemTy = GR->findDeducedElementType(Op))
+              KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+            Ops.push_back(std::make_pair(Op, i));
+          }
         }
       }
     }

>From b448cb76dc4a5005f8137a70e288490a76357081 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 11:13:20 -0700
Subject: [PATCH 3/7] fix functions name condition

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5e7891496ce93..04e29b24d8fd6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -550,7 +550,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
     if (Function *CalledF = CI->getCalledFunction()) {
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
-      if (!StringRef(DemangledName).starts_with("llvm.")) {
+      if (DemangledName.length() > 0 &&
+          !StringRef(DemangledName).starts_with("llvm.")) {
         auto [Grp, Opcode, ExtNo] =
             SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
         if (Opcode == SPIRV::OpGroupAsyncCopy) {

>From be174126a27c02d242ee813c65628d4d311de38c Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 1 Jul 2024 04:58:07 -0700
Subject: [PATCH 4/7] harden the test case

---
 .../SPIRV/transcoding/spirv-event-null.ll     | 28 +++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index 82eff6f14469d..8b8c1d1341e1b 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -8,8 +8,17 @@
 ; CHECK-DAG: %[[#TyStruct:]] = OpTypeStruct %[[#TyEvent]]
 ; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
+; CHECK-DAG: %[[#TyEventPtrGen:]] = OpTypePointer Generic %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
 
+; CHECK-DAG: %[[#TyChar:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#TyV4:]] = OpTypeVector %[[#TyChar]] 4
+; CHECK-DAG: %[[#TyStructV4:]] = OpTypeStruct %[[#TyV4]]
+; CHECK-DAG: %[[#TyPtrSV4_W:]] = OpTypePointer Workgroup %[[#TyStructV4]]
+; CHECK-DAG: %[[#TyPtrSV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyStructV4]]
+
+; Check correct translation of __spirv_GroupAsyncCopy and target("spirv.Event") zeroinitializer
+
 ; CHECK: OpFunction
 ; CHECK: OpFunctionParameter
 ; CHECK: %[[#Src:]] = OpFunctionParameter
@@ -18,6 +27,7 @@
 ; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
 ; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
+; CHECK: OpFunctionEnd
 
 %StructEvent = type { target("spirv.Event") }
 
@@ -33,6 +43,24 @@ entry:
 
 declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
 
+; Check correct type inference when calling __spirv_GroupAsyncCopy:
+; we expect that the Backend is able to deduce a type of the %_arg_Local
+; given facts that it's possible to deduce a type of the %_arg_In
+; and %_arg_Local and %_arg_In are source/destination arguments in OpGroupAsyncCopy
+
+; CHECK: OpFunction
+; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
+; CHECK: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
+; CHECK: %[[#EventVarBar:]] = OpVariable %[[#TyStructPtr]] Function
+; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
+; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1]] %[[#SrcBar]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK: %[[#EventVarBarCasted:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
+; CHECK: OpStore %[[#EventVarBarCasted]] %[[#ResBar]]
+; CHECK: %[[#EventVarBarCasted2:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
+; CHECK: %[[#EventVarBarGen:]] = OpPtrCastToGeneric %[[#TyEventPtrGen]] %[[#EventVarBarCasted2]]
+; CHECK: OpGroupWaitEvents %[[#]] %[[#]] %[[#EventVarBarGen]]
+; CHECK: OpFunctionEnd
+
 %Vec4 = type { <4 x i8> }
 
 define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg_In) {

>From fff0081b0a5ae2ba9e6fc2a60d2ce0dffde57283 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 1 Jul 2024 06:49:28 -0700
Subject: [PATCH 5/7] add validation of GroupAsyncCopy arguments

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 14 ++++---
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   | 37 +++++++++++++++++++
 .../SPIRV/transcoding/spirv-event-null.ll     | 23 +++++++-----
 3 files changed, 58 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c99165e786216..5e5fa26a0f5f7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2406,17 +2406,18 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
   case SPIRV::AsyncCopy:
   case SPIRV::LoadStore:
   case SPIRV::CoopMatr:
-    if (const auto *R = SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+    if (const auto *R =
+            SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
     break;
   case SPIRV::Extended:
-    if (const auto *R =
-            SPIRV::lookupExtendedBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+    if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
+                                                     Call->Builtin->Set))
       return std::make_tuple(Call->Builtin->Group, 0, R->Number);
     break;
   case SPIRV::VectorLoadStore:
-    if (const auto *R =
-            SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+    if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
+                                                            Call->Builtin->Set))
       return std::make_tuple(SPIRV::Extended, 0, R->Number);
     break;
   case SPIRV::Group:
@@ -2440,7 +2441,8 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
   case SPIRV::Select:
     return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
   case SPIRV::Construct:
-    return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
+                           0);
   case SPIRV::KernelClock:
     return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
   default:
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 4383d1c5c0e25..2344ec529e16d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -203,6 +203,39 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
 }
 
+static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
+                                      MachineRegisterInfo *MRI,
+                                      SPIRVGlobalRegistry &GR, MachineInstr &I,
+                                      unsigned OpIdx) {
+  MachineFunction *MF = I.getParent()->getParent();
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  Register OpTypeReg = getTypeReg(MRI, OpReg);
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+  if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+    return;
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
+      ElemType->getNumOperands() != 2)
+    return;
+  // It's a structure-wrapper around another type with a single member field.
+  SPIRVType *MemberType =
+      GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
+  if (!MemberType)
+    return;
+  unsigned MemberTypeOp = MemberType->getOpcode();
+  if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
+      MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
+    return;
+  // It's a structure-wrapper around a valid type. Insert a bitcast before the
+  // instruction to keep SPIR-V code valid.
+  SPIRV::StorageClass::StorageClass SC =
+      static_cast<SPIRV::StorageClass::StorageClass>(
+          OpType->getOperand(1).getImm());
+  MachineIRBuilder MIB(I);
+  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
+  doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
 // Insert a bitcast before the function call instruction to keep SPIR-V code
 // valid when there is a type mismatch between actual and expected types of an
 // argument:
@@ -380,6 +413,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
                                       SPIRV::OpTypeBool))
           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
         break;
+      case SPIRV::OpGroupAsyncCopy:
+        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
+        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+        break;
       case SPIRV::OpGroupWaitEvents:
         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
         validateGroupWaitEventsPtr(STI, MRI, GR, MI);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index 8b8c1d1341e1b..df11565ca8180 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -10,12 +10,13 @@
 ; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyEventPtrGen:]] = OpTypePointer Generic %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
-
 ; CHECK-DAG: %[[#TyChar:]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[#TyV4:]] = OpTypeVector %[[#TyChar]] 4
 ; CHECK-DAG: %[[#TyStructV4:]] = OpTypeStruct %[[#TyV4]]
 ; CHECK-DAG: %[[#TyPtrSV4_W:]] = OpTypePointer Workgroup %[[#TyStructV4]]
 ; CHECK-DAG: %[[#TyPtrSV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyStructV4]]
+; CHECK-DAG: %[[#TyPtrV4_W:]] = OpTypePointer Workgroup %[[#TyV4]]
+; CHECK-DAG: %[[#TyPtrV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyV4]]
 
 ; Check correct translation of __spirv_GroupAsyncCopy and target("spirv.Event") zeroinitializer
 
@@ -45,15 +46,17 @@ declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU
 
 ; Check correct type inference when calling __spirv_GroupAsyncCopy:
 ; we expect that the Backend is able to deduce a type of the %_arg_Local
-; given facts that it's possible to deduce a type of the %_arg_In
-; and %_arg_Local and %_arg_In are source/destination arguments in OpGroupAsyncCopy
+; given facts that it's possible to deduce a type of the %_arg
+; and %_arg_Local and %_arg are source/destination arguments in OpGroupAsyncCopy
 
 ; CHECK: OpFunction
 ; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
 ; CHECK: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
 ; CHECK: %[[#EventVarBar:]] = OpVariable %[[#TyStructPtr]] Function
 ; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
-; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1]] %[[#SrcBar]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK-DAG: %[[#BarArg1Casted:]] = OpBitcast %[[#TyPtrV4_W]] %[[#BarArg1]]
+; CHECK-DAG: %[[#SrcBarCasted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#SrcBar]]
+; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1Casted]] %[[#SrcBarCasted]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: %[[#EventVarBarCasted:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
 ; CHECK: OpStore %[[#EventVarBarCasted]] %[[#ResBar]]
 ; CHECK: %[[#EventVarBarCasted2:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
@@ -63,13 +66,13 @@ declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU
 
 %Vec4 = type { <4 x i8> }
 
-define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg_In) {
+define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg) {
 entry:
-  %E.i = alloca %StructEvent
-  %add.ptr.i42 = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg_In, i64 0
-  %call3.i.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %add.ptr.i42, i64 16, i64 10, target("spirv.Event") zeroinitializer)
-  store target("spirv.Event") %call3.i.i, ptr %E.i
-  %E.ascast.i = addrspacecast ptr %E.i to ptr addrspace(4)
+  %E1 = alloca %StructEvent
+  %srcptr = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg, i64 0
+  %r1 = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %srcptr, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+  store target("spirv.Event") %r1, ptr %E1
+  %E.ascast.i = addrspacecast ptr %E1 to ptr addrspace(4)
   call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
   ret void
 }

>From a1f8ccb6e7ccd984ddc31039f69927647ed874bc Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 2 Jul 2024 06:42:47 -0700
Subject: [PATCH 6/7] account for '(anonymous namespace)::' and possibility

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 5e5fa26a0f5f7..6cd8a749912cf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -184,10 +184,16 @@ lookupBuiltin(StringRef DemangledCall,
               SPIRV::InstructionSet::InstructionSet Set,
               Register ReturnRegister, const SPIRVType *ReturnType,
               const SmallVectorImpl<Register> &Arguments) {
+  const static std::string PassPrefix = "(anonymous namespace)::";
+  std::string BuiltinName;
+  // Itanium Demangler result may have "(anonymous namespace)::" prefix
+  if (DemangledCall.starts_with(PassPrefix.c_str()))
+    BuiltinName = DemangledCall.substr(PassPrefix.length());
+  else
+    BuiltinName = DemangledCall;
   // Extract the builtin function name and types of arguments from the call
   // skeleton.
-  std::string BuiltinName =
-      DemangledCall.substr(0, DemangledCall.find('(')).str();
+  BuiltinName = BuiltinName.substr(0, BuiltinName.find('('));
 
   // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
   if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)

>From 6c4dcc466cbd8eca14468632e53291e8d22ea2a7 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 3 Jul 2024 02:16:48 -0700
Subject: [PATCH 7/7] fix typo

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 6 +++---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.h   | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6cd8a749912cf..dfec10bec3f9e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2384,10 +2384,10 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
 }
 
 namespace SPIRV {
-// Try to find a builtin funtion attributes by a demangled function name and
+// Try to find a builtin function attributes by a demangled function name and
 // return a tuple <builtin group, op code, ext instruction number>, or a special
-// tuple value <-1, 0, 0> if the builtin funtion is not found.
-// Not all builtin funtions are supported, only those with a ready-to-use op
+// tuple value <-1, 0, 0> if the builtin function is not found.
+// Not all builtin functions are supported, only those with a ready-to-use op
 // code or instruction number defined in TableGen.
 // TODO: consider a major rework of mapping demangled calls into a builtin
 // functions to unify search and decrease number of individual cases.
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 482b2ed853f7d..68bff602d1d10 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -19,7 +19,7 @@
 
 namespace llvm {
 namespace SPIRV {
-/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
+/// Lowers a builtin function call using the provided \p DemangledCall skeleton
 /// and external instruction \p Set.
 ///
 /// \return the lowering success status if the called function is a recognized
@@ -38,7 +38,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
-/// Helper external function for finding a builtin funtion attributes
+/// Helper function for finding a builtin function attributes
 /// by a demangled function name. Defined in SPIRVBuiltins.cpp.
 std::tuple<int, unsigned, unsigned>
 mapBuiltinToOpcode(const StringRef DemangledCall,



More information about the llvm-commits mailing list