[llvm] [SPIRV] Create a new OpSelect selector and fix register types. (PR #152311)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 6 11:49:15 PDT 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/152311
>From d7f21b87037f84d1ba4fba433009f1c233373e9b Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Tue, 5 Aug 2025 13:15:40 -0400
Subject: [PATCH 1/2] [SPIRV] Create a new OpSelect selector and fix register
types.
fixes #135572
There are two problems that are causing problems first register types
are copied from older registers instead of evaluating the spirv types.
Second the way OpSelect is defined in SPIRVInstrInfo.td we always default
to integer for TernOpTyped. There seems to be a problem of multiple
matches in the getMatchTable so when executeMatchTable runs we aren't
getting the right opSelect.
Correcting the tablegen wasn't very easy so instead created an emitter
for Select that evaluated the register types. this passes the original
llvm/test/CodeGen/SPIRV/instructions/select.ll tests and the new float
ones I'm adding in issue-135572-emit-float-opselect.ll
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 67 ++++++++++++++++++-
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 12 ++--
.../issue-135572-emit-float-opselect.ll | 38 +++++++++++
3 files changed, 108 insertions(+), 9 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e9f5ffa23e220..bb0b7e4b03ae1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -222,6 +222,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
bool IsSigned) const;
+ bool selectSelectDefaultArgs(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I, bool IsSigned) const;
bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
bool IsSigned, unsigned Opcode) const;
bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
@@ -509,7 +511,18 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
if (isTypeFoldingSupported(Def->getOpcode()) &&
Def->getOpcode() != TargetOpcode::G_CONSTANT &&
Def->getOpcode() != TargetOpcode::G_FCONSTANT) {
- bool Res = selectImpl(I, *CoverageInfo);
+ bool Res = false;
+ if (Def->getOpcode() == TargetOpcode::G_SELECT) {
+ Register SelectDstReg = Def->getOperand(0).getReg();
+ Res = selectSelect(SelectDstReg, GR.getSPIRVTypeForVReg(SelectDstReg),
+ *Def, true);
+ GR.invalidateMachineInstr(Def);
+ Def->removeFromParent();
+ MRI->replaceRegWith(DstReg, SelectDstReg);
+ GR.invalidateMachineInstr(&I);
+ I.removeFromParent();
+ } else
+ Res = selectImpl(I, *CoverageInfo);
LLVM_DEBUG({
if (!Res && Def->getOpcode() != TargetOpcode::G_CONSTANT) {
dbgs() << "Unexpected pattern in ASSIGN_TYPE.\nInstruction: ";
@@ -2566,6 +2579,53 @@ bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
bool IsSigned) const {
+ bool IsFloatTy =
+ GR.isScalarOrVectorOfType(I.getOperand(2).getReg(), SPIRV::OpTypeFloat) ||
+ GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypeFloat);
+
+ bool IsPtrTy =
+ GR.isScalarOrVectorOfType(I.getOperand(2).getReg(),
+ SPIRV::OpTypePointer) ||
+ GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypePointer);
+ bool IsVectorTy =
+ GR.getSPIRVTypeForVReg(I.getOperand(2).getReg())->getOpcode() ==
+ SPIRV::OpTypeVector ||
+ GR.getSPIRVTypeForVReg(I.getOperand(3).getReg())->getOpcode() ==
+ SPIRV::OpTypeVector;
+
+ bool IsScalarBool =
+ GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
+ unsigned Opcode;
+ if (IsVectorTy) {
+ if (IsFloatTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectVFSCond : SPIRV::OpSelectVFVCond;
+ } else if (IsPtrTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectVPSCond : SPIRV::OpSelectVPVCond;
+ } else {
+ Opcode = IsScalarBool ? SPIRV::OpSelectVISCond : SPIRV::OpSelectVIVCond;
+ }
+ } else {
+ if (IsFloatTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectSFSCond : SPIRV::OpSelectVFVCond;
+ } else if (IsPtrTy) {
+ Opcode = IsScalarBool ? SPIRV::OpSelectSPSCond : SPIRV::OpSelectVPVCond;
+ } else {
+ Opcode = IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+ }
+ }
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(1).getReg())
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectSelectDefaultArgs(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
// To extend a bool, we need to use OpSelect between constants.
Register ZeroReg = buildZerosVal(ResType, I);
Register OneReg = buildOnesVal(IsSigned, ResType, I);
@@ -2573,6 +2633,7 @@ bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
unsigned Opcode =
IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
@@ -2597,7 +2658,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg,
TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII);
}
SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF());
- selectSelect(SrcReg, TmpType, I, false);
+ selectSelectDefaultArgs(SrcReg, TmpType, I, false);
}
return selectOpWithSrcs(ResVReg, ResType, I, {SrcReg}, Opcode);
}
@@ -2607,7 +2668,7 @@ bool SPIRVInstructionSelector::selectExt(Register ResVReg,
MachineInstr &I, bool IsSigned) const {
Register SrcReg = I.getOperand(1).getReg();
if (GR.isScalarOrVectorOfType(SrcReg, SPIRV::OpTypeBool))
- return selectSelect(ResVReg, ResType, I, IsSigned);
+ return selectSelectDefaultArgs(ResVReg, ResType, I, IsSigned);
SPIRVType *SrcType = GR.getSPIRVTypeForVReg(SrcReg);
if (SrcType == ResType)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index f4b4846f70d7d..012d13b08b529 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -440,13 +440,13 @@ void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
// Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
// present after each auto-folded instruction to take a type reference from.
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
- if (auto *RC = MRI.getRegClassOrNull(Reg)) {
- MRI.setRegClass(NewReg, RC);
- } else {
- auto RegClass = GR->getRegClass(SpvType);
+ auto RegClass = GR->getRegClass(SpvType);
+ MRI.setRegClass(NewReg, RegClass);
+ MRI.setRegClass(Reg, RegClass);
+
+ if (auto *RC = MRI.getRegClassOrNull(Reg); RC != RegClass)
MRI.setRegClass(NewReg, RegClass);
- MRI.setRegClass(Reg, RegClass);
- }
+
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
// This is to make it convenient for Legalizer to get the SPIRVType
// when processing the actual MI (i.e. not pseudo one).
diff --git a/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
new file mode 100644
index 0000000000000..2502770cb18e1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
@@ -0,0 +1,38 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %}
+
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#func_type:]] = OpTypeFunction %[[#float_32]] %[[#float_32]] %[[#float_32]]
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec_func_type:]] = OpTypeFunction %[[#vec4_float_32]] %[[#vec4_float_32]] %[[#vec4_float_32]]
+; CHECK-DAG: %[[#vec_4_bool:]] = OpTypeVector %[[#bool]] 4
+
+define spir_func float @opselect_float_scalar_test(float %x, float %y) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#func_type]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#float_32]]
+ ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#bool]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: %[[#fselect:]] = OpSelect %[[#float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: OpReturnValue %[[#fselect]]
+ %0 = fcmp ogt float %x, %y
+ %1 = select i1 %0, float %x, float %y
+ ret float %1
+}
+
+define spir_func <4 x float> @opselect_float4_vec_test(<4 x float> %x, <4 x float> %y) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#vec_func_type]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#vec_4_bool]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: %[[#fselect:]] = OpSelect %[[#vec4_float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]]
+ ; CHECK: OpReturnValue %[[#fselect]]
+ %0 = fcmp ogt <4 x float> %x, %y
+ %1 = select <4 x i1> %0, <4 x float> %x, <4 x float> %y
+ ret <4 x float> %1
+}
>From e442e6c38de3d735915b3b64f6f6d4920c3b6300 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 6 Aug 2025 14:49:02 -0400
Subject: [PATCH 2/2] fix pr issues
---
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 7 ++-----
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 3 ---
2 files changed, 2 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index bb0b7e4b03ae1..149291ef8054d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -220,8 +220,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectConst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
- bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
- bool IsSigned) const;
+ bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const;
bool selectSelectDefaultArgs(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsSigned) const;
bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
@@ -2577,8 +2576,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
const SPIRVType *ResType,
- MachineInstr &I,
- bool IsSigned) const {
+ MachineInstr &I) const {
bool IsFloatTy =
GR.isScalarOrVectorOfType(I.getOperand(2).getReg(), SPIRV::OpTypeFloat) ||
GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypeFloat);
@@ -2633,7 +2631,6 @@ bool SPIRVInstructionSelector::selectSelectDefaultArgs(Register ResVReg,
GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
unsigned Opcode =
IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
-
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 012d13b08b529..5a95d94b9213d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -444,9 +444,6 @@ void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
MRI.setRegClass(NewReg, RegClass);
MRI.setRegClass(Reg, RegClass);
- if (auto *RC = MRI.getRegClassOrNull(Reg); RC != RegClass)
- MRI.setRegClass(NewReg, RegClass);
-
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
// This is to make it convenient for Legalizer to get the SPIRVType
// when processing the actual MI (i.e. not pseudo one).
More information about the llvm-commits
mailing list