[llvm] [SPIR-V] Emit Alignment decoration for alloca instructions (PR #118520)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 4 06:53:29 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/118520

>From 066b4ba899225babc34a7c6ab69332c734a8ab02 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 3 Dec 2024 09:36:07 -0800
Subject: [PATCH 1/2] emit Alignment decoration for alloca's

---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  4 +--
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  9 +++--
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 35 ++++++++++++-------
 3 files changed, 31 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 17b70062e58fa9..1ae3129774e507 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -36,8 +36,8 @@ let TargetPrefix = "spv" in {
   def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
-  def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
-  def int_spv_alloca_array : Intrinsic<[llvm_any_ty], [llvm_anyint_ty]>;
+  def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
+  def int_spv_alloca_array : Intrinsic<[llvm_any_ty], [llvm_anyint_ty, llvm_i8_ty], [ImmArg<ArgIndex<1>>]>;
   def int_spv_undef : Intrinsic<[llvm_i32_ty], []>;
   def int_spv_inline_asm : Intrinsic<[], [llvm_metadata_ty, llvm_metadata_ty, llvm_vararg_ty]>;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index f45bdfc7aacb72..7e8c669e676fb5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1713,9 +1713,12 @@ Instruction *SPIRVEmitIntrinsics::visitAllocaInst(AllocaInst &I) {
   TrackConstants = false;
   Type *PtrTy = I.getType();
   auto *NewI =
-      ArraySize ? B.CreateIntrinsic(Intrinsic::spv_alloca_array,
-                                    {PtrTy, ArraySize->getType()}, {ArraySize})
-                : B.CreateIntrinsic(Intrinsic::spv_alloca, {PtrTy}, {});
+      ArraySize
+          ? B.CreateIntrinsic(Intrinsic::spv_alloca_array,
+                              {PtrTy, ArraySize->getType()},
+                              {ArraySize, B.getInt8(I.getAlign().value())})
+          : B.CreateIntrinsic(Intrinsic::spv_alloca, {PtrTy},
+                              {B.getInt8(I.getAlign().value())});
   replaceAllUsesWithAndErase(B, &I, NewI);
   return NewI;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3547ac66430a87..3a98b74b3d6757 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3298,12 +3298,17 @@ bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
   // there was an allocation size parameter to the allocation instruction
   // that is not 1
   MachineBasicBlock &BB = *I.getParent();
-  return BuildMI(BB, I, I.getDebugLoc(),
-                 TII.get(SPIRV::OpVariableLengthArrayINTEL))
-      .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(I.getOperand(2).getReg())
-      .constrainAllUses(TII, TRI, RBI);
+  bool Res = BuildMI(BB, I, I.getDebugLoc(),
+                     TII.get(SPIRV::OpVariableLengthArrayINTEL))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType))
+                 .addUse(I.getOperand(2).getReg())
+                 .constrainAllUses(TII, TRI, RBI);
+  if (!STI.isVulkanEnv()) {
+    unsigned Alignment = I.getOperand(3).getImm();
+    buildOpDecorate(ResVReg, I, TII, SPIRV::Decoration::Alignment, {Alignment});
+  }
+  return Res;
 }
 
 bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg,
@@ -3312,12 +3317,18 @@ bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg,
   // Change order of instructions if needed: all OpVariable instructions in a
   // function must be the first instructions in the first block
   auto It = getOpVariableMBBIt(I);
-  return BuildMI(*It->getParent(), It, It->getDebugLoc(),
-                 TII.get(SPIRV::OpVariable))
-      .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(ResType))
-      .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function))
-      .constrainAllUses(TII, TRI, RBI);
+  bool Res = BuildMI(*It->getParent(), It, It->getDebugLoc(),
+                     TII.get(SPIRV::OpVariable))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType))
+                 .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function))
+                 .constrainAllUses(TII, TRI, RBI);
+  if (!STI.isVulkanEnv()) {
+    unsigned Alignment = I.getOperand(2).getImm();
+    buildOpDecorate(ResVReg, *It, TII, SPIRV::Decoration::Alignment,
+                    {Alignment});
+  }
+  return Res;
 }
 
 bool SPIRVInstructionSelector::selectBranch(MachineInstr &I) const {

>From d3ed87cc5b58b5b117d41b2baeeeef448c6e9c26 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Dec 2024 06:53:12 -0800
Subject: [PATCH 2/2] type inference: use types parsed from demangled function
 declarations

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  34 ++--
 llvm/lib/Target/SPIRV/SPIRVBuiltins.h         |   3 +
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 151 ++++++++++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |   6 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   3 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  10 +-
 .../SPIRV/transcoding/spirv-event-null.ll     |   8 +
 7 files changed, 164 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 45a49674d4ca21..0f5cc14fb8e775 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2630,16 +2630,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   return false;
 }
 
-Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
-                                       unsigned ArgIdx, LLVMContext &Ctx) {
-  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
-  StringRef BuiltinArgs =
-      DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
-  BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
-  if (ArgIdx >= BuiltinArgsTypeStrs.size())
-    return nullptr;
-  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
-
+Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx) {
   // Parse strings representing OpenCL builtin types.
   if (hasBuiltinTypePrefix(TypeStr)) {
     // OpenCL builtin types in demangled call strings have the following format:
@@ -2683,6 +2674,29 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
   return BaseType;
 }
 
+bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
+                         const StringRef DemangledCall, LLVMContext &Ctx) {
+  auto Pos1 = DemangledCall.find('(');
+  if (Pos1 == StringRef::npos)
+    return false;
+  auto Pos2 = DemangledCall.find(')');
+  if (Pos2 == StringRef::npos || Pos1 > Pos2)
+    return false;
+  DemangledCall.slice(Pos1 + 1, Pos2)
+      .split(BuiltinArgsTypeStrs, ',', -1, false);
+  return true;
+}
+
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx) {
+  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
+  parseBuiltinTypeStr(BuiltinArgsTypeStrs, DemangledCall, Ctx);
+  if (ArgIdx >= BuiltinArgsTypeStrs.size())
+    return nullptr;
+  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
+  return parseBuiltinCallArgumentType(TypeStr, Ctx);
+}
+
 struct BuiltinType {
   StringRef Name;
   uint32_t Opcode;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index d07fc7c6ca874a..42b452db8b9fb4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -56,6 +56,9 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
 /// \p ArgIdx is the index of the argument to parse.
 Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
                                        unsigned ArgIdx, LLVMContext &Ctx);
+bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
+                         const StringRef DemangledCall, LLVMContext &Ctx);
+Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx);
 
 /// Translates a string representing a SPIR-V or OpenCL builtin type to a
 /// TargetExtType that can be further lowered with lowerBuiltinType().
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 7e8c669e676fb5..e7e4eb3bec0521 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -76,6 +76,10 @@ class SPIRVEmitIntrinsics
   DenseSet<Instruction *> AggrStores;
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
+  // map of function declarations to <pointer arg index => element type>
+  DenseMap<const Function *, SmallVector<std::pair<unsigned, Type *>>>
+      FDeclPtrTys;
+
   // a register of Instructions that don't have a complete type definition
   bool CanTodoType = true;
   unsigned TodoTypeSz = 0;
@@ -184,6 +188,10 @@ class SPIRVEmitIntrinsics
   void deduceOperandElementTypeFunctionPointer(
       CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
       Type *&KnownElemTy, bool IsPostprocessing);
+  bool deduceOperandElementTypeFunctionRet(
+      Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
+      const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
+      Type *&KnownElemTy, Value *Op, Function *F);
 
   CallInst *buildSpvPtrcast(Function *F, Value *Op, Type *ElemTy);
   void replaceUsesOfWithSpvPtrcast(Value *Op, Type *ElemTy, Instruction *I,
@@ -205,6 +213,7 @@ class SPIRVEmitIntrinsics
   bool runOnFunction(Function &F);
   bool postprocessTypes(Module &M);
   bool processFunctionPointers(Module &M);
+  void parseFunDeclarations(Module &M);
 
 public:
   static char ID;
@@ -957,6 +966,47 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
 }
 
+bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
+    Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
+    const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
+    Type *&KnownElemTy, Value *Op, Function *F) {
+  KnownElemTy = GR->findDeducedElementType(F);
+  if (KnownElemTy)
+    return false;
+  if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+    GR->addDeducedElementType(F, OpElemTy);
+    GR->addReturnType(
+        F, TypedPointerType::get(OpElemTy,
+                                 getPointerAddressSpace(F->getReturnType())));
+    // non-recursive update of types in function uses
+    DenseSet<std::pair<Value *, Value *>> VisitedSubst{std::make_pair(I, Op)};
+    for (User *U : F->users()) {
+      CallInst *CI = dyn_cast<CallInst>(U);
+      if (!CI || CI->getCalledFunction() != F)
+        continue;
+      if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
+        if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
+          updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
+          propagateElemType(CI, PrevElemTy, VisitedSubst);
+        }
+      }
+    }
+    // Non-recursive update of types in the function uncomplete returns.
+    // This may happen just once per a function, the latch is a pair of
+    // findDeducedElementType(F) / addDeducedElementType(F, ...).
+    // With or without the latch it is a non-recursive call due to
+    // UncompleteRets set to nullptr in this call.
+    if (UncompleteRets)
+      for (Instruction *UncompleteRetI : *UncompleteRets)
+        deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
+                                 IsPostprocessing);
+  } else if (UncompleteRets) {
+    UncompleteRets->insert(I);
+  }
+  TypeValidated.insert(I);
+  return true;
+}
+
 // If the Instruction has Pointer operands with unresolved types, this function
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
@@ -1039,46 +1089,15 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
         Ops.push_back(std::make_pair(Op, i));
     }
   } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
-    Type *RetTy = CurrF->getReturnType();
-    if (!isPointerTy(RetTy))
+    if (!isPointerTy(CurrF->getReturnType()))
       return;
     Value *Op = Ref->getReturnValue();
     if (!Op)
       return;
-    if (!(KnownElemTy = GR->findDeducedElementType(CurrF))) {
-      if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
-        GR->addDeducedElementType(CurrF, OpElemTy);
-        GR->addReturnType(CurrF, TypedPointerType::get(
-                                     OpElemTy, getPointerAddressSpace(RetTy)));
-        // non-recursive update of types in function uses
-        DenseSet<std::pair<Value *, Value *>> VisitedSubst{
-            std::make_pair(I, Op)};
-        for (User *U : CurrF->users()) {
-          CallInst *CI = dyn_cast<CallInst>(U);
-          if (!CI || CI->getCalledFunction() != CurrF)
-            continue;
-          if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
-            if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
-              updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
-              propagateElemType(CI, PrevElemTy, VisitedSubst);
-            }
-          }
-        }
-        // Non-recursive update of types in the function uncomplete returns.
-        // This may happen just once per a function, the latch is a pair of
-        // findDeducedElementType(F) / addDeducedElementType(F, ...).
-        // With or without the latch it is a non-recursive call due to
-        // UncompleteRets set to nullptr in this call.
-        if (UncompleteRets)
-          for (Instruction *UncompleteRetI : *UncompleteRets)
-            deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
-                                     IsPostprocessing);
-      } else if (UncompleteRets) {
-        UncompleteRets->insert(I);
-      }
-      TypeValidated.insert(I);
+    if (deduceOperandElementTypeFunctionRet(I, UncompleteRets, AskOps,
+                                            IsPostprocessing, KnownElemTy, Op,
+                                            CurrF))
       return;
-    }
     Uncomplete = isTodoType(CurrF);
     Ops.push_back(std::make_pair(Op, 0));
   } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
@@ -2157,6 +2176,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
+  DenseMap<Function *, DenseMap<unsigned, Type *>> FDeclPtrTys;
+
   processParamTypesByFunHeader(CurrF, B);
 
   // StoreInst's operand type can be changed during the next transformations,
@@ -2180,6 +2201,31 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   for (auto &I : instructions(Func))
     Worklist.push_back(&I);
 
+  // Apply types parsed from demangled function declarations.
+  for (auto &I : Worklist) {
+    CallInst *CI = dyn_cast<CallInst>(I);
+    if (!CI || !CI->getCalledFunction())
+      continue;
+    auto It = FDeclPtrTys.find(CI->getCalledFunction());
+    if (It == FDeclPtrTys.end())
+      continue;
+    unsigned Sz = CI->arg_size();
+    for (auto [Idx, ElemTy] : It->second)
+      if (Idx < Sz) {
+        Value *Arg = CI->getArgOperand(Idx);
+        GR->addDeducedElementType(Arg, ElemTy);
+        if (CallInst *Ref = dyn_cast<CallInst>(Arg))
+          if (Function *RefF = Ref->getCalledFunction();
+              RefF && isPointerTy(RefF->getReturnType()) &&
+              !GR->findDeducedElementType(RefF)) {
+            GR->addDeducedElementType(RefF, ElemTy);
+            GR->addReturnType(RefF, TypedPointerType::get(
+                                        ElemTy, getPointerAddressSpace(
+                                                    RefF->getReturnType())));
+          }
+      }
+  }
+
   // Pass forward: use operand to deduce instructions result.
   for (auto &I : Worklist) {
     // Don't emit intrinsincs for convergence intrinsics.
@@ -2287,9 +2333,44 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
   return SzTodo > TodoTypeSz;
 }
 
+// Parse and store argument types of function declarations where needed.
+void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
+  for (auto &F : M) {
+    if (!F.isDeclaration() || F.isIntrinsic())
+      continue;
+    // get the demangled name
+    std::string DemangledName = getOclOrSpirvBuiltinDemangledName(F.getName());
+    if (DemangledName.empty())
+      continue;
+    // find pointer arguments
+    SmallVector<unsigned> Idxs;
+    for (unsigned OpIdx = 0; OpIdx < F.arg_size(); ++OpIdx)
+      if (isPointerTy(F.getArg(OpIdx)->getType()))
+        Idxs.push_back(OpIdx);
+    if (!Idxs.size())
+      continue;
+    // parse function arguments
+    LLVMContext &Ctx = F.getContext();
+    SmallVector<StringRef, 10> TypeStrs;
+    SPIRV::parseBuiltinTypeStr(TypeStrs, DemangledName, Ctx);
+    if (!TypeStrs.size())
+      continue;
+    // find type info for pointer arguments
+    for (unsigned Idx : Idxs) {
+      if (Idx >= TypeStrs.size())
+        continue;
+      if (Type *ElemTy =
+              SPIRV::parseBuiltinCallArgumentType(TypeStrs[Idx].trim(), Ctx))
+        FDeclPtrTys[&F].push_back(std::make_pair(Idx, ElemTy));
+    }
+  }
+}
+
 bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   bool Changed = false;
 
+  parseFunDeclarations(M);
+
   TodoType.clear();
   for (auto &F : M)
     Changed |= runOnFunction(F);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9ac659f6b4f111..91b9cbcf15128c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -325,8 +325,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
 
 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
                                                MachineIRBuilder &MIRBuilder,
-                                               SPIRVType *SpvType,
-                                               bool EmitIR) {
+                                               SPIRVType *SpvType, bool EmitIR,
+                                               bool ZeroAsNull) {
   assert(SpvType);
   auto &MF = MIRBuilder.getMF();
   const IntegerType *LLVMIntTy =
@@ -348,7 +348,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
     } else {
       Register SpvTypeReg = getSPIRVTypeID(SpvType);
       MachineInstrBuilder MIB;
-      if (Val) {
+      if (Val || !ZeroAsNull) {
         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
                   .addDef(Res)
                   .addUse(SpvTypeReg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ff4b0ea8757fa4..df92325ed19802 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -509,7 +509,8 @@ class SPIRVGlobalRegistry {
 
 public:
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
-                            SPIRVType *SpvType, bool EmitIR = true);
+                            SPIRVType *SpvType, bool EmitIR = true,
+                            bool ZeroAsNull = true);
   Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 7a1914aac8ceb8..1b9d3088756462 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -447,25 +447,31 @@ Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
   TypeName.consume_front("atomic_");
   if (TypeName.consume_front("void"))
     return Type::getVoidTy(Ctx);
-  else if (TypeName.consume_front("bool"))
+  else if (TypeName.consume_front("bool") || TypeName.consume_front("_Bool"))
     return Type::getIntNTy(Ctx, 1);
   else if (TypeName.consume_front("char") ||
+           TypeName.consume_front("signed char") ||
            TypeName.consume_front("unsigned char") ||
            TypeName.consume_front("uchar"))
     return Type::getInt8Ty(Ctx);
   else if (TypeName.consume_front("short") ||
+           TypeName.consume_front("signed short") ||
            TypeName.consume_front("unsigned short") ||
            TypeName.consume_front("ushort"))
     return Type::getInt16Ty(Ctx);
   else if (TypeName.consume_front("int") ||
+           TypeName.consume_front("signed int") ||
            TypeName.consume_front("unsigned int") ||
            TypeName.consume_front("uint"))
     return Type::getInt32Ty(Ctx);
   else if (TypeName.consume_front("long") ||
+           TypeName.consume_front("signed long") ||
            TypeName.consume_front("unsigned long") ||
            TypeName.consume_front("ulong"))
     return Type::getInt64Ty(Ctx);
-  else if (TypeName.consume_front("half"))
+  else if (TypeName.consume_front("half") ||
+           TypeName.consume_front("_Float16") ||
+           TypeName.consume_front("__fp16"))
     return Type::getHalfTy(Ctx);
   else if (TypeName.consume_front("float"))
     return Type::getFloatTy(Ctx);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index e512f909cfd059..91738634ff233b 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -32,6 +32,14 @@
 
 %StructEvent = type { target("spirv.Event") }
 
+define spir_kernel void @test_half(ptr addrspace(3) %_arg1, ptr addrspace(1) %_arg2) {
+entry:
+  %r = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg1, ptr addrspace(1) %_arg2, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+  ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 noundef, ptr addrspace(3) noundef, ptr addrspace(1) noundef, i64 noundef, i64 noundef, target("spirv.Event"))
+
 define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
 entry:
   %var = alloca %StructEvent



More information about the llvm-commits mailing list