[llvm] [DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef h… (PR #147044)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 6 19:18:39 PDT 2025
https://github.com/woruyu updated https://github.com/llvm/llvm-project/pull/147044
>From 229ec2822e5c5d959f287eb8f93fbce0007e3b13 Mon Sep 17 00:00:00 2001
From: woruyu <1214539920 at qq.com>
Date: Fri, 4 Jul 2025 19:20:07 +0800
Subject: [PATCH] [DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent
undef handling
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 39 ++++++++++++++++---
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 4 ++
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 2 +-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 ++++++
llvm/lib/Target/X86/X86ISelLowering.cpp | 20 +++++-----
5 files changed, 59 insertions(+), 18 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
return SpecificInt_match(APInt(64, V));
}
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+ bool AllowUndefs;
+
+ explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &, SDValue N) const {
+ return isZeroOrZeroSplat(N, AllowUndefs);
+ }
+};
+
+struct Ones_match {
+ bool AllowUndefs;
+
+ Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ return isOnesOrOnesSplat(N, AllowUndefs);
+ }
+};
struct AllOnes_match {
+ bool AllowUndefs;
- AllOnes_match() = default;
+ AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
- return isAllOnesOrAllOnesSplat(N);
+ return isAllOnesOrAllOnesSplat(N, AllowUndefs);
}
};
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+ return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+ return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+ return AllOnes_match(AllowUndefs);
+}
/// Match true boolean value based on the information provided by
/// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
/// Match a negate as a sub(0, v)
template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
return m_Sub(m_Zero(), V);
}
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
/// Does not permit build vector implicit truncation.
LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
/// Return true if \p V is either a integer or FP constant.
inline bool isIntOrFPConstant(SDValue V) {
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d4ad4d3a09381..f94b3a35652fc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
return V;
// (A - B) - 1 -> add (xor B, -1), A
- if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+ if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
// Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
}
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+ N = peekThroughBitcasts(N);
+ ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+ return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+ N = peekThroughBitcasts(N);
+ ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+ return C && C->isZero();
+}
+
HandleSDNode::~HandleSDNode() {
DropOperands();
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index afffe51f23a27..7ec666d0b1658 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57925,22 +57925,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
}
}
+ SDValue X, Y;
+
// add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
// iff X and Y won't overflow.
- if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
- ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
- ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
- if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
- MVT OpVT = Op0.getOperand(1).getSimpleValueType();
- SDValue Sum =
- DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
- return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
- getZeroVector(OpVT, Subtarget, DAG, DL));
- }
+ if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+ sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+ DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+ MVT OpVT = X.getSimpleValueType();
+ SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+ return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+ getZeroVector(OpVT, Subtarget, DAG, DL));
}
if (VT.isVector()) {
- SDValue X, Y;
EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
VT.getVectorElementCount());
More information about the llvm-commits
mailing list