[llvm] [DAG] foldShiftToAvg - Fixes avgceil[su] pattern matching for sub+xor form (PR #169199)

via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 23 02:03:43 PST 2025


https://github.com/laurenmchin updated https://github.com/llvm/llvm-project/pull/169199

>From 5f7647ef69938e9e68c33b4077cb8d9286818fd8 Mon Sep 17 00:00:00 2001
From: Lauren Chin <lchin at berkeley.edu>
Date: Sun, 23 Nov 2025 00:51:51 -0500
Subject: [PATCH] [DAG] foldShiftToAvg - Fixes avgceil[su] pattern matching for
 sub+xor form

Fixes regression where avgceil[su] patterns fail to match when AArch64
canonicalizes `(add (add x, y), 1)` to `(sub x, (xor y, -1))`, causing
SVE/SVE2 test failures.

Addresses the remaining regression in (https://github.com/llvm/llvm-project/issues/147946)[#147946]
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 333 ++++++++++++++++++
 1 file changed, 333 insertions(+)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b79dbb46cadc..d9bbed3a3f61e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11968,6 +11968,339 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
   return SDValue();
 }
 
+// // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
+// SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
+// 	const unsigned Opcode = N->getOpcode();
+//   if (Opcode != ISD::SRA && Opcode != ISD::SRL)
+//     return SDValue();
+
+// 	LLVM_DEBUG(dbgs() << "\n--- foldShiftToAvg() ---\n");
+
+//   EVT VT = N->getValueType(0);
+//   SDValue N0 = N->getOperand(0);
+
+// 	// Bail if N0 is a leaf node - can't be an avg pattern
+//   if (N0.getNumOperands() == 0) {
+//     LLVM_DEBUG(dbgs() << "\n--- foldShiftToAvg(): Return SDValue();\n\t[N0 is
+//     a leaf node]\n\n"); return SDValue();
+//   }
+
+// 	if (!isOnesOrOnesSplat(N->getOperand(1)))
+// 		return SDValue();
+
+// 	if (sd_match(N->getOperand(1), m_AllOnes())) {
+// 		LLVM_DEBUG(dbgs() << "\n+ Matched m_AllOnes()\n\n");
+// 	}
+
+//   if (!sd_match(N->getOperand(1), m_One())) {
+// 		LLVM_DEBUG(dbgs() << "\n--- foldShiftToAvg(): Return
+// SDValue();\n\t[not a shift by 1]\n\n");
+//     return SDValue();
+// 	}
+
+// 	LLVM_DEBUG(dbgs() << "\npassed all splat checks\n");
+//   // [TruncVT]
+//   // result type of a single truncate user fed by this shift node (if
+//   present).
+//   // We always use TruncVT to verify whether the target supports folding to
+//   // avgceils. For avgfloor[su], we use TruncVT if present, else VT.
+//   //
+//   // [NarrowVT]
+//   // semantic source width of the value(s) being averaged when the ops are
+//   // SExt/SExtInReg.
+//   EVT TruncVT = VT;
+//   SDNode *TruncNode = nullptr;
+
+//   // If this shift has a single truncate user, use it to decide whether
+//   folding
+//   // to avg* is legal at the truncated width. Note that the target may only
+//   // support the avgceil[su]/avgfloor[su] op at the narrower type, or the
+//   // full-width VT, but we check for legality using the truncate node's VT if
+//   // present, else this shift's VT.
+//   if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+//     TruncNode = *N->user_begin();
+//     TruncVT = TruncNode->getValueType(0);
+//   }
+
+//   EVT NarrowVT = VT;
+//   SDValue N00 = N0.getOperand(0);
+
+// 	// For SRL of SIGN_EXTEND_INREG values, check if the narrow type is
+// legal.
+// 	// If not, bail out to prevent incorrect folding at the wider type.
+// 	// This ensures operations like srhadd are generated at the correct
+// width. 	if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) { 		NarrowVT =
+// cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT(); 		if (Opcode ==
+// ISD::SRL && !TLI.isTypeLegal(NarrowVT)) 			return SDValue();
+// 	}
+
+//   unsigned FloorISD = 0;
+//   unsigned CeilISD = 0;
+//   bool IsUnsigned = false;
+
+//   // Decide whether signed or unsigned.
+//   switch (Opcode) {
+//   case ISD::SRA:
+//     FloorISD = ISD::AVGFLOORS;
+//     break;
+//   case ISD::SRL:
+//     IsUnsigned = true;
+//     // SRL of a widened signed sub feeding a truncate acts like shadd.
+//     if (TruncNode &&
+//         (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) &&
+//         (N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+//          N00.getOpcode() == ISD::SIGN_EXTEND))
+//       IsUnsigned = false;
+//     FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+//     break;
+//   default:
+// 		LLVM_DEBUG(dbgs() << "\n--- foldShiftToAvg(): Return
+// SDValue();\n\t[switch default]\n\n");
+//     return SDValue();
+//   }
+
+//   CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+//   // Bail out if this shift is not truncated and the target doesn't support
+//   // the avg* op at this shift's VT (or TruncVT for avgceil[su]).
+//   if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+//       (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT))) {
+// 		LLVM_DEBUG(dbgs() << "\n--- foldShiftToAvg(): Return
+// SDValue();\n\t[shift isn't truncated && target doesnt support op]\n\n");
+//     return SDValue();
+// 	}
+
+//   SDValue X, Y, Sub, Xor;
+
+//   // (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
+//   if (sd_match(N, m_BinOp(Opcode,
+//                           m_AllOf(m_Value(Sub),
+//                                   m_Sub(m_Value(X),
+//                                         m_AllOf(m_Value(Xor),
+//                                                 m_Xor(m_Value(Y),
+//                                                 m_Value())))),
+//                           m_One()))) {
+// 		APInt SplatVal;
+
+// 		if (isAllOnesOrAllOnesSplat(Xor.getOperand(1)) ||
+// 			(ISD::isConstantSplatVector(Xor.getOperand(1).getNode(),
+// SplatVal) && 			SplatVal.trunc(VT.getScalarSizeInBits()).isAllOnes())) {
+// 				// - Can't fold if either op is
+// sign/zero-extended for SRL, as SRL
+// 				//   is unsigned, and shadd patterns are handled
+// elsewhere.
+// 				//
+// 				// - Large fixed vectors (>128 bits) on AArch64
+// will be type-legalized
+// 				//   into a series of EXTRACT_SUBVECTORs.
+// Folding each subvector does not
+// 				//   necessarily preserve semantics so they
+// cannot be folded here. 				if (TruncNode && VT.isFixedLengthVector()) { 					if
+// (X.getOpcode() == ISD::SIGN_EXTEND || 							X.getOpcode() == ISD::ZERO_EXTEND ||
+// 							Y.getOpcode() ==
+// ISD::SIGN_EXTEND || 							Y.getOpcode() == ISD::ZERO_EXTEND || VT.getSizeInBits() >
+// 128) 							return SDValue();
+// 				}
+
+// 				// If there is no truncate user, ensure the
+// relevant no wrap flag is on
+// 				// the sub so that narrowing the widened result
+// is defined. 				if (Opcode == ISD::SRA && VT == NarrowVT) { 					if (!IsUnsigned &&
+// !Sub->getFlags().hasNoSignedWrap()) 						return SDValue(); 				} else if (IsUnsigned
+// && !Sub->getFlags().hasNoUnsignedWrap()) 					return SDValue();
+
+// 				// Only fold if the target supports avgceil[su]
+// at the truncated type:
+// 				// - if there is a single truncate user, we
+// require support at TruncVT.
+// 				//   We build the avg* at VT (to replace this
+// shift node).
+// 				//   visitTRUNCATE handles the actual folding to
+// avgceils (x, y).
+// 				// - otherwise, we require support at VT
+// (TruncVT == VT).
+// 				//
+// 				// AArch64 canonicalizes (x + y + 1) >> 1 -> sub
+// (x, xor (y, -1)). In
+// 				// order for our fold to be legal, we require
+// support for the VT at the
+// 				// final observable type (TruncVT or VT).
+// 				if (TLI.isOperationLegalOrCustom(CeilISD,
+// TruncVT)) 					return DAG.getNode(CeilISD, DL, VT, Y, X);
+// 			}
+// 	}
+
+//   // Captured values.
+//   SDValue A, B, Add;
+
+//   // Match floor average as it is common to both floor/ceil avgs.
+//   // (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+//   if (!sd_match(N, m_BinOp(Opcode,
+//                            m_AllOf(m_Value(Add), m_Add(m_Value(A),
+//                            m_Value(B))), m_One())))
+//     return SDValue();
+
+//   if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+//     return SDValue();
+
+//   // Can't optimize adds that may wrap.
+//   if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+//       (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+//     return SDValue();
+
+//   EVT TargetVT = TruncNode ? TruncVT : VT;
+//   if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+//     return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
+//   return SDValue();
+// }
+
+// // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
+// SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
+//   const unsigned Opcode = N->getOpcode();
+//   if (Opcode != ISD::SRA && Opcode != ISD::SRL)
+//     return SDValue();
+
+//   EVT VT = N->getValueType(0);
+//   SDValue N0 = N->getOperand(0);
+// 	LLVM_DEBUG(dbgs() << "\n[foldShiftToAvg]\n" << "\n");
+
+//   if (!isOnesOrOnesSplat(N->getOperand(1)))
+//     return SDValue();
+
+//   EVT TruncVT = VT;
+//   SDNode *TruncNode = nullptr;
+
+//   // We need the correct type to check for avgceil/floor support.
+//   if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+//     TruncNode = *N->user_begin();
+//     TruncVT = TruncNode->getValueType(0);
+//   }
+
+//   // NarrowVT is used to detect whether we're working with sign-extended
+//   values. EVT NarrowVT = VT;
+
+//   // Extract narrow type from SIGN_EXTEND_INREG. For SRL, require the narrow
+//   // type to be legal to ensure correct width avg operations.
+// 	// if (N0.getNumOperands() > 0 && N0.getOperand(0) &&
+// N0.getOperand(0).getOpcode() == ISD::SIGN_EXTEND_INREG) {
+//   //   SDValue N00 = N0.getOperand(0);
+//   //   NarrowVT = cast<VTSDNode>(N00.getOperand(1))->getVT();
+//   //   if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+//   //     return SDValue();
+//   // }
+// 	// Check if N0 itself is SIGN_EXTEND_INREG
+// 	if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+// 		NarrowVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
+// 		if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+// 			return SDValue();
+// 	}
+
+//   unsigned FloorISD = 0;
+//   unsigned CeilISD = 0;
+//   bool IsUnsigned = false;
+
+//   // Decide whether signed or unsigned.
+//   switch (Opcode) {
+//   case ISD::SRA:
+//     FloorISD = ISD::AVGFLOORS;
+//     break;
+//   case ISD::SRL:
+//     IsUnsigned = true;
+//     if (TruncNode &&
+//         (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB)) {
+//       // Use signed avg for SRL of sign-extended values when truncating.
+// 			SDValue N00 = N0.getOperand(0);
+// 			SDValue N01 = N0.getOperand(1);
+// 			if ((N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+// 					N00.getOpcode() == ISD::SIGN_EXTEND) ||
+// 					(N01.getOpcode() ==
+// ISD::SIGN_EXTEND_INREG || 					N01.getOpcode() == ISD::SIGN_EXTEND)) 				IsUnsigned =
+// false;
+
+//     }
+//     FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+//     break;
+//   default:
+//     return SDValue();
+//   }
+
+//   CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+//   // Without truncation, require target support for both averaging
+//   operations.
+//   // We check FloorISD at VT (generated type), CeilISD at TruncVT (final
+//   type). if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+//       (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
+//     return SDValue();
+
+//   SDValue X, Y, Sub, Xor;
+
+//   // fold (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
+//   if (sd_match(N, m_BinOp(Opcode,
+//                           m_AllOf(m_Value(Sub),
+//                                   m_Sub(m_Value(X),
+//                                         m_AllOf(m_Value(Xor),
+//                                                 m_Xor(m_Value(Y),
+//                                                 m_Value())))),
+//                           m_One()))) {
+// 		LLVM_DEBUG(dbgs() << "pattern 1\n" << "\n");
+
+//     ConstantSDNode *C = isConstOrConstSplat(Xor.getOperand(1),
+//                                             /*AllowUndefs=*/false,
+//                                             /*AllowTruncation=*/true);
+//     if (C && C->getAPIntValue().trunc(VT.getScalarSizeInBits()).isAllOnes())
+//     {
+//       // Don't fold extended inputs with truncation on fixed vectors > 128b
+//       if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+//       {
+//         if (X.getOpcode() == ISD::SIGN_EXTEND ||
+//             X.getOpcode() == ISD::ZERO_EXTEND ||
+//             Y.getOpcode() == ISD::SIGN_EXTEND ||
+//             Y.getOpcode() == ISD::ZERO_EXTEND)
+//           return SDValue();
+//       }
+
+//       if (!TruncNode) {
+//         // Without truncation, require no-wrap flags for safe narrowing.
+//         const SDNodeFlags &Flags = Sub->getFlags();
+//         if ((!IsUnsigned && (Opcode == ISD::SRA && VT == NarrowVT) &&
+//              !Flags.hasNoSignedWrap()) ||
+//             (IsUnsigned && !Flags.hasNoUnsignedWrap()))
+//           return SDValue();
+//       }
+
+//       // Require avgceil[su] support at the final type:
+//       //  - with truncation: build at VT, visitTRUNCATE completes the fold
+//       //  - without truncation: build directly at VT (where TruncVT == VT).
+//       if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
+//         return DAG.getNode(CeilISD, DL, VT, Y, X);
+//     }
+//   }
+
+//   // Captured values.
+//   SDValue A, B, Add;
+
+//   // Match floor average as it is common to both floor/ceil avgs.
+//   // fold (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+//   if (!sd_match(N, m_BinOp(Opcode,
+//                            m_AllOf(m_Value(Add), m_Add(m_Value(A),
+//                            m_Value(B))), m_One())))
+//     return SDValue();
+// 	LLVM_DEBUG(dbgs() << "pattern 2\n" << "\n");
+//   if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+//     return SDValue();
+
+//   // Can't optimize adds that may wrap.
+//   if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+//       (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+//     return SDValue();
+
+//   EVT TargetVT = TruncNode ? TruncVT : VT;
+//   if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+//     return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
+//   return SDValue();
+// }
+
 SDValue DAGCombiner::foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT) {
   unsigned Opc = N->getOpcode();
   SDValue X, Y, Z;



More information about the llvm-commits mailing list