[llvm] [DAG] foldShiftToAvg - Fixes avgceil[su] pattern matching for sub+xor form (PR #169199)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 22 22:17:39 PST 2025


https://github.com/laurenmchin updated https://github.com/llvm/llvm-project/pull/169199

>From b342abfb8773877eb86805ee0aba02b605cf61cd Mon Sep 17 00:00:00 2001
From: Lauren Chin <lchin at berkeley.edu>
Date: Sun, 23 Nov 2025 00:51:51 -0500
Subject: [PATCH] [DAG] foldShiftToAvg - Fixes avgceil[su] pattern matching for
 sub+xor form

Fixes regression where avgceil[su] patterns fail to match when AArch64
canonicalizes `(add (add x, y), 1)` to `(sub x, (xor y, -1))`, causing
SVE/SVE2 test failures.

Addresses the remaining regression in (https://github.com/llvm/llvm-project/issues/147946)[#147946]
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 128 ++++++++++++++++--
 1 file changed, 114 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b79dbb46cadc..e874b4d1e59de 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11943,28 +11943,128 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
     return SDValue();
 
   EVT VT = N->getValueType(0);
-  bool IsUnsigned = Opcode == ISD::SRL;
+  SDValue N0 = N->getOperand(0);
 
-  // Captured values.
-  SDValue A, B, Add;
+  if (!isOnesOrOnesSplat(N->getOperand(1)))
+    return SDValue();
 
-  // Match floor average as it is common to both floor/ceil avgs.
+  EVT TruncVT = VT;
+  SDNode *TruncNode = nullptr;
+
+  // We need the correct type to check for avgceil/floor support.
+  if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+    TruncNode = *N->user_begin();
+    TruncVT = TruncNode->getValueType(0);
+  }
+
+  // NarrowVT is used to detect whether we're working with sign-extended values.
+  EVT NarrowVT = VT;
+  SDValue N00 = N0.getOperand(0);
+
+  // Extract narrow type from SIGN_EXTEND_INREG. For SRL, require the narrow
+  // type to be legal to ensure correct width avg operations.
+  if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+    NarrowVT = cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT();
+    if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+      return SDValue();
+  }
+
+  unsigned FloorISD = 0;
+  unsigned CeilISD = 0;
+  bool IsUnsigned = false;
+
+  // Decide whether signed or unsigned.
+  switch (Opcode) {
+  case ISD::SRA:
+    FloorISD = ISD::AVGFLOORS;
+    break;
+  case ISD::SRL:
+    IsUnsigned = true;
+    if (TruncNode &&
+        (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB)) {
+      // Use signed avg for SRL of sign-extended values when truncating.
+      SDValue N01 = N0.getOperand(1);
+      if ((N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+           N00.getOpcode() == ISD::SIGN_EXTEND) ||
+          (N01.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+           N01.getOpcode() == ISD::SIGN_EXTEND))
+        IsUnsigned = false;
+    }
+    FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+    break;
+  default:
+    return SDValue();
+  }
+
+  CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+  // Without truncation, require target support for both averaging operations.
+  // We check FloorISD at VT (generated type), CeilISD at TruncVT (final type).
+  if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+      (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
+    return SDValue();
+
+  SDValue X, Y, Sub, Xor;
+
+  // fold (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
   if (sd_match(N, m_BinOp(Opcode,
-                          m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                          m_AllOf(m_Value(Sub),
+                                  m_Sub(m_Value(X),
+                                        m_AllOf(m_Value(Xor),
+                                                m_Xor(m_Value(Y), m_Value())))),
                           m_One()))) {
-    // Decide whether signed or unsigned.
-    unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
-    if (!hasOperation(FloorISD, VT))
-      return SDValue();
 
-    // Can't optimize adds that may wrap.
-    if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
-        (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
-      return SDValue();
+    ConstantSDNode *C = isConstOrConstSplat(Xor.getOperand(1),
+                                            /*AllowUndefs=*/false,
+                                            /*AllowTruncation=*/true);
+    if (C && C->getAPIntValue().trunc(VT.getScalarSizeInBits()).isAllOnes()) {
+      // Don't fold extended inputs with truncation on fixed vectors > 128b
+      if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128) {
+        if (X.getOpcode() == ISD::SIGN_EXTEND ||
+            X.getOpcode() == ISD::ZERO_EXTEND ||
+            Y.getOpcode() == ISD::SIGN_EXTEND ||
+            Y.getOpcode() == ISD::ZERO_EXTEND)
+          return SDValue();
+      }
+
+      if (!TruncNode) {
+        // Without truncation, require no-wrap flags for safe narrowing.
+        const SDNodeFlags &Flags = Sub->getFlags();
+        if ((!IsUnsigned && (Opcode == ISD::SRA && VT == NarrowVT) &&
+             !Flags.hasNoSignedWrap()) ||
+            (IsUnsigned && !Flags.hasNoUnsignedWrap()))
+          return SDValue();
+      }
 
-    return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
+      // Require avgceil[su] support at the final type:
+      //  - with truncation: build at VT, visitTRUNCATE completes the fold
+      //  - without truncation: build directly at VT (where TruncVT == VT).
+      if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
+        return DAG.getNode(CeilISD, DL, VT, Y, X);
+    }
   }
 
+  // Captured values.
+  SDValue A, B, Add;
+
+  // Match floor average as it is common to both floor/ceil avgs.
+  // fold (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+  if (!sd_match(N, m_BinOp(Opcode,
+                           m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                           m_One())))
+    return SDValue();
+
+  if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+    return SDValue();
+
+  // Can't optimize adds that may wrap.
+  if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+      (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+    return SDValue();
+
+  EVT TargetVT = TruncNode ? TruncVT : VT;
+  if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+    return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
   return SDValue();
 }
 



More information about the llvm-commits mailing list