[llvm] [DAG] foldShiftToAvg - recognize sub(x, xor(y, -1)) >> 1 as avgceil[su] (#147946) (PR #156239)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Aug 31 05:13:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: Lauren (laurenmchin)
<details>
<summary>Changes</summary>
- Match avgceil[su] idiom lowered by AArch64: (sr[al] (sub x, (xor y, -1)), 1) -> avgceil[su](x, y)
- Keep floor-average fold, but check legality at TruncVT when truncation is present: (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
- Treat SRL of widened signed add/sub feeding a truncate as signed
- Add visitTRUNCATE combine: trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)
This patch resolves the regression test failure:
llvm/test/CodeGen/AArch64/sve-hadd.ll and addresses PR #<!-- -->147946.
---
Full diff: https://github.com/llvm/llvm-project/pull/156239.diff
1 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+149-14)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bed3c42473e27..a51709a33910e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11802,28 +11802,146 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
return SDValue();
EVT VT = N->getValueType(0);
- bool IsUnsigned = Opcode == ISD::SRL;
+ SDValue N0 = N->getOperand(0);
- // Captured values.
- SDValue A, B, Add;
+ if (!sd_match(N->getOperand(1), m_One()))
+ return SDValue();
- // Match floor average as it is common to both floor/ceil avgs.
+ // [TruncVT]
+ // result type of a single truncate user fed by this shift node (if present).
+ // We always use TruncVT to verify whether the target supports folding to
+ // avgceils. For avgfloor[su], we use TruncVT if present, else VT.
+ //
+ // [NarrowVT]
+ // semantic source width of the value(s) being averaged when the ops are
+ // SExt/SExtInReg.
+ EVT TruncVT = VT;
+ SDNode *TruncNode = nullptr;
+
+ // If this shift has a single truncate user, use it to decide whether folding
+ // to avg* is legal at the truncated width. Note that the target may only
+ // support the avgceil[su]/avgfloor[su] op at the narrower type, or the
+ // full-width VT, but we check for legality using the truncate node's VT if
+ // present, else this shift's VT.
+ if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+ TruncNode = *N->user_begin();
+ TruncVT = TruncNode->getValueType(0);
+ }
+
+ EVT NarrowVT = VT;
+ SDValue N00 = N0.getOperand(0);
+
+ // For SRL of SExt'd values, if (1) the type isnt legal, and (2) there's no
+ // truncate user, bail out, because we can't safely fold.
+ if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+ NarrowVT = cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT();
+ if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+ return SDValue();
+ }
+
+ unsigned FloorISD = 0;
+ unsigned CeilISD = 0;
+ bool IsUnsigned = false;
+
+ // Decide whether signed or unsigned.
+ switch (Opcode) {
+ case ISD::SRA:
+ FloorISD = ISD::AVGFLOORS;
+ break;
+ case ISD::SRL:
+ IsUnsigned = true;
+ // SRL of a widened signed sub feeding a truncate acts like shadd.
+ if (TruncNode &&
+ (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) &&
+ (N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+ N00.getOpcode() == ISD::SIGN_EXTEND))
+ IsUnsigned = false;
+ FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+ break;
+ default:
+ return SDValue();
+ }
+
+ CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+ // Bail out if this shift is not truncated and the target doesn't support
+ // the avg* op at this shift's VT (or TruncVT for avgceil[su]).
+ if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+ (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
+ return SDValue();
+
+ SDValue X, Y, Sub, Xor;
+
+ // (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
if (sd_match(N, m_BinOp(Opcode,
- m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+ m_AllOf(m_Value(Sub),
+ m_Sub(m_Value(X),
+ m_AllOf(m_Value(Xor),
+ m_Xor(m_Value(Y), m_Value())))),
m_One()))) {
- // Decide whether signed or unsigned.
- unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
- if (!hasOperation(FloorISD, VT))
- return SDValue();
+ APInt SplatVal;
+ if (ISD::isConstantSplatVector(Xor.getOperand(1).getNode(), SplatVal)) {
+ // - Can't fold if either op is sign/zero-extended for SRL, as SRL
+ // is unsigned, and shadd patterns are handled elsewhere.
+ //
+ // - Large fixed vectors (>128 bits) on AArch64 will be type-legalized
+ // into a series of EXTRACT_SUBVECTORs. Folding each subvector does not
+ // necessarily preserve semantics so they cannot be folded here.
+ if (TruncNode && VT.isFixedLengthVector()) {
+ if (X.getOpcode() == ISD::SIGN_EXTEND ||
+ X.getOpcode() == ISD::ZERO_EXTEND ||
+ Y.getOpcode() == ISD::SIGN_EXTEND ||
+ Y.getOpcode() == ISD::ZERO_EXTEND)
+ return SDValue();
+ else if (TruncNode && VT.isFixedLengthVector() &&
+ VT.getSizeInBits() > 128)
+ return SDValue();
+ }
- // Can't optimize adds that may wrap.
- if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
- (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
- return SDValue();
+ // If there is no truncate user, ensure the relevant no wrap flag is on
+ // the sub so that narrowing the widened result is defined.
+ if (Opcode == ISD::SRA && VT == NarrowVT) {
+ if (!IsUnsigned && !Sub->getFlags().hasNoSignedWrap())
+ return SDValue();
+ } else if (IsUnsigned && !Sub->getFlags().hasNoUnsignedWrap())
+ return SDValue();
- return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
+ // Only fold if the target supports avgceil[su] at the truncated type:
+ // - if there is a single truncate user, we require support at TruncVT.
+ // We build the avg* at VT (to replace this shift node).
+ // visitTRUNCATE handles the actual folding to avgceils (x, y).
+ // - otherwise, we require support at VT (TruncVT == VT).
+ //
+ // AArch64 canonicalizes (x + y + 1) >> 1 -> sub (x, xor (y, -1)). In
+ // order for our fold to be legal, we require support for the VT at the
+ // final observable type (TruncVT or VT).
+ if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
+ return DAG.getNode(CeilISD, DL, VT, Y, X);
+ }
}
+ // Captured values.
+ SDValue A, B, Add;
+
+ // Match floor average as it is common to both floor/ceil avgs.
+ // (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+ if (!sd_match(N, m_BinOp(Opcode,
+ m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+ m_One())))
+ return SDValue();
+
+ if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+ return SDValue();
+
+ // Can't optimize adds that may wrap.
+ if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+ (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+ return SDValue();
+
+ EVT TargetVT = TruncNode ? TruncVT : VT;
+ if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+ return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
+
return SDValue();
}
@@ -16294,6 +16412,23 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}
+ // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
+ if (N0.getOpcode() == ISD::AVGCEILU) {
+ SDValue SExtX = N0.getOperand(0);
+ SDValue SExtY = N0.getOperand(1);
+ if ((SExtX.getOpcode() == ISD::SIGN_EXTEND &&
+ SExtY.getOpcode() == ISD::SIGN_EXTEND) ||
+ (SExtX.getOpcode() == ISD::SIGN_EXTEND_INREG &&
+ SExtY.getOpcode() == ISD::SIGN_EXTEND_INREG)) {
+ SDValue X = SExtX.getOperand(0);
+ SDValue Y = SExtY.getOperand(0);
+ if (X.getValueType() == VT &&
+ TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT)) {
+ return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
+ }
+ }
+ }
+
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
return NewVSel;
``````````
</details>
https://github.com/llvm/llvm-project/pull/156239
More information about the llvm-commits
mailing list