[llvm] 9f5c8de - [DAG] visitAVG - rewrite "fold (avgfloor x, 0) -> x >> 1" to use SDPatternMatch

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun May 19 03:36:06 PDT 2024


Author: Simon Pilgrim
Date: 2024-05-19T11:30:20+01:00
New Revision: 9f5c8de3864b0be27a8b36cd891c5a28a3acfd27

URL: https://github.com/llvm/llvm-project/commit/9f5c8de3864b0be27a8b36cd891c5a28a3acfd27
DIFF: https://github.com/llvm/llvm-project/commit/9f5c8de3864b0be27a8b36cd891c5a28a3acfd27.diff

LOG: [DAG] visitAVG - rewrite "fold (avgfloor x, 0) -> x >> 1" to use SDPatternMatch

No need for this to be vector specific, and its more likely that scalar cases will appear after #92096

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2b1dec8205b73..bf85212e6a92e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5211,30 +5211,28 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
 
-  if (VT.isVector()) {
+  if (VT.isVector())
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
       return FoldedVOp;
 
-    // fold (avgfloor x, 0) -> x >> 1
-    if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
-      if (Opcode == ISD::AVGFLOORS)
-        return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT));
-      if (Opcode == ISD::AVGFLOORU)
-        return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT));
-    }
-  }
-
   // fold (avg x, undef) -> x
   if (N0.isUndef())
     return N1;
   if (N1.isUndef())
     return N0;
 
-  // Fold (avg x, x) --> x
+  // fold (avg x, x) --> x
   if (N0 == N1 && Level >= AfterLegalizeTypes)
     return N0;
 
-  // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
+  // fold (avgfloor x, 0) -> x >> 1
+  SDValue X;
+  if (sd_match(N, m_c_BinOp(ISD::AVGFLOORS, m_Value(X), m_Zero())))
+    return DAG.getNode(ISD::SRA, DL, VT, X,
+                       DAG.getShiftAmountConstant(1, VT, DL));
+  if (sd_match(N, m_c_BinOp(ISD::AVGFLOORU, m_Value(X), m_Zero())))
+    return DAG.getNode(ISD::SRL, DL, VT, X,
+                       DAG.getShiftAmountConstant(1, VT, DL));
 
   return SDValue();
 }


        


More information about the llvm-commits mailing list