[llvm] [SPIR-V] Emit Alignment decoration for alloca instructions and improve type inference (PR #118520)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 4 14:14:57 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/118520

>From 066b4ba899225babc34a7c6ab69332c734a8ab02 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 3 Dec 2024 09:36:07 -0800
Subject: [PATCH 1/5] emit Alignment decoration for alloca's

---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  4 +--
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  9 +++--
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 35 ++++++++++++-------
 3 files changed, 31 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 17b70062e58fa9..1ae3129774e507 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -36,8 +36,8 @@ let TargetPrefix = "spv" in {
   def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
-  def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
-  def int_spv_alloca_array : Intrinsic<[llvm_any_ty], [llvm_anyint_ty]>;
+  def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
+  def int_spv_alloca_array : Intrinsic<[llvm_any_ty], [llvm_anyint_ty, llvm_i8_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_undef : Intrinsic<[llvm_i32_ty], []>;
   def int_spv_inline_asm : Intrinsic<[], [llvm_metadata_ty, llvm_metadata_ty, llvm_vararg_ty]>;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index f45bdfc7aacb72..7e8c669e676fb5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1713,9 +1713,12 @@ Instruction *SPIRVEmitIntrinsics::visitAllocaInst(AllocaInst &I) {
   TrackConstants = false;
   Type *PtrTy = I.getType();
   auto *NewI =
-      ArraySize ? B.CreateIntrinsic(Intrinsic::spv_alloca_array,
-                                    {PtrTy, ArraySize->getType()}, {ArraySize})
-                : B.CreateIntrinsic(Intrinsic::spv_alloca, {PtrTy}, {});
+      ArraySize
+          ? B.CreateIntrinsic(Intrinsic::spv_alloca_array,
+                              {PtrTy, ArraySize->getType()},
+                              {ArraySize, B.getInt8(I.getAlign().value())})
+          : B.CreateIntrinsic(Intrinsic::spv_alloca, {PtrTy},
+                              {B.getInt8(I.getAlign().value())});
   replaceAllUsesWithAndErase(B, &I, NewI);
   return NewI;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3547ac66430a87..3a98b74b3d6757 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3298,12 +3298,17 @@ bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
   // there was an allocation size parameter to the allocation instruction
   // that is not 1
   MachineBasicBlock &BB = *I.getParent();
-  return BuildMI(BB, I, I.getDebugLoc(),
-                 TII.get(SPIRV::OpVariableLengthArrayINTEL))
-      .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(I.getOperand(2).getReg())
-      .constrainAllUses(TII, TRI, RBI);
+  bool Res = BuildMI(BB, I, I.getDebugLoc(),
+                     TII.get(SPIRV::OpVariableLengthArrayINTEL))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType))
+                 .addUse(I.getOperand(2).getReg())
+                 .constrainAllUses(TII, TRI, RBI);
+  if (!STI.isVulkanEnv()) {
+    unsigned Alignment = I.getOperand(3).getImm();
+    buildOpDecorate(ResVReg, I, TII, SPIRV::Decoration::Alignment, {Alignment});
+  }
+  return Res;
 }
 
 bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg,
@@ -3312,12 +3317,18 @@ bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg,
   // Change order of instructions if needed: all OpVariable instructions in a
   // function must be the first instructions in the first block
   auto It = getOpVariableMBBIt(I);
-  return BuildMI(*It->getParent(), It, It->getDebugLoc(),
-                 TII.get(SPIRV::OpVariable))
-      .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(ResType))
-      .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function))
-      .constrainAllUses(TII, TRI, RBI);
+  bool Res = BuildMI(*It->getParent(), It, It->getDebugLoc(),
+                     TII.get(SPIRV::OpVariable))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType))
+                 .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function))
+                 .constrainAllUses(TII, TRI, RBI);
+  if (!STI.isVulkanEnv()) {
+    unsigned Alignment = I.getOperand(2).getImm();
+    buildOpDecorate(ResVReg, *It, TII, SPIRV::Decoration::Alignment,
+                    {Alignment});
+  }
+  return Res;
 }
 
 bool SPIRVInstructionSelector::selectBranch(MachineInstr &I) const {

>From d3ed87cc5b58b5b117d41b2baeeeef448c6e9c26 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Dec 2024 06:53:12 -0800
Subject: [PATCH 2/5] type inference: use types parsed from demangled function
 declarations

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  34 ++--
 llvm/lib/Target/SPIRV/SPIRVBuiltins.h         |   3 +
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 151 ++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |   6 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   3 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  10 +-
 .../SPIRV/transcoding/spirv-event-null.ll     |   8 +
 7 files changed, 164 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 45a49674d4ca21..0f5cc14fb8e775 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2630,16 +2630,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   return false;
 }
 
-Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
-                                       unsigned ArgIdx, LLVMContext &Ctx) {
-  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
-  StringRef BuiltinArgs =
-      DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
-  BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
-  if (ArgIdx >= BuiltinArgsTypeStrs.size())
-    return nullptr;
-  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
-
+Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx) {
   // Parse strings representing OpenCL builtin types.
   if (hasBuiltinTypePrefix(TypeStr)) {
     // OpenCL builtin types in demangled call strings have the following format:
@@ -2683,6 +2674,29 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
   return BaseType;
 }
 
+bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
+                         const StringRef DemangledCall, LLVMContext &Ctx) {
+  auto Pos1 = DemangledCall.find('(');
+  if (Pos1 == StringRef::npos)
+    return false;
+  auto Pos2 = DemangledCall.find(')');
+  if (Pos2 == StringRef::npos || Pos1 > Pos2)
+    return false;
+  DemangledCall.slice(Pos1 + 1, Pos2)
+      .split(BuiltinArgsTypeStrs, ',', -1, false);
+  return true;
+}
+
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx) {
+  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
+  parseBuiltinTypeStr(BuiltinArgsTypeStrs, DemangledCall, Ctx);
+  if (ArgIdx >= BuiltinArgsTypeStrs.size())
+    return nullptr;
+  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
+  return parseBuiltinCallArgumentType(TypeStr, Ctx);
+}
+
 struct BuiltinType {
   StringRef Name;
   uint32_t Opcode;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index d07fc7c6ca874a..42b452db8b9fb4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -56,6 +56,9 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
 /// \p ArgIdx is the index of the argument to parse.
 Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
                                        unsigned ArgIdx, LLVMContext &Ctx);
+bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
+                         const StringRef DemangledCall, LLVMContext &Ctx);
+Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx);
 
 /// Translates a string representing a SPIR-V or OpenCL builtin type to a
 /// TargetExtType that can be further lowered with lowerBuiltinType().
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 7e8c669e676fb5..e7e4eb3bec0521 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -76,6 +76,10 @@ class SPIRVEmitIntrinsics
   DenseSet<Instruction *> AggrStores;
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
+  // map of function declarations to <pointer arg index => element type>
+  DenseMap<const Function *, SmallVector<std::pair<unsigned, Type *>>>
+      FDeclPtrTys;
+
   // a register of Instructions that don't have a complete type definition
   bool CanTodoType = true;
   unsigned TodoTypeSz = 0;
@@ -184,6 +188,10 @@ class SPIRVEmitIntrinsics
   void deduceOperandElementTypeFunctionPointer(
       CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
       Type *&KnownElemTy, bool IsPostprocessing);
+  bool deduceOperandElementTypeFunctionRet(
+      Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
+      const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
+      Type *&KnownElemTy, Value *Op, Function *F);
 
   CallInst *buildSpvPtrcast(Function *F, Value *Op, Type *ElemTy);
   void replaceUsesOfWithSpvPtrcast(Value *Op, Type *ElemTy, Instruction *I,
@@ -205,6 +213,7 @@ class SPIRVEmitIntrinsics
   bool runOnFunction(Function &F);
   bool postprocessTypes(Module &M);
   bool processFunctionPointers(Module &M);
+  void parseFunDeclarations(Module &M);
 
 public:
   static char ID;
@@ -957,6 +966,47 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
 }
 
+bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
+    Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
+    const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
+    Type *&KnownElemTy, Value *Op, Function *F) {
+  KnownElemTy = GR->findDeducedElementType(F);
+  if (KnownElemTy)
+    return false;
+  if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+    GR->addDeducedElementType(F, OpElemTy);
+    GR->addReturnType(
+        F, TypedPointerType::get(OpElemTy,
+                                 getPointerAddressSpace(F->getReturnType())));
+    // non-recursive update of types in function uses
+    DenseSet<std::pair<Value *, Value *>> VisitedSubst{std::make_pair(I, Op)};
+    for (User *U : F->users()) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI || CI->getCalledFunction() != F)
+        continue;
+      if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
+        if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
+          updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
+          propagateElemType(CI, PrevElemTy, VisitedSubst);
+        }
+      }
+    }
+    // Non-recursive update of types in the function uncomplete returns.
+    // This may happen just once per a function, the latch is a pair of
+    // findDeducedElementType(F) / addDeducedElementType(F, ...).
+    // With or without the latch it is a non-recursive call due to
+    // UncompleteRets set to nullptr in this call.
+    if (UncompleteRets)
+      for (Instruction *UncompleteRetI : *UncompleteRets)
+        deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
+                                 IsPostprocessing);
+  } else if (UncompleteRets) {
+    UncompleteRets->insert(I);
+  }
+  TypeValidated.insert(I);
+  return true;
+}
+
 // If the Instruction has Pointer operands with unresolved types, this function
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
@@ -1039,46 +1089,15 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
         Ops.push_back(std::make_pair(Op, i));
     }
   } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
-    Type *RetTy = CurrF->getReturnType();
-    if (!isPointerTy(RetTy))
+    if (!isPointerTy(CurrF->getReturnType()))
       return;
     Value *Op = Ref->getReturnValue();
     if (!Op)
       return;
-    if (!(KnownElemTy = GR->findDeducedElementType(CurrF))) {
-      if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
-        GR->addDeducedElementType(CurrF, OpElemTy);
-        GR->addReturnType(CurrF, TypedPointerType::get(
-                                     OpElemTy, getPointerAddressSpace(RetTy)));
-        // non-recursive update of types in function uses
-        DenseSet<std::pair<Value *, Value *>> VisitedSubst{
-            std::make_pair(I, Op)};
-        for (User *U : CurrF->users()) {
-          CallInst *CI = dyn_cast<CallInst>(U);
-          if (!CI || CI->getCalledFunction() != CurrF)
-            continue;
-          if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
-            if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
-              updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
-              propagateElemType(CI, PrevElemTy, VisitedSubst);
-            }
-          }
-        }
-        // Non-recursive update of types in the function uncomplete returns.
-        // This may happen just once per a function, the latch is a pair of
-        // findDeducedElementType(F) / addDeducedElementType(F, ...).
-        // With or without the latch it is a non-recursive call due to
-        // UncompleteRets set to nullptr in this call.
-        if (UncompleteRets)
-          for (Instruction *UncompleteRetI : *UncompleteRets)
-            deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
-                                     IsPostprocessing);
-      } else if (UncompleteRets) {
-        UncompleteRets->insert(I);
-      }
-      TypeValidated.insert(I);
+    if (deduceOperandElementTypeFunctionRet(I, UncompleteRets, AskOps,
+                                            IsPostprocessing, KnownElemTy, Op,
+                                            CurrF))
       return;
-    }
     Uncomplete = isTodoType(CurrF);
     Ops.push_back(std::make_pair(Op, 0));
   } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
@@ -2157,6 +2176,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
+  DenseMap<Function *, DenseMap<unsigned, Type *>> FDeclPtrTys;
+
   processParamTypesByFunHeader(CurrF, B);
 
   // StoreInst's operand type can be changed during the next transformations,
@@ -2180,6 +2201,31 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   for (auto &I : instructions(Func))
     Worklist.push_back(&I);
 
+  // Apply types parsed from demangled function declarations.
+  for (auto &I : Worklist) {
+    CallInst *CI = dyn_cast<CallInst>(I);
+    if (!CI || !CI->getCalledFunction())
+      continue;
+    auto It = FDeclPtrTys.find(CI->getCalledFunction());
+    if (It == FDeclPtrTys.end())
+      continue;
+    unsigned Sz = CI->arg_size();
+    for (auto [Idx, ElemTy] : It->second)
+      if (Idx < Sz) {
+        Value *Arg = CI->getArgOperand(Idx);
+        GR->addDeducedElementType(Arg, ElemTy);
+        if (CallInst *Ref = dyn_cast<CallInst>(Arg))
+          if (Function *RefF = Ref->getCalledFunction();
+              RefF && isPointerTy(RefF->getReturnType()) &&
+              !GR->findDeducedElementType(RefF)) {
+            GR->addDeducedElementType(RefF, ElemTy);
+            GR->addReturnType(RefF, TypedPointerType::get(
+                                        ElemTy, getPointerAddressSpace(
+                                                    RefF->getReturnType())));
+          }
+      }
+  }
+
   // Pass forward: use operand to deduce instructions result.
   for (auto &I : Worklist) {
     // Don't emit intrinsincs for convergence intrinsics.
@@ -2287,9 +2333,44 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
   return SzTodo > TodoTypeSz;
 }
 
+// Parse and store argument types of function declarations where needed.
+void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
+  for (auto &F : M) {
+    if (!F.isDeclaration() || F.isIntrinsic())
+      continue;
+    // get the demangled name
+    std::string DemangledName = getOclOrSpirvBuiltinDemangledName(F.getName());
+    if (DemangledName.empty())
+      continue;
+    // find pointer arguments
+    SmallVector<unsigned> Idxs;
+    for (unsigned OpIdx = 0; OpIdx < F.arg_size(); ++OpIdx)
+      if (isPointerTy(F.getArg(OpIdx)->getType()))
+        Idxs.push_back(OpIdx);
+    if (!Idxs.size())
+      continue;
+    // parse function arguments
+    LLVMContext &Ctx = F.getContext();
+    SmallVector<StringRef, 10> TypeStrs;
+    SPIRV::parseBuiltinTypeStr(TypeStrs, DemangledName, Ctx);
+    if (!TypeStrs.size())
+      continue;
+    // find type info for pointer arguments
+    for (unsigned Idx : Idxs) {
+      if (Idx >= TypeStrs.size())
+        continue;
+      if (Type *ElemTy =
+              SPIRV::parseBuiltinCallArgumentType(TypeStrs[Idx].trim(), Ctx))
+        FDeclPtrTys[&F].push_back(std::make_pair(Idx, ElemTy));
+    }
+  }
+}
+
 bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   bool Changed = false;
 
+  parseFunDeclarations(M);
+
   TodoType.clear();
   for (auto &F : M)
     Changed |= runOnFunction(F);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9ac659f6b4f111..91b9cbcf15128c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -325,8 +325,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
 
 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
                                                MachineIRBuilder &MIRBuilder,
-                                               SPIRVType *SpvType,
-                                               bool EmitIR) {
+                                               SPIRVType *SpvType, bool EmitIR,
+                                               bool ZeroAsNull) {
   assert(SpvType);
   auto &MF = MIRBuilder.getMF();
   const IntegerType *LLVMIntTy =
@@ -348,7 +348,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
     } else {
       Register SpvTypeReg = getSPIRVTypeID(SpvType);
       MachineInstrBuilder MIB;
-      if (Val) {
+      if (Val || !ZeroAsNull) {
         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
                   .addDef(Res)
                   .addUse(SpvTypeReg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ff4b0ea8757fa4..df92325ed19802 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -509,7 +509,8 @@ class SPIRVGlobalRegistry {
 
 public:
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
-                            SPIRVType *SpvType, bool EmitIR = true);
+                            SPIRVType *SpvType, bool EmitIR = true,
+                            bool ZeroAsNull = true);
   Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 7a1914aac8ceb8..1b9d3088756462 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -447,25 +447,31 @@ Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
   TypeName.consume_front("atomic_");
   if (TypeName.consume_front("void"))
     return Type::getVoidTy(Ctx);
-  else if (TypeName.consume_front("bool"))
+  else if (TypeName.consume_front("bool") || TypeName.consume_front("_Bool"))
     return Type::getIntNTy(Ctx, 1);
   else if (TypeName.consume_front("char") ||
+           TypeName.consume_front("signed char") ||
            TypeName.consume_front("unsigned char") ||
            TypeName.consume_front("uchar"))
     return Type::getInt8Ty(Ctx);
   else if (TypeName.consume_front("short") ||
+           TypeName.consume_front("signed short") ||
            TypeName.consume_front("unsigned short") ||
            TypeName.consume_front("ushort"))
     return Type::getInt16Ty(Ctx);
   else if (TypeName.consume_front("int") ||
+           TypeName.consume_front("signed int") ||
            TypeName.consume_front("unsigned int") ||
            TypeName.consume_front("uint"))
     return Type::getInt32Ty(Ctx);
   else if (TypeName.consume_front("long") ||
+           TypeName.consume_front("signed long") ||
            TypeName.consume_front("unsigned long") ||
            TypeName.consume_front("ulong"))
     return Type::getInt64Ty(Ctx);
-  else if (TypeName.consume_front("half"))
+  else if (TypeName.consume_front("half") ||
+           TypeName.consume_front("_Float16") ||
+           TypeName.consume_front("__fp16"))
     return Type::getHalfTy(Ctx);
   else if (TypeName.consume_front("float"))
     return Type::getFloatTy(Ctx);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index e512f909cfd059..91738634ff233b 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -32,6 +32,14 @@
 
 %StructEvent = type { target("spirv.Event") }
 
+define spir_kernel void @test_half(ptr addrspace(3) %_arg1, ptr addrspace(1) %_arg2) {
+entry:
+  %r = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg1, ptr addrspace(1) %_arg2, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+  ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 noundef, ptr addrspace(3) noundef, ptr addrspace(1) noundef, i64 noundef, i64 noundef, target("spirv.Event"))
+
 define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
 entry:
   %var = alloca %StructEvent

>From 4cc9bffcbdd697efad4d9420ab823b394abdb4f0 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Dec 2024 09:05:38 -0800
Subject: [PATCH 3/5] fix applying of deduced types

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 46 +++++++++----------
 1 file changed, 23 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e7e4eb3bec0521..b6526b891977cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -77,8 +77,7 @@ class SPIRVEmitIntrinsics
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // map of function declarations to <pointer arg index => element type>
-  DenseMap<const Function *, SmallVector<std::pair<unsigned, Type *>>>
-      FDeclPtrTys;
+  DenseMap<Function *, SmallVector<std::pair<unsigned, Type *>>> FDeclPtrTys;
 
   // a register of Instructions that don't have a complete type definition
   bool CanTodoType = true;
@@ -2176,8 +2175,6 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
-  DenseMap<Function *, DenseMap<unsigned, Type *>> FDeclPtrTys;
-
   processParamTypesByFunHeader(CurrF, B);
 
   // StoreInst's operand type can be changed during the next transformations,
@@ -2202,28 +2199,31 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     Worklist.push_back(&I);
 
   // Apply types parsed from demangled function declarations.
-  for (auto &I : Worklist) {
-    CallInst *CI = dyn_cast<CallInst>(I);
-    if (!CI || !CI->getCalledFunction())
-      continue;
-    auto It = FDeclPtrTys.find(CI->getCalledFunction());
-    if (It == FDeclPtrTys.end())
-      continue;
-    unsigned Sz = CI->arg_size();
-    for (auto [Idx, ElemTy] : It->second)
-      if (Idx < Sz) {
+  for (auto It : FDeclPtrTys) {
+    Function *F = It.first;
+    for (auto *U : F->users()) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI || CI->getCalledFunction() != F)
+        continue;
+      unsigned Sz = CI->arg_size();
+      for (auto [Idx, ElemTy] : It.second) {
+        if (Idx >= Sz)
+          continue;
         Value *Arg = CI->getArgOperand(Idx);
         GR->addDeducedElementType(Arg, ElemTy);
-        if (CallInst *Ref = dyn_cast<CallInst>(Arg))
-          if (Function *RefF = Ref->getCalledFunction();
-              RefF && isPointerTy(RefF->getReturnType()) &&
-              !GR->findDeducedElementType(RefF)) {
-            GR->addDeducedElementType(RefF, ElemTy);
-            GR->addReturnType(RefF, TypedPointerType::get(
-                                        ElemTy, getPointerAddressSpace(
-                                                    RefF->getReturnType())));
-          }
+        CallInst *Ref = dyn_cast<CallInst>(Arg);
+        if (!Ref)
+          continue;
+        Function *RefF = Ref->getCalledFunction();
+        if (!RefF || !isPointerTy(RefF->getReturnType()) ||
+            GR->findDeducedElementType(RefF))
+          continue;
+        GR->addDeducedElementType(RefF, ElemTy);
+        GR->addReturnType(
+            RefF, TypedPointerType::get(
+                      ElemTy, getPointerAddressSpace(RefF->getReturnType())));
       }
+    }
   }
 
   // Pass forward: use operand to deduce instructions result.

>From 7087945dbad3cdf169c3c00e0cca907480ad044c Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Dec 2024 13:49:47 -0800
Subject: [PATCH 4/5] fixes and test cases

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  2 +-
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 87 ++++++++++++-------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  2 +-
 llvm/test/CodeGen/SPIRV/opencl/vload2.ll      | 29 ++++---
 .../SPIRV/transcoding/spirv-event-null.ll     | 47 +++++-----
 5 files changed, 97 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index e8e853c5c758a6..fa37313f8247c4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -316,7 +316,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
 
       if (Arg.hasName())
         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
-      if (isPointerTy(Arg.getType())) {
+      if (isPointerTyOrWrapper(Arg.getType())) {
         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
         if (DerefBytes != 0)
           buildOpDecorate(VRegs[i][0], MIRBuilder,
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index b6526b891977cd..8608eaa79e5077 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -209,6 +209,8 @@ class SPIRVEmitIntrinsics
   void replaceAllUsesWithAndErase(IRBuilder<> &B, Instruction *Src,
                                   Instruction *Dest, bool DeleteOld = true);
 
+  void applyDemangledPtrArgTypes(IRBuilder<> &B);
+
   bool runOnFunction(Function &F);
   bool postprocessTypes(Module &M);
   bool processFunctionPointers(Module &M);
@@ -2156,6 +2158,53 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
   return true;
 }
 
+// Apply types parsed from demangled function declarations.
+void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
+  for (auto It : FDeclPtrTys) {
+    Function *F = It.first;
+    for (auto *U : F->users()) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI || CI->getCalledFunction() != F)
+        continue;
+      unsigned Sz = CI->arg_size();
+      for (auto [Idx, ElemTy] : It.second) {
+        if (Idx >= Sz)
+          continue;
+        Value *Param = CI->getArgOperand(Idx);
+        if (GR->findDeducedElementType(Param) || isa<GlobalValue>(Param))
+          continue;
+        if (Argument *Arg = dyn_cast<Argument>(Param)) {
+          if (!hasPointeeTypeAttr(Arg)) {
+            B.SetInsertPointPastAllocas(Arg->getParent());
+            B.SetCurrentDebugLocation(DebugLoc());
+            buildAssignPtr(B, ElemTy, Arg);
+          }
+        } else if (isa<Instruction>(Param)) {
+          GR->addDeducedElementType(Param, ElemTy);
+          // insertAssignTypeIntrs() will complete buildAssignPtr()
+        } else {
+          B.SetInsertPoint(CI->getParent()
+                               ->getParent()
+                               ->getEntryBlock()
+                               .getFirstNonPHIOrDbgOrAlloca());
+          buildAssignPtr(B, ElemTy, Param);
+        }
+        CallInst *Ref = dyn_cast<CallInst>(Param);
+        if (!Ref)
+          continue;
+        Function *RefF = Ref->getCalledFunction();
+        if (!RefF || !isPointerTy(RefF->getReturnType()) ||
+            GR->findDeducedElementType(RefF))
+          continue;
+        GR->addDeducedElementType(RefF, ElemTy);
+        GR->addReturnType(
+            RefF, TypedPointerType::get(
+                      ElemTy, getPointerAddressSpace(RefF->getReturnType())));
+      }
+    }
+  }
+}
+
 bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   if (Func.isDeclaration())
     return false;
@@ -2198,33 +2247,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   for (auto &I : instructions(Func))
     Worklist.push_back(&I);
 
-  // Apply types parsed from demangled function declarations.
-  for (auto It : FDeclPtrTys) {
-    Function *F = It.first;
-    for (auto *U : F->users()) {
-      CallInst *CI = dyn_cast<CallInst>(U);
-      if (!CI || CI->getCalledFunction() != F)
-        continue;
-      unsigned Sz = CI->arg_size();
-      for (auto [Idx, ElemTy] : It.second) {
-        if (Idx >= Sz)
-          continue;
-        Value *Arg = CI->getArgOperand(Idx);
-        GR->addDeducedElementType(Arg, ElemTy);
-        CallInst *Ref = dyn_cast<CallInst>(Arg);
-        if (!Ref)
-          continue;
-        Function *RefF = Ref->getCalledFunction();
-        if (!RefF || !isPointerTy(RefF->getReturnType()) ||
-            GR->findDeducedElementType(RefF))
-          continue;
-        GR->addDeducedElementType(RefF, ElemTy);
-        GR->addReturnType(
-            RefF, TypedPointerType::get(
-                      ElemTy, getPointerAddressSpace(RefF->getReturnType())));
-      }
-    }
-  }
+  applyDemangledPtrArgTypes(B);
 
   // Pass forward: use operand to deduce instructions result.
   for (auto &I : Worklist) {
@@ -2344,9 +2367,11 @@ void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
       continue;
     // find pointer arguments
     SmallVector<unsigned> Idxs;
-    for (unsigned OpIdx = 0; OpIdx < F.arg_size(); ++OpIdx)
-      if (isPointerTy(F.getArg(OpIdx)->getType()))
+    for (unsigned OpIdx = 0; OpIdx < F.arg_size(); ++OpIdx) {
+      Argument *Arg = F.getArg(OpIdx);
+      if (isPointerTy(Arg->getType()) && !hasPointeeTypeAttr(Arg))
         Idxs.push_back(OpIdx);
+    }
     if (!Idxs.size())
       continue;
     // parse function arguments
@@ -2361,7 +2386,9 @@ void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
         continue;
       if (Type *ElemTy =
               SPIRV::parseBuiltinCallArgumentType(TypeStrs[Idx].trim(), Ctx))
-        FDeclPtrTys[&F].push_back(std::make_pair(Idx, ElemTy));
+        if (TypedPointerType::isValidElementType(ElemTy) &&
+            !ElemTy->isTargetExtTy())
+          FDeclPtrTys[&F].push_back(std::make_pair(Idx, ElemTy));
     }
   }
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 91b9cbcf15128c..fabedb5e06c1d5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1022,7 +1022,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
-  if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))
+  if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty))
     return nullptr;
   TypesInProcessing.insert(Ty);
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
diff --git a/llvm/test/CodeGen/SPIRV/opencl/vload2.ll b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
index 592de33d4d3938..1a1b6c484e74ff 100644
--- a/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
+++ b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
@@ -1,9 +1,14 @@
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-; This test only intends to check the vloadn builtin name resolution.
-; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
 
+; CHECK-DAG: OpName %[[#CALL1:]] "call1"
+; CHECK-DAG: OpName %[[#CALL2:]] "call2"
+; CHECK-DAG: OpName %[[#CALL3:]] "call3"
+; CHECK-DAG: OpName %[[#CALL4:]] "call4"
+; CHECK-DAG: OpName %[[#CALL5:]] "call5"
+
 ; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
 ; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
@@ -23,20 +28,20 @@
 ; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
 
 define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
-; CHECK: %[[#CASTorPARAMofPTRI8:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT8]]{{.*}}
-; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI8]] 2
+; CHECK-DAG: %[[#CASTorPARAMofPTRI8:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT8]]{{.*}}
+; CHECK-DAG: %[[#CALL1]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI8]] 2
   %call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
-; CHECK: %[[#CASTorPARAMofPTRI16:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT16]]{{.*}}
-; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI16]] 2
+; CHECK-DAG: %[[#CASTorPARAMofPTRI16:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT16]]{{.*}}
+; CHECK-DAG: %[[#CALL2]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI16]] 2
   %call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
-; CHECK: %[[#CASTorPARAMofPTRI32:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT32]]{{.*}}
-; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI32]] 2
+; CHECK-DAG: %[[#CASTorPARAMofPTRI32:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT32]]{{.*}}
+; CHECK-DAG: %[[#CALL3]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI32]] 2
   %call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
-; CHECK: %[[#CASTorPARAMofPTRI64:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT64]]{{.*}}
-; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI64]] 2
+; CHECK-DAG: %[[#CASTorPARAMofPTRI64:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRINT64]]{{.*}}
+; CHECK-DAG: %[[#CALL4]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRI64]] 2
   %call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
-; CHECK: %[[#CASTorPARAMofPTRFLOAT:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRFLOAT]]{{.*}}
-; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRFLOAT]] 2
+; CHECK-DAG: %[[#CASTorPARAMofPTRFLOAT:]] = {{OpBitcast|OpFunctionParameter}}{{.*}}%[[#PTRFLOAT]]{{.*}}
+; CHECK-DAG: %[[#CALL5]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#CASTorPARAMofPTRFLOAT]] 2
   %call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
   ret void
 }
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index 91738634ff233b..b5330fad9016c1 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -13,21 +13,20 @@
 ; CHECK-DAG: %[[#TyChar:]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[#TyV4:]] = OpTypeVector %[[#TyChar]] 4
 ; CHECK-DAG: %[[#TyStructV4:]] = OpTypeStruct %[[#TyV4]]
-; CHECK-DAG: %[[#TyPtrSV4_W:]] = OpTypePointer Workgroup %[[#TyStructV4]]
 ; CHECK-DAG: %[[#TyPtrSV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyStructV4]]
 ; CHECK-DAG: %[[#TyPtrV4_W:]] = OpTypePointer Workgroup %[[#TyV4]]
 ; CHECK-DAG: %[[#TyPtrV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyV4]]
+; CHECK-DAG: %[[#TyHalf:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#TyHalfV2:]] = OpTypeVector %[[#TyHalf]] 2
+; CHECK-DAG: %[[#TyHalfV2_W:]] = OpTypePointer Workgroup %[[#TyHalfV2]]
+; CHECK-DAG: %[[#TyHalfV2_CW:]] = OpTypePointer CrossWorkgroup %[[#TyHalfV2]]
 
 ; Check correct translation of __spirv_GroupAsyncCopy and target("spirv.Event") zeroinitializer
 
 ; CHECK: OpFunction
-; CHECK: OpFunctionParameter
-; CHECK: %[[#Src:]] = OpFunctionParameter
-; CHECK: OpVariable %[[#TyStructPtr]] Function
-; CHECK: %[[#EventVar:]] = OpVariable %[[#TyEventPtr]] Function
-; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
-; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
-; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
+; CHECK: %[[#HalfA1:]] = OpFunctionParameter %[[#TyHalfV2_W:]]
+; CHECK: %[[#HalfA2:]] = OpFunctionParameter %[[#TyHalfV2_CW:]]
+; CHECK: OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#HalfA1]] %[[#HalfA2]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: OpFunctionEnd
 
 %StructEvent = type { target("spirv.Event") }
@@ -40,6 +39,16 @@ entry:
 
 declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 noundef, ptr addrspace(3) noundef, ptr addrspace(1) noundef, i64 noundef, i64 noundef, target("spirv.Event"))
 
+; CHECK: OpFunction
+; CHECK: OpFunctionParameter
+; CHECK: %[[#Src:]] = OpFunctionParameter
+; CHECK: OpVariable %[[#TyStructPtr]] Function
+; CHECK: %[[#EventVar:]] = OpVariable %[[#TyEventPtr]] Function
+; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
+; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
+; CHECK: OpFunctionEnd
+
 define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
 entry:
   %var = alloca %StructEvent
@@ -58,33 +67,19 @@ declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU
 ; and %_arg_Local and %_arg are source/destination arguments in OpGroupAsyncCopy
 
 ; CHECK: OpFunction
-; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
+; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrV4_W]]
 ; CHECK: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
 ; CHECK: %[[#EventVarBar:]] = OpVariable %[[#TyStructPtr]] Function
 ; CHECK: %[[#EventVarBarCasted2:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
-; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
-; CHECK-DAG: %[[#BarArg1Casted:]] = OpBitcast %[[#TyPtrV4_W]] %[[#BarArg1]]
-; CHECK-DAG: %[[#SrcBarCasted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#SrcBar]]
-; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1Casted]] %[[#SrcBarCasted]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK: %[[#BarArg2Casted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#BarArg2]]
+; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrV4_CW]] %[[#BarArg2Casted]] %[[#]]
+; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1]] %[[#SrcBar]] %[[#]] %[[#]] %[[#ConstEvent]]
 ; CHECK: %[[#EventVarBarCasted:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
 ; CHECK: OpStore %[[#EventVarBarCasted]] %[[#ResBar]]
 ; CHECK: %[[#EventVarBarGen:]] = OpPtrCastToGeneric %[[#TyEventPtrGen]] %[[#EventVarBarCasted2]]
 ; CHECK: OpGroupWaitEvents %[[#]] %[[#]] %[[#EventVarBarGen]]
 ; CHECK: OpFunctionEnd
 
-; CHECK2: OpFunction
-; CHECK2: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
-; CHECK2: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
-; CHECK2: %[[#EventVarBar:]] = OpVariable %[[#TyEventPtr]] Function
-; CHECK2: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
-; CHECK2-DAG: %[[#BarArg1Casted:]] = OpBitcast %[[#TyPtrV4_W]] %[[#BarArg1]]
-; CHECK2-DAG: %[[#SrcBarCasted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#SrcBar]]
-; CHECK2: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1Casted]] %[[#SrcBarCasted]] %[[#]] %[[#]] %[[#ConstEvent]]
-; CHECK2: OpStore %[[#EventVarBar]] %[[#ResBar]]
-; CHECK2: %[[#EventVarBarGen:]] = OpPtrCastToGeneric %[[#TyEventPtrGen]] %[[#EventVarBar]]
-; CHECK2: OpGroupWaitEvents %[[#]] %[[#]] %[[#EventVarBarGen]]
-; CHECK2: OpFunctionEnd
-
 %Vec4 = type { <4 x i8> }
 
 define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg) {

>From 5c14dc3c9424b59589e62a40a6b791f55026d178 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Dec 2024 14:14:41 -0800
Subject: [PATCH 5/5] restrict changes to the single use case of
 OpGroupAsyncCopy

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 8608eaa79e5077..2b623136e602e5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -2365,6 +2365,11 @@ void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
     std::string DemangledName = getOclOrSpirvBuiltinDemangledName(F.getName());
     if (DemangledName.empty())
       continue;
+    // allow only OpGroupAsyncCopy use case at the moment
+    auto [Grp, Opcode, ExtNo] =
+        SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+    if (Opcode != SPIRV::OpGroupAsyncCopy)
+      continue;
     // find pointer arguments
     SmallVector<unsigned> Idxs;
     for (unsigned OpIdx = 0; OpIdx < F.arg_size(); ++OpIdx) {



More information about the llvm-commits mailing list