[llvm] 1064768 - [SDAG] Make Select-with-Identity-Fold More Flexible; NFC (#136554)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 29 00:46:43 PDT 2025
Author: Marius Kamp
Date: 2025-05-29T09:46:39+02:00
New Revision: 10647685ca3cad0107a2f754b21a078405d30359
URL: https://github.com/llvm/llvm-project/commit/10647685ca3cad0107a2f754b21a078405d30359
DIFF: https://github.com/llvm/llvm-project/commit/10647685ca3cad0107a2f754b21a078405d30359.diff
LOG: [SDAG] Make Select-with-Identity-Fold More Flexible; NFC (#136554)
This change adds new parameters to the method
`shouldFoldSelectWithIdentityConstant()`. The method now takes the
opcode of the select node and the non-identity operand of the select
node. To gain access to the appropriate arguments, the call of
`shouldFoldSelectWithIdentityConstant()` is moved after all other checks
have been performed. Moreover, this change adjusts the precondition of
the fold so that it would work for `SELECT` nodes in addition to
`VSELECT` nodes.
No functional change is intended because all implementations of
`shouldFoldSelectWithIdentityConstant()` are adjusted such that they
restrict the fold to a `VSELECT` node; the same restriction as before.
The rationale of this change is to make more fine grained decisions
possible when to revert the InstCombine canonicalization of
`(select c (binop x y) y)` to `(binop (select c x idc) y)` in the
backends.
Added:
Modified:
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/lib/Target/ARM/ARMISelLowering.h
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/lib/Target/X86/X86ISelLowering.h
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2f189f27e6daa..b818f4768c2c3 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3390,8 +3390,10 @@ class LLVM_ABI TargetLoweringBase {
/// Return true if pulling a binary operation into a select with an identity
/// constant is profitable. This is the inverse of an IR transform.
/// Example: X + (Cond ? Y : 0) --> Cond ? (X + Y) : X
- virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
- EVT VT) const {
+ virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+ unsigned SelectOpcode,
+ SDValue X,
+ SDValue Y) const {
return false;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 9e418329d15be..12f5c3ff4eaad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2433,8 +2433,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
if (ShouldCommuteOperands)
std::swap(N0, N1);
- // TODO: Should this apply to scalar select too?
- if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
+ unsigned SelOpcode = N1.getOpcode();
+ if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
+ !N1.hasOneUse())
return SDValue();
// We can't hoist all instructions because of immediate UB (not speculatable).
@@ -2447,17 +2448,22 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
SDValue Cond = N1.getOperand(0);
SDValue TVal = N1.getOperand(1);
SDValue FVal = N1.getOperand(2);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
// This transform increases uses of N0, so freeze it to be safe.
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
- if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
+ if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
+ TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
+ FVal)) {
SDValue F0 = DAG.getFreeze(N0);
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
}
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
- if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
+ if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
+ TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
+ TVal)) {
SDValue F0 = DAG.getFreeze(N0);
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
@@ -2467,26 +2473,23 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
}
SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
"Unexpected binary operator");
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- auto BinOpcode = BO->getOpcode();
- EVT VT = BO->getValueType(0);
- if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
- if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
- return Sel;
+ if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
+ return Sel;
- if (TLI.isCommutativeBinOp(BO->getOpcode()))
- if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
- return Sel;
- }
+ if (TLI.isCommutativeBinOp(BO->getOpcode()))
+ if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
+ return Sel;
// Don't do this unless the old select is going away. We want to eliminate the
// binary operator, not replace a binop with a select.
// TODO: Handle ISD::SELECT_CC.
unsigned SelOpNo = 0;
SDValue Sel = BO->getOperand(0);
+ auto BinOpcode = BO->getOpcode();
if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
SelOpNo = 1;
Sel = BO->getOperand(1);
@@ -2534,6 +2537,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
SDLoc DL(Sel);
SDValue NewCT, NewCF;
+ EVT VT = BO->getValueType(0);
if (CanFoldNonConst) {
// If CBO is an opaque constant, we can't rely on getNode to constant fold.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a817ed5f0e917..b9882729a26b9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17685,8 +17685,10 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
}
bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
- unsigned BinOpcode, EVT VT) const {
- return VT.isScalableVector() && isTypeLegal(VT);
+ unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+ SDValue Y) const {
+ return VT.isScalableVector() && isTypeLegal(VT) &&
+ SelectOpcode == ISD::VSELECT;
}
bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1924d20f67f49..450e2efd7d430 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -281,8 +281,9 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
CombineLevel Level) const override;
- bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
- EVT VT) const override;
+ bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+ unsigned SelectOpcode, SDValue X,
+ SDValue Y) const override;
/// Returns true if it is beneficial to convert a load of a constant
/// to just the constant itself.
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index e0b7bacaa5729..b169adc1389d8 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -13957,9 +13957,11 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
return false;
}
-bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
- EVT VT) const {
- return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT);
+bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(
+ unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+ SDValue Y) const {
+ return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT) &&
+ SelectOpcode == ISD::VSELECT;
}
bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index 9fad056edd3f1..87710ee29a249 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -758,8 +758,9 @@ class VectorType;
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
CombineLevel Level) const override;
- bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
- EVT VT) const override;
+ bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+ unsigned SelectOpcode, SDValue X,
+ SDValue Y) const override;
bool preferIncOfAddToSubOfNot(EVT VT) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0a849f49116ee..43c81b97a0e05 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2145,8 +2145,12 @@ bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
return C && C->getAPIntValue().ule(10);
}
-bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
- EVT VT) const {
+bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(
+ unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+ SDValue Y) const {
+ if (SelectOpcode != ISD::VSELECT)
+ return false;
+
// Only enable for rvv.
if (!VT.isVector() || !Subtarget.hasVInstructions())
return false;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 78f2044ba83a7..1fcb25c8cd729 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -95,8 +95,9 @@ class RISCVTargetLowering : public TargetLowering {
unsigned &NumIntermediates,
MVT &RegisterVT) const override;
- bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
- EVT VT) const override;
+ bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+ unsigned SelectOpcode, SDValue X,
+ SDValue Y) const override;
/// Return true if the given shuffle mask can be codegen'd directly, or if it
/// should be stack expanded.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 99a82cab384aa..6b71f49165c60 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -35549,8 +35549,11 @@ bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
}
-bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
- EVT VT) const {
+bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(
+ unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
+ SDValue Y) const {
+ if (SelectOpcode != ISD::VSELECT)
+ return false;
// TODO: This is too general. There are cases where pre-AVX512 codegen would
// benefit. The transform may also be profitable for scalar code.
if (!Subtarget.hasAVX512())
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 359f24768b3da..5cb6b3e493a32 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1467,8 +1467,9 @@ namespace llvm {
/// from i32 to i16.
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;
- bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
- EVT VT) const override;
+ bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
+ unsigned SelectOpcode, SDValue X,
+ SDValue Y) const override;
/// Given an intrinsic, checks if on the target the intrinsic will need to map
/// to a MemIntrinsicNode (touches memory). If this is the case, it returns
More information about the llvm-commits
mailing list