[llvm] [SPIR-V] Improve type inference for a known instruction's builtin: OpGroupAsyncCopy (PR #96895)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 2 04:57:06 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR improves type inference for a known instruction's builtin: OpGroupAsyncCopy:
* deduce a type of one source/destination pointer when it's possible to deduce a type of another argument, and
* validate src and dest types and tries to unfold a parameter if it's a structure wrapper around a scalar/vector type.

---
Full diff: https://github.com/llvm/llvm-project/pull/96895.diff


5 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+73-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+27-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+37) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll (+51-3) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b93a4d85eedf..5e5fa26a0f5f7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2377,9 +2377,80 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
-/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
-/// and external instruction \p Set.
 namespace SPIRV {
+// Try to find a builtin funtion attributes by a demangled function name and
+// return a tuple <builtin group, op code, ext instruction number>, or a special
+// tuple value <-1, 0, 0> if the builtin funtion is not found.
+// Not all builtin funtions are supported, only those with a ready-to-use op
+// code or instruction number defined in TableGen.
+// TODO: consider a major rework of mapping demangled calls into a builtin
+// functions to unify search and decrease number of individual cases.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+                   SPIRV::InstructionSet::InstructionSet Set) {
+  Register Reg;
+  SmallVector<Register> Args;
+  std::unique_ptr<const IncomingCall> Call =
+      lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
+  if (!Call)
+    return std::make_tuple(-1, 0, 0);
+
+  switch (Call->Builtin->Group) {
+  case SPIRV::Relational:
+  case SPIRV::Atomic:
+  case SPIRV::Barrier:
+  case SPIRV::CastToPtr:
+  case SPIRV::ImageMiscQuery:
+  case SPIRV::SpecConstant:
+  case SPIRV::Enqueue:
+  case SPIRV::AsyncCopy:
+  case SPIRV::LoadStore:
+  case SPIRV::CoopMatr:
+    if (const auto *R =
+            SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::Extended:
+    if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
+                                                     Call->Builtin->Set))
+      return std::make_tuple(Call->Builtin->Group, 0, R->Number);
+    break;
+  case SPIRV::VectorLoadStore:
+    if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
+                                                            Call->Builtin->Set))
+      return std::make_tuple(SPIRV::Extended, 0, R->Number);
+    break;
+  case SPIRV::Group:
+    if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::AtomicFloating:
+    if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::IntelSubgroups:
+    if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::GroupUniform:
+    if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
+      return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+    break;
+  case SPIRV::WriteImage:
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
+  case SPIRV::Select:
+    return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
+  case SPIRV::Construct:
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
+                           0);
+  case SPIRV::KernelClock:
+    return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
+  default:
+    return std::make_tuple(-1, 0, 0);
+  }
+  return std::make_tuple(-1, 0, 0);
+}
+
 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRV::InstructionSet::InstructionSet Set,
                                  MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 649f5bfd1d7c2..482b2ed853f7d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,12 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
+/// Helper external function for finding a builtin funtion attributes
+/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+                   SPIRV::InstructionSet::InstructionSet Set);
+
 /// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
 /// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
 /// element type (e.g. i8 for char*), or builtin type (TargetExtType).
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index dd5884096b85d..04e29b24d8fd6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,6 +67,7 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Constant *> AggrConsts;
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
+  SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
@@ -384,9 +385,10 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
       auto AsArgIt = ResTypeByArg.find(DemangledName);
-      if (AsArgIt != ResTypeByArg.end())
+      if (AsArgIt != ResTypeByArg.end()) {
         Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
                                      Visited);
+      }
     }
   }
 
@@ -544,6 +546,28 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
       KnownElemTy = ElemTy1;
       Ops.push_back(std::make_pair(Op0, 0));
     }
+  } else if (auto *CI = dyn_cast<CallInst>(I)) {
+    if (Function *CalledF = CI->getCalledFunction()) {
+      std::string DemangledName =
+          getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+      if (DemangledName.length() > 0 &&
+          !StringRef(DemangledName).starts_with("llvm.")) {
+        auto [Grp, Opcode, ExtNo] =
+            SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+        if (Opcode == SPIRV::OpGroupAsyncCopy) {
+          for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+               ++i) {
+            Value *Op = CI->getArgOperand(i);
+            if (!isPointerTy(Op->getType()))
+              continue;
+            ++PtrCnt;
+            if (Type *ElemTy = GR->findDeducedElementType(Op))
+              KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+            Ops.push_back(std::make_pair(Op, i));
+          }
+        }
+      }
+    }
   }
 
   // There is no enough info to deduce types or all is valid.
@@ -1385,6 +1409,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
 
   const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
   GR = ST.getSPIRVGlobalRegistry();
+  InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
+                              : SPIRV::InstructionSet::GLSL_std_450;
 
   F = &Func;
   IRBuilder<> B(Func.getContext());
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 4383d1c5c0e25..2344ec529e16d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -203,6 +203,39 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
 }
 
+static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
+                                      MachineRegisterInfo *MRI,
+                                      SPIRVGlobalRegistry &GR, MachineInstr &I,
+                                      unsigned OpIdx) {
+  MachineFunction *MF = I.getParent()->getParent();
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  Register OpTypeReg = getTypeReg(MRI, OpReg);
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+  if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+    return;
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
+      ElemType->getNumOperands() != 2)
+    return;
+  // It's a structure-wrapper around another type with a single member field.
+  SPIRVType *MemberType =
+      GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
+  if (!MemberType)
+    return;
+  unsigned MemberTypeOp = MemberType->getOpcode();
+  if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
+      MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
+    return;
+  // It's a structure-wrapper around a valid type. Insert a bitcast before the
+  // instruction to keep SPIR-V code valid.
+  SPIRV::StorageClass::StorageClass SC =
+      static_cast<SPIRV::StorageClass::StorageClass>(
+          OpType->getOperand(1).getImm());
+  MachineIRBuilder MIB(I);
+  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
+  doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
 // Insert a bitcast before the function call instruction to keep SPIR-V code
 // valid when there is a type mismatch between actual and expected types of an
 // argument:
@@ -380,6 +413,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
                                       SPIRV::OpTypeBool))
           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
         break;
+      case SPIRV::OpGroupAsyncCopy:
+        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
+        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+        break;
       case SPIRV::OpGroupWaitEvents:
         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
         validateGroupWaitEventsPtr(STI, MRI, GR, MI);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index fe0d96f2773ec..df11565ca8180 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -8,7 +8,18 @@
 ; CHECK-DAG: %[[#TyStruct:]] = OpTypeStruct %[[#TyEvent]]
 ; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
+; CHECK-DAG: %[[#TyEventPtrGen:]] = OpTypePointer Generic %[[#TyEvent]]
 ; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
+; CHECK-DAG: %[[#TyChar:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#TyV4:]] = OpTypeVector %[[#TyChar]] 4
+; CHECK-DAG: %[[#TyStructV4:]] = OpTypeStruct %[[#TyV4]]
+; CHECK-DAG: %[[#TyPtrSV4_W:]] = OpTypePointer Workgroup %[[#TyStructV4]]
+; CHECK-DAG: %[[#TyPtrSV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyStructV4]]
+; CHECK-DAG: %[[#TyPtrV4_W:]] = OpTypePointer Workgroup %[[#TyV4]]
+; CHECK-DAG: %[[#TyPtrV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyV4]]
+
+; Check correct translation of __spirv_GroupAsyncCopy and target("spirv.Event") zeroinitializer
+
 ; CHECK: OpFunction
 ; CHECK: OpFunctionParameter
 ; CHECK: %[[#Src:]] = OpFunctionParameter
@@ -17,12 +28,13 @@
 ; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
 ; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
+; CHECK: OpFunctionEnd
 
-%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+%StructEvent = type { target("spirv.Event") }
 
-define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
+define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
 entry:
-  %var = alloca %"class.sycl::_V1::device_event"
+  %var = alloca %StructEvent
   %dev_event.i.sroa.0 = alloca target("spirv.Event")
   %add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
   %call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
@@ -31,3 +43,39 @@ entry:
 }
 
 declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
+
+; Check correct type inference when calling __spirv_GroupAsyncCopy:
+; we expect that the Backend is able to deduce a type of the %_arg_Local
+; given facts that it's possible to deduce a type of the %_arg
+; and %_arg_Local and %_arg are source/destination arguments in OpGroupAsyncCopy
+
+; CHECK: OpFunction
+; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
+; CHECK: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
+; CHECK: %[[#EventVarBar:]] = OpVariable %[[#TyStructPtr]] Function
+; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
+; CHECK-DAG: %[[#BarArg1Casted:]] = OpBitcast %[[#TyPtrV4_W]] %[[#BarArg1]]
+; CHECK-DAG: %[[#SrcBarCasted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#SrcBar]]
+; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1Casted]] %[[#SrcBarCasted]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK: %[[#EventVarBarCasted:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
+; CHECK: OpStore %[[#EventVarBarCasted]] %[[#ResBar]]
+; CHECK: %[[#EventVarBarCasted2:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
+; CHECK: %[[#EventVarBarGen:]] = OpPtrCastToGeneric %[[#TyEventPtrGen]] %[[#EventVarBarCasted2]]
+; CHECK: OpGroupWaitEvents %[[#]] %[[#]] %[[#EventVarBarGen]]
+; CHECK: OpFunctionEnd
+
+%Vec4 = type { <4 x i8> }
+
+define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg) {
+entry:
+  %E1 = alloca %StructEvent
+  %srcptr = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg, i64 0
+  %r1 = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %srcptr, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+  store target("spirv.Event") %r1, ptr %E1
+  %E.ascast.i = addrspacecast ptr %E1 to ptr addrspace(4)
+  call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
+  ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))

``````````

</details>


https://github.com/llvm/llvm-project/pull/96895


More information about the llvm-commits mailing list