[llvm] [SPIR-V] Add validation to the test case with get_image_array_size/get_image_dim calls (PR #94467)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 5 09:15:24 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/94467

>From 9a75467d7e4cd185c7604007751192daea7ec0e5 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 5 Jun 2024 06:22:01 -0700
Subject: [PATCH 1/3] improve type inference

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  20 ++-
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 150 +++++++++++-------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  16 +-
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |   6 +-
 .../test/CodeGen/SPIRV/event-wait-ptr-type.ll |  16 +-
 ...Intrinsics-no-duplicate-spv_assign_type.ll |   4 +-
 6 files changed, 142 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 424087f361a6a..9b9b8f7cbc089 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -24,6 +24,13 @@
 #define DEBUG_TYPE "spirv-builtins"
 
 namespace llvm {
+
+//  Defined in SPIRVPreLegalizer.cpp.
+extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
+                                  SPIRVGlobalRegistry *GR,
+                                  MachineIRBuilder &MIB,
+                                  MachineRegisterInfo &MRI);
+
 namespace SPIRV {
 #define GET_BuiltinGroup_DECL
 #include "SPIRVGenTables.inc"
@@ -1451,11 +1458,22 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
         Component == 3 ? NumActualRetComponents - 1 : Component;
     assert(ExtractedComposite < NumActualRetComponents &&
            "Invalid composite index!");
+    Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+    SPIRVType *NewType = nullptr;
+    if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) {
+      Register NewTypeReg = QueryResultType->getOperand(1).getReg();
+      if (TypeReg != NewTypeReg &&
+          (NewType = GR->getSPIRVTypeForVReg(NewTypeReg)) != nullptr)
+        TypeReg = NewTypeReg;
+    }
     MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
         .addDef(Call->ReturnRegister)
-        .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+        .addUse(TypeReg)
         .addUse(QueryResult)
         .addImm(ExtractedComposite);
+    if (NewType != nullptr)
+      insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
+                        MIRBuilder.getMF().getRegInfo());
   } else {
     // More than 1 component is expected, fill a new vector.
     auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5ef0be1cab722..696706258ec40 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -50,6 +50,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
 namespace {
+
 class SPIRVEmitIntrinsics
     : public ModulePass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
@@ -61,9 +62,6 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
 
-  // a registry of created Intrinsic::spv_assign_ptr_type instructions
-  DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
-
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I);
   Type *deduceElementTypeHelper(Value *I);
@@ -98,14 +96,16 @@ class SPIRVEmitIntrinsics
     return B.CreateIntrinsic(IntrID, {Types}, Args);
   }
 
+  void buildAssignType(IRBuilder<> &B, Type *ElemTy, Value *Arg);
   void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
+  void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType);
 
   void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
   void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
   void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
   void insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B);
-  void insertAssignTypeInstrForTargetExtTypes(TargetExtType *AssignedType,
-                                              Value *V, IRBuilder<> &B);
+  void insertAssignPtrTypeTargetExt(TargetExtType *AssignedType, Value *V,
+                                    IRBuilder<> &B);
   void replacePointerOperandWithPtrCast(Instruction *I, Value *Pointer,
                                         Type *ExpectedElementType,
                                         unsigned OperandToReplace,
@@ -218,15 +218,39 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
                        false);
 }
 
+void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
+                                          Value *Arg) {
+  Value *OfType = PoisonValue::get(Ty);
+  CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
+                                       {Arg->getType()}, OfType, Arg, {}, B);
+  GR->addAssignPtrTypeInstr(Arg, AssignCI);
+}
+
 void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
                                          Value *Arg) {
-  CallInst *AssignPtrTyCI =
-      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Arg->getType()},
-                      Constant::getNullValue(ElemTy), Arg,
-                      {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+  Value *OfType = PoisonValue::get(ElemTy);
+  CallInst *AssignPtrTyCI = buildIntrWithMD(
+      Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
+      {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
   GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
   GR->addDeducedElementType(Arg, ElemTy);
-  AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
+  GR->addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
+}
+
+void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
+                                           Value *OfType) {
+  LLVMContext &Ctx = Arg->getContext();
+  AssignCI->setArgOperand(
+      1, MetadataAsValue::get(
+             Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OfType))));
+  if (cast<IntrinsicInst>(AssignCI)->getIntrinsicID() !=
+      Intrinsic::spv_assign_ptr_type)
+    return;
+
+  // update association with the pointee type
+  Type *ElemTy = OfType->getType();
+  GR->addDeducedElementType(AssignCI, ElemTy);
+  GR->addDeducedElementType(Arg, ElemTy);
 }
 
 // Set element pointer type to the given value of ValueTy and tries to
@@ -513,19 +537,16 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
     if (!Ty) {
       GR->addDeducedElementType(Op, KnownElemTy);
       // check if there is existing Intrinsic::spv_assign_ptr_type instruction
-      auto It = AssignPtrTypeInstr.find(Op);
-      if (It == AssignPtrTypeInstr.end()) {
+      CallInst *AssignCI = GR->findAssignPtrTypeInstr(Op);
+      if (AssignCI == nullptr) {
         Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get());
         setInsertPointSkippingPhis(B, User ? User->getNextNode() : I);
         CallInst *CI =
             buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
                             {B.getInt32(getPointerAddressSpace(OpTy))}, B);
-        AssignPtrTypeInstr[Op] = CI;
+        GR->addAssignPtrTypeInstr(Op, CI);
       } else {
-        It->second->setArgOperand(
-            1,
-            MetadataAsValue::get(
-                Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
+        updateAssignType(AssignCI, Op, OpTyVal);
       }
     } else {
       if (auto *OpI = dyn_cast<Instruction>(Op)) {
@@ -559,7 +580,9 @@ void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
     if (isAssignTypeInstr(U)) {
       B.SetInsertPoint(U);
       SmallVector<Value *, 2> Args = {New, U->getOperand(1)};
-      B.CreateIntrinsic(Intrinsic::spv_assign_type, {New->getType()}, Args);
+      CallInst *AssignCI =
+          B.CreateIntrinsic(Intrinsic::spv_assign_type, {New->getType()}, Args);
+      GR->addAssignPtrTypeInstr(New, AssignCI);
       U->eraseFromParent();
     } else if (isMemInstrToReplace(U) || isa<ReturnInst>(U) ||
                isa<CallInst>(U)) {
@@ -751,33 +774,39 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   return NewI;
 }
 
-void SPIRVEmitIntrinsics::insertAssignTypeInstrForTargetExtTypes(
+void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
     TargetExtType *AssignedType, Value *V, IRBuilder<> &B) {
-  // Do not emit spv_assign_type if the V is of the AssignedType already.
-  if (V->getType() == AssignedType)
-    return;
+  Type *VTy = V->getType();
 
-  // Do not emit spv_assign_type if there is one already targetting V. If the
-  // found spv_assign_type assigns a type different than AssignedType, report an
-  // error. Builtin types cannot be redeclared or casted.
-  for (auto User : V->users()) {
-    auto *II = dyn_cast<IntrinsicInst>(User);
-    if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_type)
-      continue;
+  // A couple of sanity checks.
+  assert(isPointerTy(VTy) && "Expect a pointer type!");
+  if (auto PType = dyn_cast<TypedPointerType>(VTy))
+    if (PType->getElementType() != AssignedType)
+      report_fatal_error("Unexpected pointer element type!");
 
-    MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
-    Type *BuiltinType =
-        dyn_cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
-    if (BuiltinType != AssignedType)
-      report_fatal_error("Type mismatch " + BuiltinType->getTargetExtName() +
-                             "/" + AssignedType->getTargetExtName() +
-                             " for value " + V->getName(),
-                         false);
+  CallInst *AssignCI = GR->findAssignPtrTypeInstr(V);
+  if (!AssignCI) {
+    buildAssignType(B, AssignedType, V);
     return;
   }
 
-  Constant *Const = UndefValue::get(AssignedType);
-  buildIntrWithMD(Intrinsic::spv_assign_type, {V->getType()}, Const, V, {}, B);
+  Type *CurrentType =
+      dyn_cast<ConstantAsMetadata>(
+          cast<MetadataAsValue>(AssignCI->getOperand(1))->getMetadata())
+          ->getType();
+  if (CurrentType == AssignedType)
+    return;
+
+  // Builtin types cannot be redeclared or casted.
+  if (CurrentType->isTargetExtTy())
+    report_fatal_error("Type mismatch " + CurrentType->getTargetExtName() +
+                           "/" + AssignedType->getTargetExtName() +
+                           " for value " + V->getName(),
+                       false);
+
+  // Our previous guess about the type seems to be wrong, let's update
+  // inferred type according to a new, more precise type information.
+  updateAssignType(AssignCI, V, PoisonValue::get(AssignedType));
 }
 
 void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
@@ -850,7 +879,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
         ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
     GR->addDeducedElementType(CI, ExpectedElementType);
     GR->addDeducedElementType(Pointer, ExpectedElementType);
-    AssignPtrTypeInstr[Pointer] = CI;
+    GR->addAssignPtrTypeInstr(Pointer, CI);
     return;
   }
 
@@ -929,8 +958,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
 
   for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
     Value *ArgOperand = CI->getArgOperand(OpIdx);
-    if (!isa<PointerType>(ArgOperand->getType()) &&
-        !isa<TypedPointerType>(ArgOperand->getType()))
+    if (!isPointerTy(ArgOperand->getType()))
       continue;
 
     // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
@@ -952,8 +980,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       continue;
 
     if (ExpectedType->isTargetExtTy())
-      insertAssignTypeInstrForTargetExtTypes(cast<TargetExtType>(ExpectedType),
-                                             ArgOperand, B);
+      insertAssignPtrTypeTargetExt(cast<TargetExtType>(ExpectedType),
+                                   ArgOperand, B);
     else
       replacePointerOperandWithPtrCast(CI, ArgOperand, ExpectedType, OpIdx, B);
   }
@@ -1145,7 +1173,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
                                  EltTyConst, I, {B.getInt32(AddressSpace)}, B);
   GR->addDeducedElementType(CI, ElemTy);
-  AssignPtrTypeInstr[I] = CI;
+  GR->addAssignPtrTypeInstr(I, CI);
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1164,20 +1192,32 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
         TypeToAssign = It->second;
       }
     }
-    Constant *Const = UndefValue::get(TypeToAssign);
-    buildIntrWithMD(Intrinsic::spv_assign_type, {Ty}, Const, I, {}, B);
+    buildAssignType(B, TypeToAssign, I);
   }
   for (const auto &Op : I->operands()) {
     if (isa<ConstantPointerNull>(Op) || isa<UndefValue>(Op) ||
         // Check GetElementPtrConstantExpr case.
         (isa<ConstantExpr>(Op) && isa<GEPOperator>(Op))) {
       setInsertPointSkippingPhis(B, I);
-      if (isa<UndefValue>(Op) && Op->getType()->isAggregateType())
-        buildIntrWithMD(Intrinsic::spv_assign_type, {B.getInt32Ty()}, Op,
-                        UndefValue::get(B.getInt32Ty()), {}, B);
-      else if (!isa<Instruction>(Op))
-        buildIntrWithMD(Intrinsic::spv_assign_type, {Op->getType()}, Op, Op, {},
-                        B);
+      Type *OpTy = Op->getType();
+      if (isa<UndefValue>(Op) && OpTy->isAggregateType()) {
+        CallInst *AssignCI =
+            buildIntrWithMD(Intrinsic::spv_assign_type, {B.getInt32Ty()}, Op,
+                            UndefValue::get(B.getInt32Ty()), {}, B);
+        GR->addAssignPtrTypeInstr(Op, AssignCI);
+      } else if (!isa<Instruction>(Op)) {
+        Type *OpTy = Op->getType();
+        if (auto PType = dyn_cast<TypedPointerType>(OpTy)) {
+          buildAssignPtr(B, PType->getElementType(), Op);
+        } else if (isPointerTy(OpTy)) {
+          Type *ElemTy = GR->findDeducedElementType(Op);
+          buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op), Op);
+        } else {
+          CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
+                                               {OpTy}, Op, Op, {}, B);
+          GR->addAssignPtrTypeInstr(Op, AssignCI);
+        }
+      }
     }
   }
 }
@@ -1368,14 +1408,12 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       continue;
 
     insertAssignPtrTypeIntrs(I, B);
+    deduceOperandElementType(I);
     insertAssignTypeIntrs(I, B);
     insertPtrCastOrAssignTypeInstr(I, B);
     insertSpirvDecorations(I, B);
   }
 
-  for (auto &I : instructions(Func))
-    deduceOperandElementType(&I);
-
   for (auto *I : Worklist) {
     TrackConstants = true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..0103fb8214341 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -72,8 +72,11 @@ class SPIRVGlobalRegistry {
   // untyped pointers.
   DenseMap<Value *, Type *> DeducedElTys;
   // Maps composite values to deduced types where untyped pointers are replaced
-  // with typed ones
+  // with typed ones.
   DenseMap<Value *, Type *> DeducedNestedTys;
+  // Maps values to "assign type" calls, thus being a registry of created
+  // Intrinsic::spv_assign_ptr_type instructions.
+  DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
 
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
@@ -148,6 +151,17 @@ class SPIRVGlobalRegistry {
     return It == FunResPointerTypes.end() ? nullptr : It->second;
   }
 
+  // A registry of "assign type" records:
+  // - Add a record.
+  void addAssignPtrTypeInstr(Value *Val, CallInst *AssignPtrTyCI) {
+    AssignPtrTypeInstr[Val] = AssignPtrTyCI;
+  }
+  // - Find a record.
+  CallInst *findAssignPtrTypeInstr(const Value *Val) {
+    auto It = AssignPtrTypeInstr.find(Val);
+    return It == AssignPtrTypeInstr.end() ? nullptr : It->second;
+  }
+
   // Deduced element types of untyped pointers and composites:
   // - Add a record to the map of deduced element types.
   void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 3d536085b78aa..a0a253c23b1e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -417,7 +417,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         MachineInstr *Def = MRI.getVRegDef(Reg);
         assert(Def && "Expecting an instruction that defines the register");
         // G_GLOBAL_VALUE already has type info.
-        if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
+        if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
+            Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
                             MF.getRegInfo());
         ToErase.push_back(&MI);
@@ -427,7 +428,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         MachineInstr *Def = MRI.getVRegDef(Reg);
         assert(Def && "Expecting an instruction that defines the register");
         // G_GLOBAL_VALUE already has type info.
-        if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
+        if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE &&
+            Def->getOpcode() != SPIRV::ASSIGN_TYPE)
           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
         ToErase.push_back(&MI);
       } else if (MIOp == TargetOpcode::G_CONSTANT ||
diff --git a/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
index d6fb70bb59a7e..ec9afc789944d 100644
--- a/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
+++ b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
@@ -4,16 +4,16 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
-; CHECK: %[[#EventTy:]] = OpTypeEvent
-; CHECK: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
-; CHECK: %[[#GenPtrStructEventTy:]] = OpTypePointer Generic %[[#StructEventTy]]
-; CHECK: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
-; CHECK: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy:]]
+; CHECK-DAG: %[[#EventTy:]] = OpTypeEvent
+; CHECK-DAG: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
+; CHECK-DAG: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
+; CHECK-DAG: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy]]
+; CHECK-DAG: %[[#FunPtrEventTy:]] = OpTypePointer Function %[[#EventTy]]
 ; CHECK: OpFunction
 ; CHECK: %[[#Var:]] = OpVariable %[[#FunPtrStructEventTy]] Function
-; CHECK-NEXT: %[[#AddrspacecastVar:]] = OpPtrCastToGeneric %[[#GenPtrStructEventTy]] %[[#Var]]
-; CHECK-NEXT: %[[#BitcastVar:]] = OpBitcast %[[#GenPtrEventTy]] %[[#AddrspacecastVar]]
-; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#BitcastVar]]
+; CHECK-NEXT: %[[#FunEvent:]] = OpBitcast %[[#FunPtrEventTy]] %[[#Var]]
+; CHECK-NEXT: %[[#GenEvent:]] = OpPtrCastToGeneric %[[#GenPtrEventTy]] %[[#FunEvent]]
+; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#GenEvent]]
 
 %"class.sycl::_V1::device_event" = type { target("spirv.Event") }
 
diff --git a/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-duplicate-spv_assign_type.ll b/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-duplicate-spv_assign_type.ll
index 7056b9cb1230d..9db4f26a27d4f 100644
--- a/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-duplicate-spv_assign_type.ll
+++ b/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-duplicate-spv_assign_type.ll
@@ -3,9 +3,9 @@
 ; CHECK: *** IR Dump After SPIRV emit intrinsics (emit-intrinsics) ***
 
 define spir_kernel void @test(ptr addrspace(1) %srcimg) {
-; CHECK: call void @llvm.spv.assign.type.p1(ptr addrspace(1) %srcimg, metadata target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 0) undef)
+; CHECK: call void @llvm.spv.assign.type.p1(ptr addrspace(1) %srcimg, metadata target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 0) poison)
   %call1 = call spir_func <2 x i32> @_Z13get_image_dim14ocl_image2d_ro(ptr addrspace(1) %srcimg)
-; CHECK-NOT: call void @llvm.spv.assign.type.p1(ptr addrspace(1) %srcimg, metadata target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 0) undef)
+; CHECK-NOT: call void @llvm.spv.assign.type.p1(ptr addrspace(1) %srcimg, metadata target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 0) poison)
   %call2 = call spir_func <2 x i32> @_Z13get_image_dim14ocl_image2d_ro(ptr addrspace(1) %srcimg)
   ret void
 ; CHECK: }

>From 4f40eb8c3eec26cb6905471858ecb33facb36522 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 5 Jun 2024 06:30:17 -0700
Subject: [PATCH 2/3] add validation

---
 llvm/test/CodeGen/SPIRV/transcoding/check_ro_qualifier.ll | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/SPIRV/transcoding/check_ro_qualifier.ll b/llvm/test/CodeGen/SPIRV/transcoding/check_ro_qualifier.ll
index 824ca1b2d6924..6f61aba23a46f 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/check_ro_qualifier.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/check_ro_qualifier.ll
@@ -1,5 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
-; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV: %[[#IMAGE_TYPE:]] = OpTypeImage
 ; CHECK-SPIRV: %[[#IMAGE_ARG:]] = OpFunctionParameter %[[#IMAGE_TYPE]]

>From 31b490628ce2ab7fc5d6653f7ec717a93ebf615d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 5 Jun 2024 09:11:53 -0700
Subject: [PATCH 3/3] add a test case

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 40 +++++++++++--------
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        |  1 +
 .../transcoding/OpGroupAsyncCopy-strided.ll   | 36 +++++++++++++++++
 3 files changed, 61 insertions(+), 16 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/OpGroupAsyncCopy-strided.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 9b9b8f7cbc089..93209e1728008 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -25,12 +25,6 @@
 
 namespace llvm {
 
-//  Defined in SPIRVPreLegalizer.cpp.
-extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
-                                  SPIRVGlobalRegistry *GR,
-                                  MachineIRBuilder &MIB,
-                                  MachineRegisterInfo &MRI);
-
 namespace SPIRV {
 #define GET_BuiltinGroup_DECL
 #include "SPIRVGenTables.inc"
@@ -2073,16 +2067,30 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
   auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);
 
   switch (Opcode) {
-  case SPIRV::OpGroupAsyncCopy:
-    return MIRBuilder.buildInstr(Opcode)
-        .addDef(Call->ReturnRegister)
-        .addUse(GR->getSPIRVTypeID(Call->ReturnType))
-        .addUse(Scope)
-        .addUse(Call->Arguments[0])
-        .addUse(Call->Arguments[1])
-        .addUse(Call->Arguments[2])
-        .addUse(buildConstantIntReg(1, MIRBuilder, GR))
-        .addUse(Call->Arguments[3]);
+  case SPIRV::OpGroupAsyncCopy: {
+    SPIRVType *NewType =
+        Call->ReturnType->getOpcode() == SPIRV::OpTypeEvent
+            ? nullptr
+            : GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder);
+    Register TypeReg = GR->getSPIRVTypeID(NewType ? NewType : Call->ReturnType);
+    unsigned NumArgs = Call->Arguments.size();
+    Register EventReg = Call->Arguments[NumArgs - 1];
+    bool Res = MIRBuilder.buildInstr(Opcode)
+                   .addDef(Call->ReturnRegister)
+                   .addUse(TypeReg)
+                   .addUse(Scope)
+                   .addUse(Call->Arguments[0])
+                   .addUse(Call->Arguments[1])
+                   .addUse(Call->Arguments[2])
+                   .addUse(Call->Arguments.size() > 4
+                               ? Call->Arguments[3]
+                               : buildConstantIntReg(1, MIRBuilder, GR))
+                   .addUse(EventReg);
+    if (NewType != nullptr)
+      insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
+                        MIRBuilder.getMF().getRegInfo());
+    return Res;
+  }
   case SPIRV::OpGroupWaitEvents:
     return MIRBuilder.buildInstr(Opcode)
         .addUse(Scope)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 692234c405ab6..da547cbab4e98 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -586,6 +586,7 @@ defm : DemangledNativeBuiltin<"__spirv_SpecConstantComposite", OpenCL_std, SpecC
 
 // Async Copy and Prefetch builtin records:
 defm : DemangledNativeBuiltin<"async_work_group_copy", OpenCL_std, AsyncCopy, 4, 4, OpGroupAsyncCopy>;
+defm : DemangledNativeBuiltin<"async_work_group_strided_copy", OpenCL_std, AsyncCopy, 5, 5, OpGroupAsyncCopy>;
 defm : DemangledNativeBuiltin<"__spirv_GroupAsyncCopy", OpenCL_std, AsyncCopy, 6, 6, OpGroupAsyncCopy>;
 defm : DemangledNativeBuiltin<"wait_group_events", OpenCL_std, AsyncCopy, 2, 2, OpGroupWaitEvents>;
 defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy, 3, 3, OpGroupWaitEvents>;
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAsyncCopy-strided.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAsyncCopy-strided.ll
new file mode 100644
index 0000000000000..96d6016083f06
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAsyncCopy-strided.ll
@@ -0,0 +1,36 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#LongTy:]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[#IntTy:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#EventTy:]] = OpTypeEvent
+; CHECK-SPIRV-DAG: %[[#Scope:]] = OpConstant %[[#IntTy]] 2
+; CHECK-SPIRV-DAG: %[[#Num:]] = OpConstant %[[#LongTy]] 123
+; CHECK-SPIRV-DAG: %[[#Null:]] = OpConstantNull
+; CHECK-SPIRV-DAG: %[[#Stride:]] = OpConstant %[[#LongTy]] 1
+; CHECK-SPIRV-DAG: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy]]
+; CHECK-SPIRV-DAG: %[[#FunPtrEventTy:]] = OpTypePointer Function %[[#EventTy]]
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Var:]] = OpVariable %[[#]] Function
+; CHECK-SPIRV: %[[#ResEvent:]] = OpGroupAsyncCopy %[[#EventTy]] %[[#Scope]] %[[#Null]] %[[#Null]] %[[#Num]] %[[#Stride]] %[[#Null]]
+; CHECK-SPIRV: %[[#VarPtrEvent:]] = OpBitcast %[[#FunPtrEventTy]] %[[#Var]]
+; CHECK-SPIRV: OpStore %[[#VarPtrEvent]] %[[#ResEvent]]
+; CHECK-SPIRV: %[[#VarPtrEvent2:]] = OpBitcast %[[#FunPtrEventTy]] %[[#Var]]
+; CHECK-SPIRV: %[[#PtrEventGen:]] = OpPtrCastToGeneric %[[#]] %[[#VarPtrEvent2]]
+; CHECK-SPIRV: OpGroupWaitEvents %[[#Scope]] %[[#Num]] %[[#PtrEventGen]]
+; CHECK-SPIRV: OpFunctionEnd
+
+define spir_kernel void @foo() {
+  %event = alloca ptr, align 8
+  %call = call spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr null, ptr null, i64 123, i64 1, ptr null)
+  store ptr %call, ptr %event, align 8
+  %event.ascast = addrspacecast ptr %event to ptr addrspace(4)
+  call spir_func void @_Z17wait_group_eventsiPU3AS49ocl_event(i64 123, ptr addrspace(4) %event.ascast)
+  ret void
+}
+
+declare spir_func ptr @_Z29async_work_group_strided_copyPU3AS3hPU3AS1Khmm9ocl_event(ptr, ptr, i64, i64, ptr)
+declare spir_func void @_Z17wait_group_eventsiPU3AS49ocl_event(i64, ptr addrspace(4))



More information about the llvm-commits mailing list