[llvm] ce73e17 - [SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 3 01:34:08 PDT 2024
Author: Vyacheslav Levytskyy
Date: 2024-06-03T10:34:05+02:00
New Revision: ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3
URL: https://github.com/llvm/llvm-project/commit/ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3
DIFF: https://github.com/llvm/llvm-project/commit/ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.diff
LOG: [SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)
This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V
Backend inserts a bitcast before OpGroupWaitEvents if the last argument
is a pointer that doesn't point to OpTypeEvent.
Added:
llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2bd22bbd63169..5ccbaf12ddee2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0u, RC);
}
+inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
+ SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+ return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+ ? TypeInst->getOperand(1).getReg()
+ : OpReg;
+}
+
+static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR, MachineInstr &I,
+ Register OpReg, unsigned OpIdx,
+ SPIRVType *NewPtrType) {
+ Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ MachineIRBuilder MIB(I);
+ bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+ .addDef(NewReg)
+ .addUse(GR.getSPIRVTypeID(NewPtrType))
+ .addUse(OpReg)
+ .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+ *STI.getRegBankInfo());
+ if (!Res)
+ report_fatal_error("insert validation bitcast: cannot constrain all uses");
+ MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+ GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+ I.getOperand(OpIdx).setReg(NewReg);
+}
+
+static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
+ SPIRVType *OpType, bool ReuseType,
+ bool EmitIR, SPIRVType *ResType,
+ const Type *ResTy) {
+ SPIRV::StorageClass::StorageClass SC =
+ static_cast<SPIRV::StorageClass::StorageClass>(
+ OpType->getOperand(1).getImm());
+ MachineIRBuilder MIB(I);
+ SPIRVType *NewBaseType =
+ ReuseType ? ResType
+ : GR.getOrCreateSPIRVType(
+ ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+}
+
// Insert a bitcast before the instruction to keep SPIR-V code valid
// when there is a type mismatch between results and operand types.
static void validatePtrTypes(const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
// Get operand type
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
- SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- Register OpTypeReg =
- TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
- ? TypeInst->getOperand(1).getReg()
- : OpReg;
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
return;
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
- SPIRV::StorageClass::StorageClass SC =
- static_cast<SPIRV::StorageClass::StorageClass>(
- OpType->getOperand(1).getImm());
- MachineIRBuilder MIB(I);
- SPIRVType *NewBaseType =
- IsSameMF ? ResType
- : GR.getOrCreateSPIRVType(
- ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
- SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+ SPIRVType *NewPtrType =
+ createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
if (!GR.isBitcastCompatible(NewPtrType, OpType))
report_fatal_error(
"insert validation bitcast: incompatible result and operand types");
- Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
- bool Res = MIB.buildInstr(SPIRV::OpBitcast)
- .addDef(NewReg)
- .addUse(GR.getSPIRVTypeID(NewPtrType))
- .addUse(OpReg)
- .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
- *STI.getRegBankInfo());
- if (!Res)
- report_fatal_error("insert validation bitcast: cannot constrain all uses");
- MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
- GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
- I.getOperand(OpIdx).setReg(NewReg);
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
+// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
+// that doesn't point to OpTypeEvent.
+static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR,
+ MachineInstr &I) {
+ constexpr unsigned OpIdx = 2;
+ MachineFunction *MF = I.getParent()->getParent();
+ Register OpReg = I.getOperand(OpIdx).getReg();
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+ if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+ return;
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
+ return;
+ // Insert a bitcast before the instruction to keep SPIR-V code valid.
+ LLVMContext &Context = MF->getMMI().getModule()->getContext();
+ SPIRVType *NewPtrType =
+ createNewPtrType(GR, I, OpType, false, true, nullptr,
+ TargetExtType::get(Context, "spirv.Event"));
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
// Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
+ case SPIRV::OpGroupWaitEvents:
+ // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
+ validateGroupWaitEventsPtr(STI, MRI, GR, MI);
+ break;
case SPIRV::OpConstantI: {
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
diff --git a/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
new file mode 100644
index 0000000000000..d6fb70bb59a7e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#EventTy:]] = OpTypeEvent
+; CHECK: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
+; CHECK: %[[#GenPtrStructEventTy:]] = OpTypePointer Generic %[[#StructEventTy]]
+; CHECK: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
+; CHECK: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy:]]
+; CHECK: OpFunction
+; CHECK: %[[#Var:]] = OpVariable %[[#FunPtrStructEventTy]] Function
+; CHECK-NEXT: %[[#AddrspacecastVar:]] = OpPtrCastToGeneric %[[#GenPtrStructEventTy]] %[[#Var]]
+; CHECK-NEXT: %[[#BitcastVar:]] = OpBitcast %[[#GenPtrEventTy]] %[[#AddrspacecastVar]]
+; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#BitcastVar]]
+
+%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+
+define weak_odr dso_local spir_kernel void @foo() {
+entry:
+ %var = alloca %"class.sycl::_V1::device_event"
+ %eventptr = addrspacecast ptr %var to ptr addrspace(4)
+ call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %eventptr)
+ ret void
+}
+
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))
More information about the llvm-commits
mailing list