[llvm] 1ca8ad2 - [SPIRV] Create a new OpSelect selector and fix register types. (#152311)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 12 14:43:33 PDT 2025
Author: Farzon Lotfi
Date: 2025-08-12T17:43:30-04:00
New Revision: 1ca8ad29dbbe4255cb19fb1193a88040dda515a9
URL: https://github.com/llvm/llvm-project/commit/1ca8ad29dbbe4255cb19fb1193a88040dda515a9
DIFF: https://github.com/llvm/llvm-project/commit/1ca8ad29dbbe4255cb19fb1193a88040dda515a9.diff
LOG: [SPIRV] Create a new OpSelect selector and fix register types. (#152311)
fixes #135572
There are two problems that are causing problems first register types
are copied from older registers instead of evaluating the spirv types.
Second the way OpSelect is defined in SPIRVInstrInfo.td we always
default to integer for TernOpTyped. There seems to be a problem of
multiple matches in the getMatchTable so when executeMatchTable runs we
aren't getting the right opSelect.
Correcting the tablegen wasn't very easy so instead created an emitter
for Select that evaluated the register types. this passes the original
llvm/test/CodeGen/SPIRV/instructions/select.ll tests and the new float
ones I'm adding in issue-135572-emit-float-opselect.ll
Added:
llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5259db1ff2dd7..98c7709acf938 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -220,8 +220,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectConst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
- bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
- bool IsSigned) const;
+ bool selectSelect(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+ bool selectSelectDefaultArgs(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, bool IsSigned) const;
bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
bool IsSigned, unsigned Opcode) const;
bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
@@ -510,7 +512,18 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
if (isTypeFoldingSupported(Def->getOpcode()) &&
Def->getOpcode() != TargetOpcode::G_CONSTANT &&
Def->getOpcode() != TargetOpcode::G_FCONSTANT) {
- bool Res = selectImpl(I, *CoverageInfo);
+ bool Res = false;
+ if (Def->getOpcode() == TargetOpcode::G_SELECT) {
+ Register SelectDstReg = Def->getOperand(0).getReg();
+ Res = selectSelect(SelectDstReg, GR.getSPIRVTypeForVReg(SelectDstReg),
+ *Def);
+ GR.invalidateMachineInstr(Def);
+ Def->removeFromParent();
+ MRI->replaceRegWith(DstReg, SelectDstReg);
+ GR.invalidateMachineInstr(&I);
+ I.removeFromParent();
+ } else
+ Res = selectImpl(I, *CoverageInfo);
LLVM_DEBUG({
if (!Res && Def->getOpcode() != TargetOpcode::G_CONSTANT) {
dbgs() << "Unexpected pattern in ASSIGN_TYPE.\nInstruction: ";
@@ -2565,8 +2578,52 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
const SPIRVType *ResType,
- MachineInstr &I,
- bool IsSigned) const {
+ MachineInstr &I) const {
+ Register SelectFirstArg = I.getOperand(2).getReg();
+ Register SelectSecondArg = I.getOperand(3).getReg();
+ assert(ResType == GR.getSPIRVTypeForVReg(SelectFirstArg) &&
+ ResType == GR.getSPIRVTypeForVReg(SelectSecondArg));
+
+ bool IsFloatTy =
+ GR.isScalarOrVectorOfType(SelectFirstArg, SPIRV::OpTypeFloat);
+ bool IsPtrTy =
+ GR.isScalarOrVectorOfType(SelectFirstArg, SPIRV::OpTypePointer);
+ bool IsVectorTy = GR.getSPIRVTypeForVReg(SelectFirstArg)->getOpcode() ==
+ SPIRV::OpTypeVector;
+
+ bool IsScalarBool =
+ GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
+ unsigned Opcode;
+ if (IsVectorTy) {
+ if (IsFloatTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectVFSCond : SPIRV::OpSelectVFVCond;
+ } else if (IsPtrTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectVPSCond : SPIRV::OpSelectVPVCond;
+ } else {
+ Opcode = IsScalarBool ? SPIRV::OpSelectVISCond : SPIRV::OpSelectVIVCond;
+ }
+ } else {
+ if (IsFloatTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectSFSCond : SPIRV::OpSelectVFVCond;
+ } else if (IsPtrTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectSPSCond : SPIRV::OpSelectVPVCond;
+ } else {
+ Opcode = IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+ }
+ }
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(1).getReg())
+ .addUse(SelectFirstArg)
+ .addUse(SelectSecondArg)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectSelectDefaultArgs(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
// To extend a bool, we need to use OpSelect between constants.
Register ZeroReg = buildZerosVal(ResType, I);
Register OneReg = buildOnesVal(IsSigned, ResType, I);
@@ -2598,7 +2655,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg,
TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII);
}
SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF());
- selectSelect(SrcReg, TmpType, I, false);
+ selectSelectDefaultArgs(SrcReg, TmpType, I, false);
}
return selectOpWithSrcs(ResVReg, ResType, I, {SrcReg}, Opcode);
}
@@ -2608,7 +2665,7 @@ bool SPIRVInstructionSelector::selectExt(Register ResVReg,
MachineInstr &I, bool IsSigned) const {
Register SrcReg = I.getOperand(1).getReg();
if (GR.isScalarOrVectorOfType(SrcReg, SPIRV::OpTypeBool))
- return selectSelect(ResVReg, ResType, I, IsSigned);
+ return selectSelectDefaultArgs(ResVReg, ResType, I, IsSigned);
SPIRVType *SrcType = GR.getSPIRVTypeForVReg(SrcReg);
if (SrcType == ResType)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b62db7fd62b2e..1a08c6ac0dcaf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -441,13 +441,10 @@ void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
// Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
// present after each auto-folded instruction to take a type reference from.
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
- if (auto *RC = MRI.getRegClassOrNull(Reg)) {
- MRI.setRegClass(NewReg, RC);
- } else {
- auto RegClass = GR->getRegClass(SpvType);
- MRI.setRegClass(NewReg, RegClass);
- MRI.setRegClass(Reg, RegClass);
- }
+ const auto *RegClass = GR->getRegClass(SpvType);
+ MRI.setRegClass(NewReg, RegClass);
+ MRI.setRegClass(Reg, RegClass);
+
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
// This is to make it convenient for Legalizer to get the SPIRVType
// when processing the actual MI (i.e. not pseudo one).
diff --git a/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
new file mode 100644
index 0000000000000..69ea054de1e4f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
@@ -0,0 +1,53 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %}
+
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#func_type:]] = OpTypeFunction %[[#float_32]] %[[#float_32]] %[[#float_32]]
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec_func_type:]] = OpTypeFunction %[[#vec4_float_32]] %[[#vec4_float_32]] %[[#vec4_float_32]]
+; CHECK-DAG: %[[#vec_4_bool:]] = OpTypeVector %[[#bool]] 4
+
+define spir_func float @opselect_float_scalar_test(float %x, float %y) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#func_type]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#bool]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: %[[#fselect:]] = OpSelect %[[#float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: OpReturnValue %[[#fselect]]
+ %0 = fcmp ogt float %x, %y
+ %1 = select i1 %0, float %x, float %y
+ ret float %1
+}
+
+define spir_func <4 x float> @opselect_float4_vec_test(<4 x float> %x, <4 x float> %y) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#vec_func_type]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#vec_4_bool]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: %[[#fselect:]] = OpSelect %[[#vec4_float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: OpReturnValue %[[#fselect]]
+ %0 = fcmp ogt <4 x float> %x, %y
+ %1 = select <4 x i1> %0, <4 x float> %x, <4 x float> %y
+ ret <4 x float> %1
+}
+
+define spir_func <4 x float> @opselect_scalar_bool_float4_vec_test(float %a, float %b, <4 x float> %x, <4 x float> %y) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#arg3:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#bool]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: %[[#fselect:]] = OpSelect %[[#vec4_float_32]] %[[#fcmp]] %[[#arg2]] %[[#arg3]]
+ ; CHECK: OpReturnValue %[[#fselect]]
+ %0 = fcmp ogt float %a, %b
+ %1 = select i1 %0, <4 x float> %x, <4 x float> %y
+ ret <4 x float> %1
+}
\ No newline at end of file
More information about the llvm-commits
mailing list