[llvm] [DAG] foldShiftToAvg - recognize sub(x, xor(y, -1)) >> 1 as avgceil[su] (#147946) (PR #156239)

via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 31 05:13:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Lauren (laurenmchin)

<details>
<summary>Changes</summary>

- Match avgceil[su] idiom lowered by AArch64: (sr[al] (sub x, (xor y, -1)), 1) -> avgceil[su](x, y)
- Keep floor-average fold, but check legality at TruncVT when truncation is present: (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
- Treat SRL of widened signed add/sub feeding a truncate as signed
- Add visitTRUNCATE combine: trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)

This patch resolves the regression test failure:
  llvm/test/CodeGen/AArch64/sve-hadd.ll and addresses PR #<!-- -->147946.

---
Full diff: https://github.com/llvm/llvm-project/pull/156239.diff


1 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+149-14) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bed3c42473e27..a51709a33910e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11802,28 +11802,146 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
     return SDValue();
 
   EVT VT = N->getValueType(0);
-  bool IsUnsigned = Opcode == ISD::SRL;
+  SDValue N0 = N->getOperand(0);
 
-  // Captured values.
-  SDValue A, B, Add;
+  if (!sd_match(N->getOperand(1), m_One()))
+    return SDValue();
 
-  // Match floor average as it is common to both floor/ceil avgs.
+  // [TruncVT]
+  // result type of a single truncate user fed by this shift node (if present).
+  // We always use TruncVT to verify whether the target supports folding to
+  // avgceils. For avgfloor[su], we use TruncVT if present, else VT.
+  //
+  // [NarrowVT]
+  // semantic source width of the value(s) being averaged when the ops are
+  // SExt/SExtInReg.
+  EVT TruncVT = VT;
+  SDNode *TruncNode = nullptr;
+
+  // If this shift has a single truncate user, use it to decide whether folding
+  // to avg* is legal at the truncated width. Note that the target may only
+  // support the avgceil[su]/avgfloor[su] op at the narrower type, or the
+  // full-width VT, but we check for legality using the truncate node's VT if
+  // present, else this shift's VT.
+  if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+    TruncNode = *N->user_begin();
+    TruncVT = TruncNode->getValueType(0);
+  }
+
+  EVT NarrowVT = VT;
+  SDValue N00 = N0.getOperand(0);
+
+  // For SRL of SExt'd values, if (1) the type isnt legal, and (2) there's no
+  // truncate user, bail out, because we can't safely fold.
+  if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+    NarrowVT = cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT();
+    if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+      return SDValue();
+  }
+
+  unsigned FloorISD = 0;
+  unsigned CeilISD = 0;
+  bool IsUnsigned = false;
+
+  // Decide whether signed or unsigned.
+  switch (Opcode) {
+  case ISD::SRA:
+    FloorISD = ISD::AVGFLOORS;
+    break;
+  case ISD::SRL:
+    IsUnsigned = true;
+    // SRL of a widened signed sub feeding a truncate acts like shadd.
+    if (TruncNode &&
+        (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) &&
+        (N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+         N00.getOpcode() == ISD::SIGN_EXTEND))
+      IsUnsigned = false;
+    FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+    break;
+  default:
+    return SDValue();
+  }
+
+  CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+  // Bail out if this shift is not truncated and the target doesn't support
+  // the avg* op at this shift's VT (or TruncVT for avgceil[su]).
+  if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+      (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
+    return SDValue();
+
+  SDValue X, Y, Sub, Xor;
+
+  // (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
   if (sd_match(N, m_BinOp(Opcode,
-                          m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                          m_AllOf(m_Value(Sub),
+                                  m_Sub(m_Value(X),
+                                        m_AllOf(m_Value(Xor),
+                                                m_Xor(m_Value(Y), m_Value())))),
                           m_One()))) {
-    // Decide whether signed or unsigned.
-    unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
-    if (!hasOperation(FloorISD, VT))
-      return SDValue();
+    APInt SplatVal;
+    if (ISD::isConstantSplatVector(Xor.getOperand(1).getNode(), SplatVal)) {
+      // - Can't fold if either op is sign/zero-extended for SRL, as SRL
+      //   is unsigned, and shadd patterns are handled elsewhere.
+      //
+      // - Large fixed vectors (>128 bits) on AArch64 will be type-legalized
+      //   into a series of EXTRACT_SUBVECTORs. Folding each subvector does not
+      //   necessarily preserve semantics so they cannot be folded here.
+      if (TruncNode && VT.isFixedLengthVector()) {
+        if (X.getOpcode() == ISD::SIGN_EXTEND ||
+            X.getOpcode() == ISD::ZERO_EXTEND ||
+            Y.getOpcode() == ISD::SIGN_EXTEND ||
+            Y.getOpcode() == ISD::ZERO_EXTEND)
+          return SDValue();
+        else if (TruncNode && VT.isFixedLengthVector() &&
+                 VT.getSizeInBits() > 128)
+          return SDValue();
+      }
 
-    // Can't optimize adds that may wrap.
-    if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
-        (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
-      return SDValue();
+      // If there is no truncate user, ensure the relevant no wrap flag is on
+      // the sub so that narrowing the widened result is defined.
+      if (Opcode == ISD::SRA && VT == NarrowVT) {
+        if (!IsUnsigned && !Sub->getFlags().hasNoSignedWrap())
+          return SDValue();
+      } else if (IsUnsigned && !Sub->getFlags().hasNoUnsignedWrap())
+        return SDValue();
 
-    return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
+      // Only fold if the target supports avgceil[su] at the truncated type:
+      // - if there is a single truncate user, we require support at TruncVT.
+      //   We build the avg* at VT (to replace this shift node).
+      //   visitTRUNCATE handles the actual folding to avgceils (x, y).
+      // - otherwise, we require support at VT (TruncVT == VT).
+      //
+      // AArch64 canonicalizes (x + y + 1) >> 1 -> sub (x, xor (y, -1)). In
+      // order for our fold to be legal, we require support for the VT at the
+      // final observable type (TruncVT or VT).
+      if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
+        return DAG.getNode(CeilISD, DL, VT, Y, X);
+    }
   }
 
+  // Captured values.
+  SDValue A, B, Add;
+
+  // Match floor average as it is common to both floor/ceil avgs.
+  // (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+  if (!sd_match(N, m_BinOp(Opcode,
+                           m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                           m_One())))
+    return SDValue();
+
+  if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+    return SDValue();
+
+  // Can't optimize adds that may wrap.
+  if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+      (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+    return SDValue();
+
+  EVT TargetVT = TruncNode ? TruncVT : VT;
+  if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+    return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
+
   return SDValue();
 }
 
@@ -16294,6 +16412,23 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
     }
   }
 
+  // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
+  if (N0.getOpcode() == ISD::AVGCEILU) {
+    SDValue SExtX = N0.getOperand(0);
+    SDValue SExtY = N0.getOperand(1);
+    if ((SExtX.getOpcode() == ISD::SIGN_EXTEND &&
+         SExtY.getOpcode() == ISD::SIGN_EXTEND) ||
+        (SExtX.getOpcode() == ISD::SIGN_EXTEND_INREG &&
+         SExtY.getOpcode() == ISD::SIGN_EXTEND_INREG)) {
+      SDValue X = SExtX.getOperand(0);
+      SDValue Y = SExtY.getOperand(0);
+      if (X.getValueType() == VT &&
+          TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT)) {
+        return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
+      }
+    }
+  }
+
   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
     return NewVSel;
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/156239


More information about the llvm-commits mailing list