[llvm] [DAG] Add legalization handling for AVGCEIL/AVGFLOOR nodes (PR #92096)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 22 08:56:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Simon Pilgrim (RKSimon)

<details>
<summary>Changes</summary>

Always match AVG patterns pre-legalization, and use TargetLowering::expandAVG to expand again during legalization.

I've removed the X86 custom AVGCEILU pattern detection and replaced with combines to try and convert other AVG nodes to AVGCEILU.

---

Patch is 690.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92096.diff


19 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+5) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+19-4) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+7) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+14) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+13) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+10-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+8-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+51-7) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+33-165) 
- (modified) llvm/test/CodeGen/AArch64/arm64-vhadd.ll (+16-16) 
- (modified) llvm/test/CodeGen/AArch64/sve-hadd.ll (+62-46) 
- (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+29-57) 
- (modified) llvm/test/CodeGen/X86/avg.ll (+334-767) 
- (modified) llvm/test/CodeGen/X86/avgceils.ll (+940-3020) 
- (modified) llvm/test/CodeGen/X86/avgceilu.ll (+355-1434) 
- (modified) llvm/test/CodeGen/X86/avgfloors.ll (+968-2554) 
- (modified) llvm/test/CodeGen/X86/avgflooru.ll (+647-1892) 
- (modified) llvm/test/CodeGen/X86/min-legal-vector-width.ll (+4-4) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 50a8c7eb75af5..92c6ecb996c07 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5332,6 +5332,11 @@ class TargetLowering : public TargetLoweringBase {
   /// \returns The expansion result or SDValue() if it fails.
   SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
 
+  /// Expand vector/scalar AVGCEILS/AVGCEILU/AVGFLOORS/AVGFLOORU nodes.
+  /// \param N Node to expand
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandAVG(SDNode *N, SelectionDAG &DAG) const;
+
   /// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
   /// scalar types. Returns SDValue() if expand fails.
   /// \param N Node to expand
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8607b50175359..a224f8916c690 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2578,13 +2578,13 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
   EVT VT = N0.getValueType();
   SDValue A, B;
 
-  if (hasOperation(ISD::AVGCEILU, VT) &&
+  if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
       sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
                         m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
                               m_SpecificInt(1))))) {
     return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
   }
-  if (hasOperation(ISD::AVGCEILS, VT) &&
+  if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
       sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
                         m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
                               m_SpecificInt(1))))) {
@@ -2950,13 +2950,13 @@ SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
   EVT VT = N0.getValueType();
   SDValue A, B;
 
-  if (hasOperation(ISD::AVGFLOORU, VT) &&
+  if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
       sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
                         m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
                               m_SpecificInt(1))))) {
     return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
   }
-  if (hasOperation(ISD::AVGFLOORS, VT) &&
+  if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
       sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
                         m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
                               m_SpecificInt(1))))) {
@@ -5234,6 +5234,21 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
     return DAG.getNode(ISD::SRL, DL, VT, X,
                        DAG.getShiftAmountConstant(1, VT, DL));
 
+  // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
+  // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
+  // Check if avgflooru isn't legal/custom but avgceilu is.
+  if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
+      (!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
+    if (DAG.isKnownNeverZero(N0))
+      return DAG.getNode(
+          ISD::AVGCEILU, DL, VT, N1,
+          DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
+    if (DAG.isKnownNeverZero(N1))
+      return DAG.getNode(
+          ISD::AVGCEILU, DL, VT, N0,
+          DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
+  }
+
   return SDValue();
 }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index bfc2273c9425c..d72ac548e4fb3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3054,6 +3054,13 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     if ((Tmp1 = TLI.expandABD(Node, DAG)))
       Results.push_back(Tmp1);
     break;
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
+    if ((Tmp1 = TLI.expandAVG(Node, DAG)))
+      Results.push_back(Tmp1);
+    break;
   case ISD::CTPOP:
     if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
       Results.push_back(Tmp1);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index c64e27fe45634..e64a59dcdca65 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -188,6 +188,8 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_SUB:
   case ISD::VP_MUL:      Res = PromoteIntRes_SimpleIntBinOp(N); break;
 
+  case ISD::AVGCEILS:
+  case ISD::AVGFLOORS:
   case ISD::VP_SMIN:
   case ISD::VP_SMAX:
   case ISD::SDIV:
@@ -195,6 +197,8 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_SDIV:
   case ISD::VP_SREM:     Res = PromoteIntRes_SExtIntBinOp(N); break;
 
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORU:
   case ISD::VP_UMIN:
   case ISD::VP_UMAX:
   case ISD::UDIV:
@@ -2775,6 +2779,11 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SSHLSAT:
   case ISD::USHLSAT: ExpandIntRes_SHLSAT(N, Lo, Hi); break;
 
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU: 
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU: ExpandIntRes_AVG(N, Lo, Hi); break;
+
   case ISD::SMULFIX:
   case ISD::SMULFIXSAT:
   case ISD::UMULFIX:
@@ -4077,6 +4086,11 @@ void DAGTypeLegalizer::ExpandIntRes_READCOUNTER(SDNode *N, SDValue &Lo,
   ReplaceValueWith(SDValue(N, 1), R.getValue(2));
 }
 
+void DAGTypeLegalizer::ExpandIntRes_AVG(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  SDValue Result = TLI.expandAVG(N, DAG);
+  SplitInteger(Result, Lo, Hi);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_ADDSUBSAT(SDNode *N, SDValue &Lo,
                                               SDValue &Hi) {
   SDValue Result = TLI.expandAddSubSat(N, DAG);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index d925089d5689f..f35dc655a5b36 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -460,6 +460,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_XMULO             (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_AVG               (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_ADDSUBSAT         (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_SHLSAT            (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_MULFIX            (SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6acbc044d6731..ab38eac031d7c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -369,6 +369,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::ABS:
   case ISD::ABDS:
   case ISD::ABDU:
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
   case ISD::BSWAP:
   case ISD::BITREVERSE:
   case ISD::CTLZ:
@@ -916,6 +920,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
       return;
     }
     break;
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
+    if (SDValue Expanded = TLI.expandAVG(Node, DAG)) {
+      Results.push_back(Expanded);
+      return;
+    }
+    break;
   case ISD::BITREVERSE:
     ExpandBITREVERSE(Node, Results);
     return;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ec05135915664..8aeee0248d203 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -125,6 +125,10 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
     break;
   case ISD::ADD:
   case ISD::AND:
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
   case ISD::FADD:
   case ISD::FCOPYSIGN:
   case ISD::FDIV:
@@ -1171,7 +1175,12 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::MUL: case ISD::VP_MUL:
   case ISD::MULHS:
   case ISD::MULHU:
-  case ISD::FADD: case ISD::VP_FADD:
+  case ISD::AVGCEILS:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGFLOORU:
+  case ISD::FADD:
+  case ISD::VP_FADD:
   case ISD::FSUB: case ISD::VP_FSUB:
   case ISD::FMUL: case ISD::VP_FMUL:
   case ISD::FMINNUM: case ISD::VP_FMINNUM:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b05649c6ce955..f88981a423396 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4588,8 +4588,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
     // SRA X, C -> adds C sign bits.
     if (const APInt *ShAmt =
-            getValidMinimumShiftAmountConstant(Op, DemandedElts))
+            getValidMinimumShiftAmountConstant(Op, DemandedElts)) {
       Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
+    } else {
+      KnownBits KnownAmt =
+          computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
+      if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(VTBits))
+        Tmp = std::min<uint64_t>(Tmp + KnownAmt.getConstant().getZExtValue(),
+                                 VTBits);
+    }
     return Tmp;
   case ISD::SHL:
     if (const APInt *ShAmt =
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 87c4c62522c1b..265c4e302e5b0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -951,11 +951,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
 
 // Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
 //      or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
-static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
+static SDValue combineShiftToAVG(SDValue Op,
+                                 TargetLowering::TargetLoweringOpt &TLO,
                                  const TargetLowering &TLI,
                                  const APInt &DemandedBits,
-                                 const APInt &DemandedElts,
-                                 unsigned Depth) {
+                                 const APInt &DemandedElts, unsigned Depth) {
   assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
          "SRL or SRA node is required here!");
   // Is the right shift using an immediate value of 1?
@@ -1006,6 +1006,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
   // If the shift is unsigned (srl):
   //  - Needs >= 1 zero bit for both operands.
   //  - Needs 1 demanded bit zero and >= 2 sign bits.
+  SelectionDAG &DAG = TLO.DAG;
   unsigned ShiftOpc = Op.getOpcode();
   bool IsSigned = false;
   unsigned KnownBits;
@@ -1061,10 +1062,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
   EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
   if (VT.isVector())
     NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
-  if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT)) {
+  if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, NVT)) {
     // If we could not transform, and (both) adds are nuw/nsw, we can use the
     // larger type size to do the transform.
-    if (!TLI.isOperationLegalOrCustom(AVGOpc, VT))
+    if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
       return SDValue();
     if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
                                Add.getOperand(1)) &&
@@ -2017,7 +2018,7 @@ bool TargetLowering::SimplifyDemandedBits(
     }
 
     // Try to match AVG patterns (after shift simplification).
-    if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
+    if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
                                         DemandedElts, Depth + 1))
       return TLO.CombineTo(Op, AVG);
 
@@ -2130,7 +2131,7 @@ bool TargetLowering::SimplifyDemandedBits(
     }
 
     // Try to match AVG patterns (after shift simplification).
-    if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
+    if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
                                         DemandedElts, Depth + 1))
       return TLO.CombineTo(Op, AVG);
 
@@ -9231,6 +9232,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
                        DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
 }
 
+SDValue TargetLowering::expandAVG(SDNode *N, SelectionDAG &DAG) const {
+  SDLoc dl(N);
+  EVT VT = N->getValueType(0);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  unsigned Opc = N->getOpcode();
+  bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
+  bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
+  unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
+  assert((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
+          Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
+         "Unknown AVG node");
+
+  // If the operands are already extended, we can add+shift.
+  bool IsExt =
+      (IsSigned && DAG.ComputeNumSignBits(LHS) >= 2 &&
+       DAG.ComputeNumSignBits(RHS) >= 2) ||
+      (!IsSigned && DAG.computeKnownBits(LHS).countMinLeadingZeros() >= 1 &&
+       DAG.computeKnownBits(RHS).countMinLeadingZeros() >= 1);
+  if (IsExt) {
+    SDValue Sum = DAG.getNode(ISD::ADD, dl, VT, LHS, RHS);
+    if (!IsFloor)
+      Sum = DAG.getNode(ISD::ADD, dl, VT, Sum, DAG.getConstant(1, dl, VT));
+    return DAG.getNode(ShiftOpc, dl, VT, Sum,
+                       DAG.getShiftAmountConstant(1, VT, dl));
+  }
+
+  // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
+  // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
+  // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
+  // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
+  unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
+  unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
+  LHS = DAG.getFreeze(LHS);
+  RHS = DAG.getFreeze(RHS);
+  SDValue Sign = DAG.getNode(SignOpc, dl, VT, LHS, RHS);
+  SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
+  SDValue Shift =
+      DAG.getNode(ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant(1, VT, dl));
+  return DAG.getNode(SumOpc, dl, VT, Sign, Shift);
+}
+
 SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
   SDLoc dl(N);
   EVT VT = N->getValueType(0);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 37c591f90f0a3..44b20e90f5760 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -2501,6 +2501,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
                        ISD::SRL,
                        ISD::OR,
                        ISD::AND,
+                       ISD::AVGCEILS,
+                       ISD::AVGCEILU,
+                       ISD::AVGFLOORS,
+                       ISD::AVGFLOORU,
                        ISD::BITREVERSE,
                        ISD::ADD,
                        ISD::FADD,
@@ -50497,157 +50501,6 @@ static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL,
   return SDValue();
 }
 
-/// This function detects the AVG pattern between vectors of unsigned i8/i16,
-/// which is c = (a + b + 1) / 2, and replace this operation with the efficient
-/// ISD::AVGCEILU (AVG) instruction.
-static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
-                                const X86Subtarget &Subtarget,
-                                const SDLoc &DL) {
-  if (!VT.isVector())
-    return SDValue();
-  EVT InVT = In.getValueType();
-  unsigned NumElems = VT.getVectorNumElements();
-
-  EVT ScalarVT = VT.getVectorElementType();
-  if (!((ScalarVT == MVT::i8 || ScalarVT == MVT::i16) && NumElems >= 2))
-    return SDValue();
-
-  // InScalarVT is the intermediate type in AVG pattern and it should be greater
-  // than the original input type (i8/i16).
-  EVT InScalarVT = InVT.getVectorElementType();
-  if (InScalarVT.getFixedSizeInBits() <= ScalarVT.getFixedSizeInBits())
-    return SDValue();
-
-  if (!Subtarget.hasSSE2())
-    return SDValue();
-
-  // Detect the following pattern:
-  //
-  //   %1 = zext <N x i8> %a to <N x i32>
-  //   %2 = zext <N x i8> %b to <N x i32>
-  //   %3 = add nuw nsw <N x i32> %1, <i32 1 x N>
-  //   %4 = add nuw nsw <N x i32> %3, %2
-  //   %5 = lshr <N x i32> %N, <i32 1 x N>
-  //   %6 = trunc <N x i32> %5 to <N x i8>
-  //
-  // In AVX512, the last instruction can also be a trunc store.
-  if (In.getOpcode() != ISD::SRL)
-    return SDValue();
-
-  // A lambda checking the given SDValue is a constant vector and each element
-  // is in the range [Min, Max].
-  auto IsConstVectorInRange = [](SDValue V, unsigned Min, unsigned Max) {
-    return ISD::matchUnaryPredicate(V, [Min, Max](ConstantSDNode *C) {
-      return !(C->getAPIntValue().ult(Min) || C->getAPIntValue().ugt(Max));
-    });
-  };
-
-  auto IsZExtLike = [DAG = &DAG, ScalarVT](SDValue V) {
-    unsigned MaxActiveBits = DAG->computeKnownBits(V).countMaxActiveBits();
-    return MaxActiveBits <= ScalarVT.getSizeInBits();
-  };
-
-  // Check if each element of the vector is right-shifted by one.
-  SDValue LHS = In.getOperand(0);
-  SDValue RHS = In.getOperand(1);
-  if (!IsConstVectorInRange(RHS, 1, 1))
-    return SDValue();
-  if (LHS.getOpcode() != ISD::ADD)
-    return SDValue();
-
-  // Detect a pattern of a + b + 1 where the order doesn't matter.
-  SDValue Operands[3];
-  Operands[0] = LHS.getOperand(0);
-  Operands[1] = LHS.getOperand(1);
-
-  auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
-                       ArrayRef<SDValue> Ops) {
-    return DAG.getNode(ISD::AVGCEILU, DL, Ops[0].getValueType(), Ops);
-  };
-
-  auto AVGSplitter = [&](std::array<SDValue, 2> Ops) {
-    for (SDValue &Op : Ops)
-      if (Op.getValueType() != VT)
-        Op = DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
-    // Pad to a power-of-2 vector, split+apply and extract the original vector.
-    unsigned NumElemsPow2 = PowerOf2Ceil(NumElems);
-    EVT Pow2VT = EVT::getVectorVT(*DAG.getContext(), ScalarVT, NumElemsPow2);
-    if (NumElemsPow2 != NumElems) {
-      for (SDValue &Op : Ops) {
-        SmallVector<SDValue, 32> EltsOfOp(NumElemsPow2, DAG.getUNDEF(ScalarVT));
-        for (unsigned i = 0; i != NumElems; ++i) {
-          SDValue Idx = DAG.getIntPtrConstant(i, DL);
-          EltsOfOp[i] =
-              DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Op, Idx);
-        }
-        Op = DAG.getBuildVector(Pow2VT, DL, EltsOfOp);
-      }
-    }
-    SDValue Res = SplitOpsAndApply(DAG, Subtarget, DL, Pow2VT, Ops, AVGBuilder);
-    if (NumElemsPow2 == NumElems)
-      return Res;
-    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
-                       DAG.getIntPtrConstant(0, DL));
-  };
-
-  // Take care of the case when one of the operands is a constant vector whose
-  // element is in the range [1, 256].
-  if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) &&
-      IsZExtLike(Operands[0])) {
-    // The pattern is detected. Subtract one from the constant vector, then
-    // demote it and emit X86ISD::AVG instruction.
-    SDValue VecOnes = DAG.getConstant(1, DL, InVT);
-    Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes);
-    return AVGSplitter({Operands[0], Operands[1]});
-  }
-
-  // Matches 'add like' patterns: add(Op0,Op1) + zext(or(Op0,Op1)).
-  // Match the or case only if its 'add-like' - can be replaced by an add.
-  auto FindAddLike = [&](SDValue V, SDValue &Op0, SDValue &Op1) {
-    if (ISD...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/92096


More information about the llvm-commits mailing list