[llvm] [DAG] SDPatternMatch - allow m_BinOp / m_c_BinOp to take an optional SDNodeFlags required for matching (PR #178435)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 28 06:28:37 PST 2026


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/178435

BinaryOpc_match is already wired up for this - but allow us to use m_BinOp/m_c_BinOp with the required flags directly

Updated the foldShiftToAvg folds to make use of this

>From 2d0b0eac7cbc20a2027a3e2de458d44087694419 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Wed, 28 Jan 2026 14:27:57 +0000
Subject: [PATCH] [DAG] SDPatternMatch - allow m_BinOp / m_c_BinOp to take an
 optional SDNodeFlags required for matching

BinaryOpc_match is already wired up for this - but allow us to use m_BinOp/m_c_BinOp with the required flags directly

Updated the foldShiftToAvg folds to make use of this
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 12 ++++++-----
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 20 ++++++++-----------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  8 ++++++++
 3 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index ba8471ab4f947..fc672d23785a9 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -764,13 +764,15 @@ struct umin_pred_ty {
 
 template <typename LHS, typename RHS>
 inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
-                                         const RHS &R) {
-  return BinaryOpc_match<LHS, RHS>(Opc, L, R);
+                                         const RHS &R,
+                                         SDNodeFlags Flgs = SDNodeFlags()) {
+  return BinaryOpc_match<LHS, RHS>(Opc, L, R, Flgs);
 }
 template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L,
-                                                 const RHS &R) {
-  return BinaryOpc_match<LHS, RHS, true>(Opc, L, R);
+inline BinaryOpc_match<LHS, RHS, true>
+m_c_BinOp(unsigned Opc, const LHS &L, const RHS &R,
+          SDNodeFlags Flgs = SDNodeFlags()) {
+  return BinaryOpc_match<LHS, RHS, true>(Opc, L, R, Flgs);
 }
 
 template <typename LHS, typename RHS>
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 02ca8197161a5..6ea4864ca98a8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12166,23 +12166,19 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
   bool IsUnsigned = Opcode == ISD::SRL;
 
   // Captured values.
-  SDValue A, B, Add;
+  SDValue A, B;
 
-  // Match floor average as it is common to both floor/ceil avgs.
+  // Match floor average as it is common to both floor/ceil avgs, ensure the add
+  // doesn't wrap.
+  SDNodeFlags Flags =
+      IsUnsigned ? SDNodeFlags::NoUnsignedWrap : SDNodeFlags::NoSignedWrap;
   if (sd_match(N, m_BinOp(Opcode,
-                          m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                          m_c_BinOp(ISD::ADD, m_Value(A), m_Value(B), Flags),
                           m_One()))) {
     // Decide whether signed or unsigned.
     unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
-    if (!hasOperation(FloorISD, VT))
-      return SDValue();
-
-    // Can't optimize adds that may wrap.
-    if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
-        (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
-      return SDValue();
-
-    return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
+    if (hasOperation(FloorISD, VT))
+      return DAG.getNode(FloorISD, DL, VT, {A, B});
   }
 
   return SDValue();
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 4d955d9b4ef0e..ad2cd9decd9f8 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -319,11 +319,19 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Or, m_BitwiseLogic(m_Value(), m_Value())));
   EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(
+      Or, m_BinOp(ISD::OR, m_Value(), m_Value(), SDNodeFlags::Disjoint)));
+  EXPECT_FALSE(sd_match(
+      Or, m_c_BinOp(ISD::OR, m_Value(), m_Value(), SDNodeFlags::Disjoint)));
 
   EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
   EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      DisOr, m_BinOp(ISD::OR, m_Value(), m_Value(), SDNodeFlags::Disjoint)));
+  EXPECT_TRUE(sd_match(
+      DisOr, m_c_BinOp(ISD::OR, m_Value(), m_Value(), SDNodeFlags::Disjoint)));
 
   EXPECT_TRUE(sd_match(Rotl, m_Rotl(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Rotr, m_Rotr(m_Value(), m_Value())));



More information about the llvm-commits mailing list