[llvm] [SPIRV] Create a new OpSelect selector and fix register types. (PR #152311)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 8 08:47:42 PDT 2025


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/152311

>From aaadb2eecbb94791b5a81ecd9265969d64160e05 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Tue, 5 Aug 2025 13:15:40 -0400
Subject: [PATCH 1/3] [SPIRV] Create a new OpSelect selector and fix register
 types.

fixes #135572

There are two problems that are causing problems first register types
are copied from older registers instead of evaluating the spirv types.

Second the way OpSelect is defined in SPIRVInstrInfo.td we always default
to integer for TernOpTyped. There seems to be a problem of multiple
matches in the getMatchTable so when executeMatchTable runs we aren't
getting the right opSelect.

Correcting the tablegen wasn't very easy so instead created an emitter
for Select that evaluated the register types. this passes the original
llvm/test/CodeGen/SPIRV/instructions/select.ll tests and the new float
ones I'm adding in issue-135572-emit-float-opselect.ll
---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 67 ++++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   | 12 ++--
 .../issue-135572-emit-float-opselect.ll       | 38 +++++++++++
 3 files changed, 108 insertions(+), 9 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5259db1ff2dd7..e3d423be85991 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -222,6 +222,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
                     bool IsSigned) const;
+  bool selectSelectDefaultArgs(Register ResVReg, const SPIRVType *ResType,
+                               MachineInstr &I, bool IsSigned) const;
   bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
                   bool IsSigned, unsigned Opcode) const;
   bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
@@ -510,7 +512,18 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
       if (isTypeFoldingSupported(Def->getOpcode()) &&
           Def->getOpcode() != TargetOpcode::G_CONSTANT &&
           Def->getOpcode() != TargetOpcode::G_FCONSTANT) {
-        bool Res = selectImpl(I, *CoverageInfo);
+        bool Res = false;
+        if (Def->getOpcode() == TargetOpcode::G_SELECT) {
+          Register SelectDstReg = Def->getOperand(0).getReg();
+          Res = selectSelect(SelectDstReg, GR.getSPIRVTypeForVReg(SelectDstReg),
+                             *Def, true);
+          GR.invalidateMachineInstr(Def);
+          Def->removeFromParent();
+          MRI->replaceRegWith(DstReg, SelectDstReg);
+          GR.invalidateMachineInstr(&I);
+          I.removeFromParent();
+        } else
+          Res = selectImpl(I, *CoverageInfo);
         LLVM_DEBUG({
           if (!Res && Def->getOpcode() != TargetOpcode::G_CONSTANT) {
             dbgs() << "Unexpected pattern in ASSIGN_TYPE.\nInstruction: ";
@@ -2567,6 +2580,53 @@ bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
                                             const SPIRVType *ResType,
                                             MachineInstr &I,
                                             bool IsSigned) const {
+  bool IsFloatTy =
+      GR.isScalarOrVectorOfType(I.getOperand(2).getReg(), SPIRV::OpTypeFloat) ||
+      GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypeFloat);
+
+  bool IsPtrTy =
+      GR.isScalarOrVectorOfType(I.getOperand(2).getReg(),
+                                SPIRV::OpTypePointer) ||
+      GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypePointer);
+  bool IsVectorTy =
+      GR.getSPIRVTypeForVReg(I.getOperand(2).getReg())->getOpcode() ==
+          SPIRV::OpTypeVector ||
+      GR.getSPIRVTypeForVReg(I.getOperand(3).getReg())->getOpcode() ==
+          SPIRV::OpTypeVector;
+
+  bool IsScalarBool =
+      GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
+  unsigned Opcode;
+  if (IsVectorTy) {
+    if (IsFloatTy) {
+      Opcode = IsScalarBool ? SPIRV::OpSelectVFSCond : SPIRV::OpSelectVFVCond;
+    } else if (IsPtrTy) {
+      Opcode = IsScalarBool ? SPIRV::OpSelectVPSCond : SPIRV::OpSelectVPVCond;
+    } else {
+      Opcode = IsScalarBool ? SPIRV::OpSelectVISCond : SPIRV::OpSelectVIVCond;
+    }
+  } else {
+    if (IsFloatTy) {
+      Opcode = IsScalarBool ? SPIRV::OpSelectSFSCond : SPIRV::OpSelectVFVCond;
+    } else if (IsPtrTy) {
+      Opcode = IsScalarBool ? SPIRV::OpSelectSPSCond : SPIRV::OpSelectVPVCond;
+    } else {
+      Opcode = IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+    }
+  }
+  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(I.getOperand(1).getReg())
+      .addUse(I.getOperand(2).getReg())
+      .addUse(I.getOperand(3).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectSelectDefaultArgs(Register ResVReg,
+                                                       const SPIRVType *ResType,
+                                                       MachineInstr &I,
+                                                       bool IsSigned) const {
   // To extend a bool, we need to use OpSelect between constants.
   Register ZeroReg = buildZerosVal(ResType, I);
   Register OneReg = buildOnesVal(IsSigned, ResType, I);
@@ -2574,6 +2634,7 @@ bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
       GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
   unsigned Opcode =
       IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
+
   return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
@@ -2598,7 +2659,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg,
       TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII);
     }
     SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF());
-    selectSelect(SrcReg, TmpType, I, false);
+    selectSelectDefaultArgs(SrcReg, TmpType, I, false);
   }
   return selectOpWithSrcs(ResVReg, ResType, I, {SrcReg}, Opcode);
 }
@@ -2608,7 +2669,7 @@ bool SPIRVInstructionSelector::selectExt(Register ResVReg,
                                          MachineInstr &I, bool IsSigned) const {
   Register SrcReg = I.getOperand(1).getReg();
   if (GR.isScalarOrVectorOfType(SrcReg, SPIRV::OpTypeBool))
-    return selectSelect(ResVReg, ResType, I, IsSigned);
+    return selectSelectDefaultArgs(ResVReg, ResType, I, IsSigned);
 
   SPIRVType *SrcType = GR.getSPIRVTypeForVReg(SrcReg);
   if (SrcType == ResType)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b62db7fd62b2e..569caa96adbdf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -441,13 +441,13 @@ void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
   // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
   // present after each auto-folded instruction to take a type reference from.
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
-  if (auto *RC = MRI.getRegClassOrNull(Reg)) {
-    MRI.setRegClass(NewReg, RC);
-  } else {
-    auto RegClass = GR->getRegClass(SpvType);
+  auto RegClass = GR->getRegClass(SpvType);
+  MRI.setRegClass(NewReg, RegClass);
+  MRI.setRegClass(Reg, RegClass);
+
+  if (auto *RC = MRI.getRegClassOrNull(Reg); RC != RegClass)
     MRI.setRegClass(NewReg, RegClass);
-    MRI.setRegClass(Reg, RegClass);
-  }
+
   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
   // This is to make it convenient for Legalizer to get the SPIRVType
   // when processing the actual MI (i.e. not pseudo one).
diff --git a/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
new file mode 100644
index 0000000000000..2502770cb18e1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll
@@ -0,0 +1,38 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %}
+
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#func_type:]] = OpTypeFunction %[[#float_32]] %[[#float_32]] %[[#float_32]]
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec_func_type:]] = OpTypeFunction %[[#vec4_float_32]] %[[#vec4_float_32]] %[[#vec4_float_32]]
+; CHECK-DAG:  %[[#vec_4_bool:]] = OpTypeVector %[[#bool]] 4
+
+define spir_func float @opselect_float_scalar_test(float %x, float %y) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#func_type]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#float_32]]
+  ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan  %[[#bool]]  %[[#arg0]] %[[#arg1]]
+  ; CHECK: %[[#fselect:]] = OpSelect %[[#float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]]
+  ; CHECK: OpReturnValue %[[#fselect]]
+  %0 = fcmp ogt float %x, %y
+  %1 = select i1 %0, float %x, float %y
+  ret float %1
+}
+
+define spir_func <4 x float> @opselect_float4_vec_test(<4 x float>  %x, <4 x float>  %y) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#vec_func_type]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+  ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan  %[[#vec_4_bool]]  %[[#arg0]] %[[#arg1]]
+  ; CHECK: %[[#fselect:]] = OpSelect %[[#vec4_float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]]
+  ; CHECK: OpReturnValue %[[#fselect]]
+  %0 = fcmp ogt <4 x float> %x, %y
+  %1 = select <4 x i1> %0, <4 x float> %x, <4 x float> %y
+  ret <4 x float>  %1
+}

>From 9db9a2b4a0685fb5ba0af07479cc7be8663d48ce Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 6 Aug 2025 14:49:02 -0400
Subject: [PATCH 2/3] fix pr issues

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 10 ++++------
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp        |  5 +----
 2 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e3d423be85991..4a75d7beb06c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -220,8 +220,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectConst(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
-  bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
-                    bool IsSigned) const;
+  bool selectSelect(Register ResVReg, const SPIRVType *ResType,
+                    MachineInstr &I) const;
   bool selectSelectDefaultArgs(Register ResVReg, const SPIRVType *ResType,
                                MachineInstr &I, bool IsSigned) const;
   bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
@@ -516,7 +516,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
         if (Def->getOpcode() == TargetOpcode::G_SELECT) {
           Register SelectDstReg = Def->getOperand(0).getReg();
           Res = selectSelect(SelectDstReg, GR.getSPIRVTypeForVReg(SelectDstReg),
-                             *Def, true);
+                             *Def);
           GR.invalidateMachineInstr(Def);
           Def->removeFromParent();
           MRI->replaceRegWith(DstReg, SelectDstReg);
@@ -2578,8 +2578,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
 
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
                                             const SPIRVType *ResType,
-                                            MachineInstr &I,
-                                            bool IsSigned) const {
+                                            MachineInstr &I) const {
   bool IsFloatTy =
       GR.isScalarOrVectorOfType(I.getOperand(2).getReg(), SPIRV::OpTypeFloat) ||
       GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypeFloat);
@@ -2634,7 +2633,6 @@ bool SPIRVInstructionSelector::selectSelectDefaultArgs(Register ResVReg,
       GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
   unsigned Opcode =
       IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond;
-
   return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 569caa96adbdf..1a08c6ac0dcaf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -441,13 +441,10 @@ void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
   // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
   // present after each auto-folded instruction to take a type reference from.
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
-  auto RegClass = GR->getRegClass(SpvType);
+  const auto *RegClass = GR->getRegClass(SpvType);
   MRI.setRegClass(NewReg, RegClass);
   MRI.setRegClass(Reg, RegClass);
 
-  if (auto *RC = MRI.getRegClassOrNull(Reg); RC != RegClass)
-    MRI.setRegClass(NewReg, RegClass);
-
   GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
   // This is to make it convenient for Legalizer to get the SPIRVType
   // when processing the actual MI (i.e. not pseudo one).

>From 9e577ab9e538a245a044ca9348fbe16e82628fbe Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 8 Aug 2025 11:47:27 -0400
Subject: [PATCH 3/3] address pr comments

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 24 +++++++++----------
 1 file changed, 11 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 4a75d7beb06c9..98c7709acf938 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2579,19 +2579,17 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
                                             const SPIRVType *ResType,
                                             MachineInstr &I) const {
-  bool IsFloatTy =
-      GR.isScalarOrVectorOfType(I.getOperand(2).getReg(), SPIRV::OpTypeFloat) ||
-      GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypeFloat);
+  Register SelectFirstArg = I.getOperand(2).getReg();
+  Register SelectSecondArg = I.getOperand(3).getReg();
+  assert(ResType == GR.getSPIRVTypeForVReg(SelectFirstArg) &&
+         ResType == GR.getSPIRVTypeForVReg(SelectSecondArg));
 
+  bool IsFloatTy =
+      GR.isScalarOrVectorOfType(SelectFirstArg, SPIRV::OpTypeFloat);
   bool IsPtrTy =
-      GR.isScalarOrVectorOfType(I.getOperand(2).getReg(),
-                                SPIRV::OpTypePointer) ||
-      GR.isScalarOrVectorOfType(I.getOperand(3).getReg(), SPIRV::OpTypePointer);
-  bool IsVectorTy =
-      GR.getSPIRVTypeForVReg(I.getOperand(2).getReg())->getOpcode() ==
-          SPIRV::OpTypeVector ||
-      GR.getSPIRVTypeForVReg(I.getOperand(3).getReg())->getOpcode() ==
-          SPIRV::OpTypeVector;
+      GR.isScalarOrVectorOfType(SelectFirstArg, SPIRV::OpTypePointer);
+  bool IsVectorTy = GR.getSPIRVTypeForVReg(SelectFirstArg)->getOpcode() ==
+                    SPIRV::OpTypeVector;
 
   bool IsScalarBool =
       GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool);
@@ -2617,8 +2615,8 @@ bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
       .addUse(I.getOperand(1).getReg())
-      .addUse(I.getOperand(2).getReg())
-      .addUse(I.getOperand(3).getReg())
+      .addUse(SelectFirstArg)
+      .addUse(SelectSecondArg)
       .constrainAllUses(TII, TRI, RBI);
 }
 



More information about the llvm-commits mailing list