[llvm] ce73e17 - [SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 01:34:08 PDT 2024


Author: Vyacheslav Levytskyy
Date: 2024-06-03T10:34:05+02:00
New Revision: ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3

URL: https://github.com/llvm/llvm-project/commit/ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3
DIFF: https://github.com/llvm/llvm-project/commit/ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.diff

LOG: [SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)

This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V
Backend inserts a bitcast before OpGroupWaitEvents if the last argument
is a pointer that doesn't point to OpTypeEvent.

Added: 
    llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2bd22bbd63169..5ccbaf12ddee2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   return std::make_pair(0u, RC);
 }
 
+inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
+  SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+  return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+             ? TypeInst->getOperand(1).getReg()
+             : OpReg;
+}
+
+static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
+                            SPIRVGlobalRegistry &GR, MachineInstr &I,
+                            Register OpReg, unsigned OpIdx,
+                            SPIRVType *NewPtrType) {
+  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+  MachineIRBuilder MIB(I);
+  bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+                 .addDef(NewReg)
+                 .addUse(GR.getSPIRVTypeID(NewPtrType))
+                 .addUse(OpReg)
+                 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+                                   *STI.getRegBankInfo());
+  if (!Res)
+    report_fatal_error("insert validation bitcast: cannot constrain all uses");
+  MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+  I.getOperand(OpIdx).setReg(NewReg);
+}
+
+static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
+                                   SPIRVType *OpType, bool ReuseType,
+                                   bool EmitIR, SPIRVType *ResType,
+                                   const Type *ResTy) {
+  SPIRV::StorageClass::StorageClass SC =
+      static_cast<SPIRV::StorageClass::StorageClass>(
+          OpType->getOperand(1).getImm());
+  MachineIRBuilder MIB(I);
+  SPIRVType *NewBaseType =
+      ReuseType ? ResType
+                : GR.getOrCreateSPIRVType(
+                      ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+  return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+}
+
 // Insert a bitcast before the instruction to keep SPIR-V code valid
 // when there is a type mismatch between results and operand types.
 static void validatePtrTypes(const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
   // Get operand type
   MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
-  SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  Register OpTypeReg =
-      TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
-          ? TypeInst->getOperand(1).getReg()
-          : OpReg;
+  Register OpTypeReg = getTypeReg(MRI, OpReg);
   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
     return;
   // There is a type mismatch between results and operand types
   // and we insert a bitcast before the instruction to keep SPIR-V code valid
-  SPIRV::StorageClass::StorageClass SC =
-      static_cast<SPIRV::StorageClass::StorageClass>(
-          OpType->getOperand(1).getImm());
-  MachineIRBuilder MIB(I);
-  SPIRVType *NewBaseType =
-      IsSameMF ? ResType
-               : GR.getOrCreateSPIRVType(
-                     ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
-  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+  SPIRVType *NewPtrType =
+      createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
   if (!GR.isBitcastCompatible(NewPtrType, OpType))
     report_fatal_error(
         "insert validation bitcast: incompatible result and operand types");
-  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
-  bool Res = MIB.buildInstr(SPIRV::OpBitcast)
-                 .addDef(NewReg)
-                 .addUse(GR.getSPIRVTypeID(NewPtrType))
-                 .addUse(OpReg)
-                 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
-                                   *STI.getRegBankInfo());
-  if (!Res)
-    report_fatal_error("insert validation bitcast: cannot constrain all uses");
-  MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
-  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
-  I.getOperand(OpIdx).setReg(NewReg);
+  doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
+// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
+// that doesn't point to OpTypeEvent.
+static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
+                                       MachineRegisterInfo *MRI,
+                                       SPIRVGlobalRegistry &GR,
+                                       MachineInstr &I) {
+  constexpr unsigned OpIdx = 2;
+  MachineFunction *MF = I.getParent()->getParent();
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  Register OpTypeReg = getTypeReg(MRI, OpReg);
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+  if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+    return;
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
+    return;
+  // Insert a bitcast before the instruction to keep SPIR-V code valid.
+  LLVMContext &Context = MF->getMMI().getModule()->getContext();
+  SPIRVType *NewPtrType =
+      createNewPtrType(GR, I, OpType, false, true, nullptr,
+                       TargetExtType::get(Context, "spirv.Event"));
+  doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
 }
 
 // Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
                                       SPIRV::OpTypeBool))
           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
         break;
+      case SPIRV::OpGroupWaitEvents:
+        // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
+        validateGroupWaitEventsPtr(STI, MRI, GR, MI);
+        break;
       case SPIRV::OpConstantI: {
         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
         if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&

diff  --git a/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
new file mode 100644
index 0000000000000..d6fb70bb59a7e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#EventTy:]] = OpTypeEvent
+; CHECK: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
+; CHECK: %[[#GenPtrStructEventTy:]] = OpTypePointer Generic %[[#StructEventTy]]
+; CHECK: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
+; CHECK: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy:]]
+; CHECK: OpFunction
+; CHECK: %[[#Var:]] = OpVariable %[[#FunPtrStructEventTy]] Function
+; CHECK-NEXT: %[[#AddrspacecastVar:]] = OpPtrCastToGeneric %[[#GenPtrStructEventTy]] %[[#Var]]
+; CHECK-NEXT: %[[#BitcastVar:]] = OpBitcast %[[#GenPtrEventTy]] %[[#AddrspacecastVar]]
+; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#BitcastVar]]
+
+%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+
+define weak_odr dso_local spir_kernel void @foo() {
+entry:
+  %var = alloca %"class.sycl::_V1::device_event"
+  %eventptr = addrspacecast ptr %var to ptr addrspace(4)
+  call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %eventptr)
+  ret void
+}
+
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))


        


More information about the llvm-commits mailing list