[llvm] [SDAG] Make Select-with-Identity-Fold More Flexible; NFC (PR #136554)

Marius Kamp via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 21 21:02:06 PDT 2025


https://github.com/mskamp updated https://github.com/llvm/llvm-project/pull/136554

>From 1c2be64493798a299c534f4ffb6fd2576e048676 Mon Sep 17 00:00:00 2001
From: Marius Kamp <msk at posteo.org>
Date: Fri, 18 Apr 2025 11:04:22 +0200
Subject: [PATCH] [SDAG] Make Select-with-Identity-Fold More Flexible; NFC

This change adds new parameters to the method
`shouldFoldSelectWithIdentityConstant()`. The method now takes the
opcode of the select node and the non-identity operand of the select
node. To gain access to the appropriate arguments, the call of
`shouldFoldSelectWithIdentityConstant()` is moved after all other checks
have been performed. Moreover, this change adjusts the precondition of
the fold so that it would work for `SELECT` nodes in addition to
`VSELECT` nodes.

No functional change is intended because all implementations of
`shouldFoldSelectWithIdentityConstant()` are adjusted such that they
restrict the fold to a `VSELECT` node; the same restriction as before.

The rationale of this change is to make more fine grained decisions
possible when to revert the InstCombine canonicalization of
`(select c (binop x y) y)` to `(binop (select c x idc) y)` in the
backends.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  6 ++--
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 32 +++++++++++--------
 .../Target/AArch64/AArch64ISelLowering.cpp    |  6 ++--
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  6 ++--
 llvm/lib/Target/ARM/ARMISelLowering.cpp       |  8 +++--
 llvm/lib/Target/ARM/ARMISelLowering.h         |  6 ++--
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  8 +++--
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  6 ++--
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  7 ++--
 llvm/lib/Target/X86/X86ISelLowering.h         |  6 ++--
 10 files changed, 58 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 00c36266a069f..f71fac0df137b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3353,8 +3353,10 @@ class TargetLoweringBase {
   /// Return true if pulling a binary operation into a select with an identity
   /// constant is profitable. This is the inverse of an IR transform.
   /// Example: X + (Cond ? Y : 0) --> Cond ? (X + Y) : X
-  virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
-                                                    EVT VT) const {
+  virtual bool
+  shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+                                       unsigned SelectOpcode, SDValue X,
+                                       SDValue NonIdConstNode) const {
     return false;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b175e35385ec6..7c8619aa29346 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2425,8 +2425,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
   if (ShouldCommuteOperands)
     std::swap(N0, N1);
 
-  // TODO: Should this apply to scalar select too?
-  if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
+  unsigned SelOpcode = N1.getOpcode();
+  if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
+      !N1.hasOneUse())
     return SDValue();
 
   // We can't hoist all instructions because of immediate UB (not speculatable).
@@ -2439,17 +2440,22 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
   SDValue Cond = N1.getOperand(0);
   SDValue TVal = N1.getOperand(1);
   SDValue FVal = N1.getOperand(2);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   // This transform increases uses of N0, so freeze it to be safe.
   // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
   unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
-  if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
+  if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
+      TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
+                                               FVal)) {
     SDValue F0 = DAG.getFreeze(N0);
     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
     return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
   }
   // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
-  if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
+  if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
+      TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
+                                               TVal)) {
     SDValue F0 = DAG.getFreeze(N0);
     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
     return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
@@ -2459,26 +2465,23 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
 }
 
 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
          "Unexpected binary operator");
 
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  auto BinOpcode = BO->getOpcode();
-  EVT VT = BO->getValueType(0);
-  if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
-    if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
-      return Sel;
+  if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
+    return Sel;
 
-    if (TLI.isCommutativeBinOp(BO->getOpcode()))
-      if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
-        return Sel;
-  }
+  if (TLI.isCommutativeBinOp(BO->getOpcode()))
+    if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
+      return Sel;
 
   // Don't do this unless the old select is going away. We want to eliminate the
   // binary operator, not replace a binop with a select.
   // TODO: Handle ISD::SELECT_CC.
   unsigned SelOpNo = 0;
   SDValue Sel = BO->getOperand(0);
+  auto BinOpcode = BO->getOpcode();
   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
     SelOpNo = 1;
     Sel = BO->getOperand(1);
@@ -2526,6 +2529,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
 
   SDLoc DL(Sel);
   SDValue NewCT, NewCF;
+  EVT VT = BO->getValueType(0);
 
   if (CanFoldNonConst) {
     // If CBO is an opaque constant, we can't rely on getNode to constant fold.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 771eee1b3fecf..9d254afa524df 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18040,8 +18040,10 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
 }
 
 bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
-    unsigned BinOpcode, EVT VT) const {
-  return VT.isScalableVector() && isTypeLegal(VT);
+    unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+    SDValue NonIdConstNode) const {
+  return VT.isScalableVector() && isTypeLegal(VT) &&
+         SelectOpcode == ISD::VSELECT;
 }
 
 bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 0d51ef2be8631..ff3f7fcd2d7c1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -786,8 +786,10 @@ class AArch64TargetLowering : public TargetLowering {
   bool shouldFoldConstantShiftPairToMask(const SDNode *N,
                                          CombineLevel Level) const override;
 
-  bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
-                                            EVT VT) const override;
+  bool
+  shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+                                       unsigned SelectOpcode, SDValue X,
+                                       SDValue NonIdConstNode) const override;
 
   /// Returns true if it is beneficial to convert a load of a constant
   /// to just the constant itself.
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 2290ac2728c6d..9900f6e958498 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -13960,9 +13960,11 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
   return false;
 }
 
-bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
-                                                             EVT VT) const {
-  return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT);
+bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(
+    unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+    SDValue NonIdConstNode) const {
+  return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT) &&
+         SelectOpcode == ISD::VSELECT;
 }
 
 bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index 9fad056edd3f1..dd366eca80461 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -758,8 +758,10 @@ class VectorType;
     bool shouldFoldConstantShiftPairToMask(const SDNode *N,
                                            CombineLevel Level) const override;
 
-    bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
-                                              EVT VT) const override;
+    bool
+    shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+                                         unsigned SelectOpcode, SDValue X,
+                                         SDValue NonIdConstNode) const override;
 
     bool preferIncOfAddToSubOfNot(EVT VT) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 98fba9e86e88a..4d8c8a578f30e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2090,8 +2090,12 @@ bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
   return C && C->getAPIntValue().ule(10);
 }
 
-bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
-                                                               EVT VT) const {
+bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(
+    unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+    SDValue NonIdConstNode) const {
+  if (SelectOpcode != ISD::VSELECT)
+    return false;
+
   // Only enable for rvv.
   if (!VT.isVector() || !Subtarget.hasVInstructions())
     return false;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index baf1b2e4d8e6e..ed6e3dbbee797 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -585,8 +585,10 @@ class RISCVTargetLowering : public TargetLowering {
                                                 unsigned &NumIntermediates,
                                                 MVT &RegisterVT) const override;
 
-  bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
-                                            EVT VT) const override;
+  bool
+  shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+                                       unsigned SelectOpcode, SDValue X,
+                                       SDValue NonIdConstNode) const override;
 
   /// Return true if the given shuffle mask can be codegen'd directly, or if it
   /// should be stack expanded.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 993118c52564e..4f35cef3acf0b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -35383,8 +35383,11 @@ bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
   return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
 }
 
-bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
-                                                             EVT VT) const {
+bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(
+    unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+    SDValue NonIdConstNode) const {
+  if (SelectOpcode != ISD::VSELECT)
+    return false;
   // TODO: This is too general. There are cases where pre-AVX512 codegen would
   //       benefit. The transform may also be profitable for scalar code.
   if (!Subtarget.hasAVX512())
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 4a2b35e9efe7c..89b817e02ade7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1460,8 +1460,10 @@ namespace llvm {
     /// from i32 to i16.
     bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;
 
-    bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
-                                              EVT VT) const override;
+    bool
+    shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+                                         unsigned SelectOpcode, SDValue X,
+                                         SDValue NonIdConstNode) const override;
 
     /// Given an intrinsic, checks if on the target the intrinsic will need to map
     /// to a MemIntrinsicNode (touches memory). If this is the case, it returns



More information about the llvm-commits mailing list