[llvm] [DAG] ISD::matchUnaryPredicate / matchUnaryFpPredicate / matchBinaryPredicate - add DemandedElts variant (PR #183013)

via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 28 23:32:19 PST 2026


https://github.com/VachanVY updated https://github.com/llvm/llvm-project/pull/183013

>From d6c42e55b05493e93ad492a26c29ad776eba2bb2 Mon Sep 17 00:00:00 2001
From: Vachan V Y <vachanvy05 at gmail.com>
Date: Tue, 24 Feb 2026 04:40:29 +0530
Subject: [PATCH 1/5] [DAG] Add DemandedElts variant to
 `ISD::matchUnaryPredicate` and `ISD::matchUnaryFpPredicate`

Fixes #181658
---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 20 ++++++++++++--
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 27 ++++++-------------
 2 files changed, 26 insertions(+), 21 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 3bbafe2d124e7..4eda70f6fbb8d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3461,25 +3461,41 @@ namespace ISD {
   /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
   template <typename ConstNodeType>
   bool matchUnaryPredicateImpl(SDValue Op,
+                               const APInt &DemandedElts,
                                std::function<bool(ConstNodeType *)> Match,
                                bool AllowUndefs = false,
                                bool AllowTruncation = false);
 
   /// Hook for matching ConstantSDNode predicate
   inline bool matchUnaryPredicate(SDValue Op,
+                                  const APInt &DemandedElts,
                                   std::function<bool(ConstantSDNode *)> Match,
                                   bool AllowUndefs = false,
                                   bool AllowTruncation = false) {
-    return matchUnaryPredicateImpl<ConstantSDNode>(Op, Match, AllowUndefs,
+    return matchUnaryPredicateImpl<ConstantSDNode>(Op, DemandedElts, Match, AllowUndefs,
                                                    AllowTruncation);
   }
 
+  inline bool matchUnaryPredicate(SDValue Op,
+                                  std::function<bool(ConstantSDNode *)> Match,
+                                  bool AllowUndefs = false,
+                                  bool AllowTruncation = false) {
+    return matchUnaryPredicate(Op, APInt(1, 1), Match, AllowUndefs, AllowTruncation);
+  }
+
   /// Hook for matching ConstantFPSDNode predicate
   inline bool
   matchUnaryFpPredicate(SDValue Op,
+                        const APInt &DemandedElts,
                         std::function<bool(ConstantFPSDNode *)> Match,
                         bool AllowUndefs = false) {
-    return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, Match, AllowUndefs);
+    return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, DemandedElts, Match, AllowUndefs);
+  }
+
+  inline bool matchUnaryFpPredicate(SDValue Op,
+                                    std::function<bool(ConstantFPSDNode *)> Match,
+                                    bool AllowUndefs = false) {
+    return matchUnaryFpPredicate(Op, APInt(1, 1), Match, AllowUndefs);
   }
 
   /// Attempt to match a binary predicate against a pair of scalar/splat
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3affb4de2d4b4..caa615b2720fc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -349,6 +349,7 @@ bool ISD::isFreezeUndef(const SDNode *N) {
 
 template <typename ConstNodeType>
 bool ISD::matchUnaryPredicateImpl(SDValue Op,
+                                  const APInt &DemandedElts,
                                   std::function<bool(ConstNodeType *)> Match,
                                   bool AllowUndefs, bool AllowTruncation) {
   // FIXME: Add support for scalar UNDEF cases?
@@ -362,6 +363,10 @@ bool ISD::matchUnaryPredicateImpl(SDValue Op,
 
   EVT SVT = Op.getValueType().getScalarType();
   for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
+    // Skip non-demanded lanes
+    if (!DemandedElts[i])
+      continue;
+
     if (AllowUndefs && Op.getOperand(i).isUndef()) {
       if (!Match(nullptr))
         return false;
@@ -377,9 +382,9 @@ bool ISD::matchUnaryPredicateImpl(SDValue Op,
 }
 // Build used template types.
 template bool ISD::matchUnaryPredicateImpl<ConstantSDNode>(
-    SDValue, std::function<bool(ConstantSDNode *)>, bool, bool);
+    SDValue, const APInt&, std::function<bool(ConstantSDNode *)>, bool, bool);
 template bool ISD::matchUnaryPredicateImpl<ConstantFPSDNode>(
-    SDValue, std::function<bool(ConstantFPSDNode *)>, bool, bool);
+    SDValue, const APInt&, std::function<bool(ConstantFPSDNode *)>, bool, bool);
 
 bool ISD::matchBinaryPredicate(
     SDValue LHS, SDValue RHS,
@@ -4675,26 +4680,10 @@ bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val,
   };
 
   // Is the constant a known power of 2 or zero?
-  if (ISD::matchUnaryPredicate(Val, IsPowerOfTwoOrZero))
+  if (ISD::matchUnaryPredicate(Val, DemandedElts, IsPowerOfTwoOrZero))
     return true;
 
   switch (Val.getOpcode()) {
-  case ISD::BUILD_VECTOR:
-    // Are all operands of a build vector constant powers of two or zero?
-    if (all_of(enumerate(Val->ops()), [&](auto P) {
-          auto *C = dyn_cast<ConstantSDNode>(P.value());
-          return !DemandedElts[P.index()] || (C && IsPowerOfTwoOrZero(C));
-        }))
-      return true;
-    break;
-
-  case ISD::SPLAT_VECTOR:
-    // Is the operand of a splat vector a constant power of two?
-    if (auto *C = dyn_cast<ConstantSDNode>(Val->getOperand(0)))
-      if (IsPowerOfTwoOrZero(C))
-        return true;
-    break;
-
   case ISD::AND: {
     // Looking for `x & -x` pattern:
     // If x == 0:

>From fdf7f0b74bc9ea6a5e6d8e017c24ea3b6becb8de Mon Sep 17 00:00:00 2001
From: Vachan V Y <vachanvy05 at gmail.com>
Date: Tue, 24 Feb 2026 14:27:16 +0530
Subject: [PATCH 2/5] clang-format

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 26 +++++++++----------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  8 +++---
 2 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 4eda70f6fbb8d..0f01e1e1f7c4c 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3460,41 +3460,41 @@ namespace ISD {
   /// every element of a constant BUILD_VECTOR.
   /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
   template <typename ConstNodeType>
-  bool matchUnaryPredicateImpl(SDValue Op,
-                               const APInt &DemandedElts,
+  bool matchUnaryPredicateImpl(SDValue Op, const APInt &DemandedElts,
                                std::function<bool(ConstNodeType *)> Match,
                                bool AllowUndefs = false,
                                bool AllowTruncation = false);
 
   /// Hook for matching ConstantSDNode predicate
-  inline bool matchUnaryPredicate(SDValue Op,
-                                  const APInt &DemandedElts,
+  inline bool matchUnaryPredicate(SDValue Op, const APInt &DemandedElts,
                                   std::function<bool(ConstantSDNode *)> Match,
                                   bool AllowUndefs = false,
                                   bool AllowTruncation = false) {
-    return matchUnaryPredicateImpl<ConstantSDNode>(Op, DemandedElts, Match, AllowUndefs,
-                                                   AllowTruncation);
+    return matchUnaryPredicateImpl<ConstantSDNode>(
+        Op, DemandedElts, Match, AllowUndefs, AllowTruncation);
   }
 
   inline bool matchUnaryPredicate(SDValue Op,
                                   std::function<bool(ConstantSDNode *)> Match,
                                   bool AllowUndefs = false,
                                   bool AllowTruncation = false) {
-    return matchUnaryPredicate(Op, APInt(1, 1), Match, AllowUndefs, AllowTruncation);
+    return matchUnaryPredicate(Op, APInt(1, 1), Match, AllowUndefs,
+                               AllowTruncation);
   }
 
   /// Hook for matching ConstantFPSDNode predicate
   inline bool
-  matchUnaryFpPredicate(SDValue Op,
-                        const APInt &DemandedElts,
+  matchUnaryFpPredicate(SDValue Op, const APInt &DemandedElts,
                         std::function<bool(ConstantFPSDNode *)> Match,
                         bool AllowUndefs = false) {
-    return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, DemandedElts, Match, AllowUndefs);
+    return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, DemandedElts, Match,
+                                                     AllowUndefs);
   }
 
-  inline bool matchUnaryFpPredicate(SDValue Op,
-                                    std::function<bool(ConstantFPSDNode *)> Match,
-                                    bool AllowUndefs = false) {
+  inline bool
+  matchUnaryFpPredicate(SDValue Op,
+                        std::function<bool(ConstantFPSDNode *)> Match,
+                        bool AllowUndefs = false) {
     return matchUnaryFpPredicate(Op, APInt(1, 1), Match, AllowUndefs);
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index caa615b2720fc..6980add78acf2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -348,8 +348,7 @@ bool ISD::isFreezeUndef(const SDNode *N) {
 }
 
 template <typename ConstNodeType>
-bool ISD::matchUnaryPredicateImpl(SDValue Op,
-                                  const APInt &DemandedElts,
+bool ISD::matchUnaryPredicateImpl(SDValue Op, const APInt &DemandedElts,
                                   std::function<bool(ConstNodeType *)> Match,
                                   bool AllowUndefs, bool AllowTruncation) {
   // FIXME: Add support for scalar UNDEF cases?
@@ -382,9 +381,10 @@ bool ISD::matchUnaryPredicateImpl(SDValue Op,
 }
 // Build used template types.
 template bool ISD::matchUnaryPredicateImpl<ConstantSDNode>(
-    SDValue, const APInt&, std::function<bool(ConstantSDNode *)>, bool, bool);
+    SDValue, const APInt &, std::function<bool(ConstantSDNode *)>, bool, bool);
 template bool ISD::matchUnaryPredicateImpl<ConstantFPSDNode>(
-    SDValue, const APInt&, std::function<bool(ConstantFPSDNode *)>, bool, bool);
+    SDValue, const APInt &, std::function<bool(ConstantFPSDNode *)>, bool,
+    bool);
 
 bool ISD::matchBinaryPredicate(
     SDValue LHS, SDValue RHS,

>From 0f1f4b29e1d6951304a9a806f2710cfdc958bb57 Mon Sep 17 00:00:00 2001
From: Vachan V Y <vachanvy05 at gmail.com>
Date: Tue, 24 Feb 2026 15:00:31 +0530
Subject: [PATCH 3/5] [DAG] Refactor `isKnownNeverZero` to use `DemandedElts`
 in `matchUnaryPredicate`

---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 20 +------------------
 1 file changed, 1 insertion(+), 19 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 6980add78acf2..936063143b019 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -362,7 +362,6 @@ bool ISD::matchUnaryPredicateImpl(SDValue Op, const APInt &DemandedElts,
 
   EVT SVT = Op.getValueType().getScalarType();
   for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
-    // Skip non-demanded lanes
     if (!DemandedElts[i])
       continue;
 
@@ -6165,7 +6164,7 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, const APInt &DemandedElts,
     return !V.isZero();
   };
 
-  if (ISD::matchUnaryPredicate(Op, IsNeverZero))
+  if (ISD::matchUnaryPredicate(Op, DemandedElts, IsNeverZero))
     return true;
 
   // TODO: Recognize more cases here. Most of the cases are also incomplete to
@@ -6173,23 +6172,6 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, const APInt &DemandedElts,
   switch (Op.getOpcode()) {
   default:
     break;
-
-  case ISD::BUILD_VECTOR:
-    // Are all operands of a build vector constant non-zero?
-    if (all_of(enumerate(Op->ops()), [&](auto P) {
-          auto *C = dyn_cast<ConstantSDNode>(P.value());
-          return !DemandedElts[P.index()] || (C && IsNeverZero(C));
-        }))
-      return true;
-    break;
-
-  case ISD::SPLAT_VECTOR:
-    // Is the operand of a splat vector a constant non-zero?
-    if (auto *C = dyn_cast<ConstantSDNode>(Op->getOperand(0)))
-      if (IsNeverZero(C))
-        return true;
-    break;
-
   case ISD::OR:
     return isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
            isKnownNeverZero(Op.getOperand(0), Depth + 1);

>From 6b19416dd5d6a1f2fe09c518885d5cd17f38f472 Mon Sep 17 00:00:00 2001
From: Vachan V Y <vachanvy05 at gmail.com>
Date: Tue, 24 Feb 2026 15:36:54 +0530
Subject: [PATCH 4/5] Fix Failing Tests

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 26 ++++++++++++++++---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  8 ++++--
 2 files changed, 29 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 0f01e1e1f7c4c..972ba453c2c07 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3478,7 +3478,11 @@ namespace ISD {
                                   std::function<bool(ConstantSDNode *)> Match,
                                   bool AllowUndefs = false,
                                   bool AllowTruncation = false) {
-    return matchUnaryPredicate(Op, APInt(1, 1), Match, AllowUndefs,
+    EVT VT = Op.getValueType();
+    APInt DemandedElts = VT.isFixedLengthVector()
+                             ? APInt::getAllOnes(VT.getVectorNumElements())
+                             : APInt(1, 1);
+    return matchUnaryPredicate(Op, DemandedElts, Match, AllowUndefs,
                                AllowTruncation);
   }
 
@@ -3495,7 +3499,11 @@ namespace ISD {
   matchUnaryFpPredicate(SDValue Op,
                         std::function<bool(ConstantFPSDNode *)> Match,
                         bool AllowUndefs = false) {
-    return matchUnaryFpPredicate(Op, APInt(1, 1), Match, AllowUndefs);
+    EVT VT = Op.getValueType();
+    APInt DemandedElts = VT.isFixedLengthVector()
+                             ? APInt::getAllOnes(VT.getVectorNumElements())
+                             : APInt(1, 1);
+    return matchUnaryFpPredicate(Op, DemandedElts, Match, AllowUndefs);
   }
 
   /// Attempt to match a binary predicate against a pair of scalar/splat
@@ -3503,10 +3511,22 @@ namespace ISD {
   /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
   /// If AllowTypeMismatch is true then RetType + ArgTypes don't need to match.
   LLVM_ABI bool matchBinaryPredicate(
-      SDValue LHS, SDValue RHS,
+      SDValue LHS, SDValue RHS, const APInt &DemandedElts,
       std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match,
       bool AllowUndefs = false, bool AllowTypeMismatch = false);
 
+  inline bool matchBinaryPredicate(
+      SDValue LHS, SDValue RHS,
+      std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match,
+      bool AllowUndefs = false, bool AllowTypeMismatch = false) {
+    EVT VT = LHS.getValueType();
+    APInt DemandedElts = VT.isFixedLengthVector()
+                             ? APInt::getAllOnes(VT.getVectorNumElements())
+                             : APInt(1, 1);
+    return matchBinaryPredicate(LHS, RHS, DemandedElts, Match, AllowUndefs,
+                                AllowTypeMismatch);
+  }
+
   /// Returns true if the specified value is the overflow result from one
   /// of the overflow intrinsic nodes.
   inline bool isOverflowIntrOpRes(SDValue Op) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 936063143b019..8f5c50df860b4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -361,8 +361,9 @@ bool ISD::matchUnaryPredicateImpl(SDValue Op, const APInt &DemandedElts,
     return false;
 
   EVT SVT = Op.getValueType().getScalarType();
+  bool IsSplat = ISD::SPLAT_VECTOR == Op.getOpcode();
   for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
-    if (!DemandedElts[i])
+    if (!DemandedElts[IsSplat ? 0 : i])
       continue;
 
     if (AllowUndefs && Op.getOperand(i).isUndef()) {
@@ -386,7 +387,7 @@ template bool ISD::matchUnaryPredicateImpl<ConstantFPSDNode>(
     bool);
 
 bool ISD::matchBinaryPredicate(
-    SDValue LHS, SDValue RHS,
+    SDValue LHS, SDValue RHS, const APInt &DemandedElts,
     std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match,
     bool AllowUndefs, bool AllowTypeMismatch) {
   if (!AllowTypeMismatch && LHS.getValueType() != RHS.getValueType())
@@ -404,7 +405,10 @@ bool ISD::matchBinaryPredicate(
     return false;
 
   EVT SVT = LHS.getValueType().getScalarType();
+  bool IsSplat = ISD::SPLAT_VECTOR == LHS.getOpcode();
   for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
+    if (!DemandedElts[IsSplat ? 0 : i])
+      continue;
     SDValue LHSOp = LHS.getOperand(i);
     SDValue RHSOp = RHS.getOperand(i);
     bool LHSUndef = AllowUndefs && LHSOp.isUndef();

>From 21b8c4a0a7b1623d46aaebf0201a76f4a6c7910b Mon Sep 17 00:00:00 2001
From: Vachan V Y <vachanvy05 at gmail.com>
Date: Sun, 1 Mar 2026 12:46:05 +0530
Subject: [PATCH 5/5] add AllowUndefs and AllowTruncation params

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 8f5c50df860b4..990bd664a6745 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4683,7 +4683,8 @@ bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val,
   };
 
   // Is the constant a known power of 2 or zero?
-  if (ISD::matchUnaryPredicate(Val, DemandedElts, IsPowerOfTwoOrZero))
+  if (ISD::matchUnaryPredicate(Val, DemandedElts, IsPowerOfTwoOrZero,
+                               /*AllowUndefs=*/false, /*AllowTruncation=*/true))
     return true;
 
   switch (Val.getOpcode()) {
@@ -6168,7 +6169,8 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, const APInt &DemandedElts,
     return !V.isZero();
   };
 
-  if (ISD::matchUnaryPredicate(Op, DemandedElts, IsNeverZero))
+  if (ISD::matchUnaryPredicate(Op, DemandedElts, IsNeverZero,
+                               /*AllowUndefs=*/false, /*AllowTruncation=*/true))
     return true;
 
   // TODO: Recognize more cases here. Most of the cases are also incomplete to



More information about the llvm-commits mailing list