[llvm] [NFC][AMDGPU] Move cmp+select arguments optimization to SIISelLowering. (PR #150929)

Daniil Fukalov via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 28 04:37:34 PDT 2025


https://github.com/dfukalov created https://github.com/llvm/llvm-project/pull/150929

As requested in #148740.

>From 03b089b9b7e1aace9313ed128e7012bba7921fa5 Mon Sep 17 00:00:00 2001
From: Daniil Fukalov <dfukalov at gmail.com>
Date: Mon, 28 Jul 2025 13:34:06 +0200
Subject: [PATCH] [NFC][AMDGPU] Move cmp+select arguments optimization to
 SIISelLowering.

As requested in #148740.
---
 llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 83 -------------------
 llvm/lib/Target/AMDGPU/SIISelLowering.cpp     | 76 +++++++++++++++++
 llvm/lib/Target/AMDGPU/SIISelLowering.h       |  1 +
 3 files changed, 77 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index f25ce8723a2dc..61189337e5233 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4846,94 +4846,11 @@ AMDGPUTargetLowering::foldFreeOpFromSelect(TargetLowering::DAGCombinerInfo &DCI,
   return SDValue();
 }
 
-// Detect when CMP and SELECT use the same constant and fold them to avoid
-// loading the constant twice. Specifically handles patterns like:
-// %cmp = icmp eq i32 %val, 4242
-// %sel = select i1 %cmp, i32 4242, i32 %other
-// It can be optimized to reuse %val instead of 4242 in select.
-static SDValue
-foldCmpSelectWithSharedConstant(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
-                                const AMDGPUSubtarget *ST) {
-  SDValue Cond = N->getOperand(0);
-  SDValue TrueVal = N->getOperand(1);
-  SDValue FalseVal = N->getOperand(2);
-
-  // Check if condition is a comparison.
-  if (Cond.getOpcode() != ISD::SETCC)
-    return SDValue();
-
-  SDValue LHS = Cond.getOperand(0);
-  SDValue RHS = Cond.getOperand(1);
-  ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
-
-  bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
-  bool isInteger = LHS.getValueType().isInteger();
-
-  // Handle simple floating-point and integer types only.
-  if (!isFloatingPoint && !isInteger)
-    return SDValue();
-
-  bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
-  bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
-  if (!isEquality && !isNonEquality)
-    return SDValue();
-
-  SDValue ArgVal, ConstVal;
-  if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
-      (isInteger && isa<ConstantSDNode>(RHS))) {
-    ConstVal = RHS;
-    ArgVal = LHS;
-  } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
-             (isInteger && isa<ConstantSDNode>(LHS))) {
-    ConstVal = LHS;
-    ArgVal = RHS;
-  } else {
-    return SDValue();
-  }
-
-  // Check if constant should not be optimized - early return if not.
-  if (isFloatingPoint) {
-    const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
-    const GCNSubtarget *GCNST = static_cast<const GCNSubtarget *>(ST);
-
-    // Only optimize normal floating-point values (finite, non-zero, and
-    // non-subnormal as per IEEE 754), skip optimization for inlinable
-    // floating-point constants.
-    if (!Val.isNormal() || GCNST->getInstrInfo()->isInlineConstant(Val))
-      return SDValue();
-  } else {
-    int64_t IntVal = cast<ConstantSDNode>(ConstVal)->getSExtValue();
-
-    // Skip optimization for inlinable integer immediates.
-    // Inlinable immediates include: -16 to 64 (inclusive).
-    if (IntVal >= -16 && IntVal <= 64)
-      return SDValue();
-  }
-
-  // For equality and non-equality comparisons, patterns:
-  // select (setcc x, const), const, y -> select (setcc x, const), x, y
-  // select (setccinv x, const), y, const -> select (setccinv x, const), y, x
-  if (!(isEquality && TrueVal == ConstVal) &&
-      !(isNonEquality && FalseVal == ConstVal))
-    return SDValue();
-
-  SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
-  SDValue SelectRHS =
-      (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
-  return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
-                         SelectLHS, SelectRHS);
-}
-
 SDValue AMDGPUTargetLowering::performSelectCombine(SDNode *N,
                                                    DAGCombinerInfo &DCI) const {
   if (SDValue Folded = foldFreeOpFromSelect(DCI, SDValue(N, 0)))
     return Folded;
 
-  // Try to fold CMP + SELECT patterns with shared constants (both FP and
-  // integer).
-  if (SDValue Folded = foldCmpSelectWithSharedConstant(N, DCI, Subtarget))
-    return Folded;
-
   SDValue Cond = N->getOperand(0);
   if (Cond.getOpcode() != ISD::SETCC)
     return SDValue();
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 8d51ec6dc7f31..0fac43fa78e82 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -15896,6 +15896,78 @@ SDValue SITargetLowering::performClampCombine(SDNode *N,
   return SDValue(CSrc, 0);
 }
 
+SDValue SITargetLowering::performSelectCombine(SDNode *N,
+                                               DAGCombinerInfo &DCI) const {
+
+  // Try to fold CMP + SELECT patterns with shared constants (both FP and
+  // integer).
+  // Detect when CMP and SELECT use the same constant and fold them to avoid
+  // loading the constant twice. Specifically handles patterns like:
+  // %cmp = icmp eq i32 %val, 4242
+  // %sel = select i1 %cmp, i32 4242, i32 %other
+  // It can be optimized to reuse %val instead of 4242 in select.
+  SDValue Cond = N->getOperand(0);
+  SDValue TrueVal = N->getOperand(1);
+  SDValue FalseVal = N->getOperand(2);
+
+  // Check if condition is a comparison.
+  if (Cond.getOpcode() != ISD::SETCC)
+    return SDValue();
+
+  SDValue LHS = Cond.getOperand(0);
+  SDValue RHS = Cond.getOperand(1);
+  ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
+
+  bool isFloatingPoint = LHS.getValueType().isFloatingPoint();
+  bool isInteger = LHS.getValueType().isInteger();
+
+  // Handle simple floating-point and integer types only.
+  if (!isFloatingPoint && !isInteger)
+    return SDValue();
+
+  bool isEquality = CC == (isFloatingPoint ? ISD::SETOEQ : ISD::SETEQ);
+  bool isNonEquality = CC == (isFloatingPoint ? ISD::SETONE : ISD::SETNE);
+  if (!isEquality && !isNonEquality)
+    return SDValue();
+
+  SDValue ArgVal, ConstVal;
+  if ((isFloatingPoint && isa<ConstantFPSDNode>(RHS)) ||
+      (isInteger && isa<ConstantSDNode>(RHS))) {
+    ConstVal = RHS;
+    ArgVal = LHS;
+  } else if ((isFloatingPoint && isa<ConstantFPSDNode>(LHS)) ||
+             (isInteger && isa<ConstantSDNode>(LHS))) {
+    ConstVal = LHS;
+    ArgVal = RHS;
+  } else {
+    return SDValue();
+  }
+
+  // Skip optimization for inlinable immediates.
+  if (isFloatingPoint) {
+    const APFloat &Val = cast<ConstantFPSDNode>(ConstVal)->getValueAPF();
+    if (!Val.isNormal() ||
+        Subtarget->getInstrInfo()->isInlineConstant(Val))
+      return SDValue();
+  } else {
+    if (AMDGPU::isInlinableIntLiteral(cast<ConstantSDNode>(ConstVal)->getSExtValue()))
+      return SDValue();
+  }
+
+  // For equality and non-equality comparisons, patterns:
+  // select (setcc x, const), const, y -> select (setcc x, const), x, y
+  // select (setccinv x, const), y, const -> select (setccinv x, const), y, x
+  if (!(isEquality && TrueVal == ConstVal) &&
+      !(isNonEquality && FalseVal == ConstVal))
+    return SDValue();
+
+  SDValue SelectLHS = (isEquality && TrueVal == ConstVal) ? ArgVal : TrueVal;
+  SDValue SelectRHS =
+      (isNonEquality && FalseVal == ConstVal) ? ArgVal : FalseVal;
+  return DCI.DAG.getNode(ISD::SELECT, SDLoc(N), N->getValueType(0), Cond,
+                         SelectLHS, SelectRHS);
+}
+
 SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
                                             DAGCombinerInfo &DCI) const {
   switch (N->getOpcode()) {
@@ -15944,6 +16016,10 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
     return performFMulCombine(N, DCI);
   case ISD::SETCC:
     return performSetCCCombine(N, DCI);
+  case ISD::SELECT:
+    if (auto Res = performSelectCombine(N, DCI))
+      return Res;
+    break;
   case ISD::FMAXNUM:
   case ISD::FMINNUM:
   case ISD::FMAXNUM_IEEE:
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h
index acf6158572a4d..dedd9ae170774 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h
@@ -211,6 +211,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
   SDValue performExtractVectorEltCombine(SDNode *N, DAGCombinerInfo &DCI) const;
   SDValue performInsertVectorEltCombine(SDNode *N, DAGCombinerInfo &DCI) const;
   SDValue performFPRoundCombine(SDNode *N, DAGCombinerInfo &DCI) const;
+  SDValue performSelectCombine(SDNode *N, DAGCombinerInfo &DCI) const;
 
   SDValue reassociateScalarOps(SDNode *N, SelectionDAG &DAG) const;
   unsigned getFusedOpcode(const SelectionDAG &DAG,



More information about the llvm-commits mailing list