[llvm] [DAG] foldShiftToAvg - Fixes avgceil[su] pattern matching for sub+xor form (PR #169199)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Nov 22 21:56:07 PST 2025
https://github.com/laurenmchin created https://github.com/llvm/llvm-project/pull/169199
Fixes regression where avgceil[su] patterns fail to match when AArch64 canonicalizes `(add (add x, y), 1)` to `(sub x, (xor y, -1))`, causing SVE/SVE2 test failures for DAG topological sorting.
Addresses the remaining regression in (https://github.com/llvm/llvm-project/issues/147946)[#147946]
>From 1d55e58ee33019acc319b7e990b0f8cacdcddc93 Mon Sep 17 00:00:00 2001
From: Lauren Chin <lchin at berkeley.edu>
Date: Sun, 23 Nov 2025 00:51:51 -0500
Subject: [PATCH] [DAG] foldShiftToAvg - Fixes avgceil[su] pattern matching for
sub+xor form
Fixes regression where avgceil[su] patterns fail to match when AArch64
canonicalizes `(add (add x, y), 1)` to `(sub x, (xor y, -1))`, causing
SVE/SVE2 test failures.
Addresses the remaining regression in (https://github.com/llvm/llvm-project/issues/147946)[#147946]
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 128 ++++++++++++++++--
1 file changed, 114 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b79dbb46cadc..e874b4d1e59de 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11943,28 +11943,128 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
return SDValue();
EVT VT = N->getValueType(0);
- bool IsUnsigned = Opcode == ISD::SRL;
+ SDValue N0 = N->getOperand(0);
- // Captured values.
- SDValue A, B, Add;
+ if (!isOnesOrOnesSplat(N->getOperand(1)))
+ return SDValue();
- // Match floor average as it is common to both floor/ceil avgs.
+ EVT TruncVT = VT;
+ SDNode *TruncNode = nullptr;
+
+ // We need the correct type to check for avgceil/floor support.
+ if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+ TruncNode = *N->user_begin();
+ TruncVT = TruncNode->getValueType(0);
+ }
+
+ // NarrowVT is used to detect whether we're working with sign-extended values.
+ EVT NarrowVT = VT;
+ SDValue N00 = N0.getOperand(0);
+
+ // Extract narrow type from SIGN_EXTEND_INREG. For SRL, require the narrow
+ // type to be legal to ensure correct width avg operations.
+ if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+ NarrowVT = cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT();
+ if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+ return SDValue();
+ }
+
+ unsigned FloorISD = 0;
+ unsigned CeilISD = 0;
+ bool IsUnsigned = false;
+
+ // Decide whether signed or unsigned.
+ switch (Opcode) {
+ case ISD::SRA:
+ FloorISD = ISD::AVGFLOORS;
+ break;
+ case ISD::SRL:
+ IsUnsigned = true;
+ if (TruncNode &&
+ (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB)) {
+ // Use signed avg for SRL of sign-extended values when truncating.
+ SDValue N01 = N0.getOperand(1);
+ if ((N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+ N00.getOpcode() == ISD::SIGN_EXTEND) ||
+ (N01.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+ N01.getOpcode() == ISD::SIGN_EXTEND))
+ IsUnsigned = false;
+ }
+ FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+ break;
+ default:
+ return SDValue();
+ }
+
+ CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+ // Without truncation, require target support for both averaging operations.
+ // We check FloorISD at VT (generated type), CeilISD at TruncVT (final type).
+ if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+ (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
+ return SDValue();
+
+ SDValue X, Y, Sub, Xor;
+
+ // fold (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
if (sd_match(N, m_BinOp(Opcode,
- m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+ m_AllOf(m_Value(Sub),
+ m_Sub(m_Value(X),
+ m_AllOf(m_Value(Xor),
+ m_Xor(m_Value(Y), m_Value())))),
m_One()))) {
- // Decide whether signed or unsigned.
- unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
- if (!hasOperation(FloorISD, VT))
- return SDValue();
- // Can't optimize adds that may wrap.
- if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
- (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
- return SDValue();
+ ConstantSDNode *C = isConstOrConstSplat(Xor.getOperand(1),
+ /*AllowUndefs=*/false,
+ /*AllowTruncation=*/true);
+ if (C && C->getAPIntValue().trunc(VT.getScalarSizeInBits()).isAllOnes()) {
+ // Don't fold extended inputs with truncation on fixed vectors > 128b
+ if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128) {
+ if (X.getOpcode() == ISD::SIGN_EXTEND ||
+ X.getOpcode() == ISD::ZERO_EXTEND ||
+ Y.getOpcode() == ISD::SIGN_EXTEND ||
+ Y.getOpcode() == ISD::ZERO_EXTEND)
+ return SDValue();
+ }
+
+ if (!TruncNode) {
+ // Without truncation, require no-wrap flags for safe narrowing.
+ const SDNodeFlags &Flags = Sub->getFlags();
+ if ((!IsUnsigned && (Opcode == ISD::SRA && VT == NarrowVT) &&
+ !Flags.hasNoSignedWrap()) ||
+ (IsUnsigned && !Flags.hasNoUnsignedWrap()))
+ return SDValue();
+ }
- return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
+ // Require avgceil[su] support at the final type:
+ // - with truncation: build at VT, visitTRUNCATE completes the fold
+ // - without truncation: build directly at VT (where TruncVT == VT).
+ if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
+ return DAG.getNode(CeilISD, DL, VT, Y, X);
+ }
}
+ // Captured values.
+ SDValue A, B, Add;
+
+ // Match floor average as it is common to both floor/ceil avgs.
+ // fold (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+ if (!sd_match(N, m_BinOp(Opcode,
+ m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+ m_One())))
+ return SDValue();
+
+ if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+ return SDValue();
+
+ // Can't optimize adds that may wrap.
+ if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+ (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+ return SDValue();
+
+ EVT TargetVT = TruncNode ? TruncVT : VT;
+ if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+ return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
return SDValue();
}
More information about the llvm-commits
mailing list