[llvm] [SelectionDAG] Replace some basic patterns in visitADDLike with SDPatternMatch (PR #84759)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 06:15:22 PDT 2024


================
@@ -2697,52 +2699,46 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
             reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
       return SD;
   }
+
+  SDValue A, B, C;
+
   // fold ((0-A) + B) -> B-A
-  if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
-    return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
+  if (sd_match(N0, m_Sub(m_Zero(), m_Value(A))))
+    return DAG.getNode(ISD::SUB, DL, VT, N1, A);
 
   // fold (A + (0-B)) -> A-B
-  if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
-    return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
+  if (sd_match(N1, m_Sub(m_Zero(), m_Value(B))))
+    return DAG.getNode(ISD::SUB, DL, VT, N0, B);
 
   // fold (A+(B-A)) -> B
-  if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
-    return N1.getOperand(0);
+  if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
+    return B;
 
   // fold ((B-A)+A) -> B
-  if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
-    return N0.getOperand(0);
+  if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
+    return B;
 
   // fold ((A-B)+(C-A)) -> (C-B)
-  if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
-      N0.getOperand(0) == N1.getOperand(1))
-    return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
-                       N0.getOperand(1));
+  if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
+      sd_match(N1, m_Sub(m_Value(C), m_Specific(A))))
+    return DAG.getNode(ISD::SUB, DL, VT, C, B);
 
   // fold ((A-B)+(B-C)) -> (A-C)
-  if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
-      N0.getOperand(1) == N1.getOperand(0))
-    return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
-                       N1.getOperand(1));
+  if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
+      sd_match(N1, m_Sub(m_Specific(B), m_Value(C))))
+    return DAG.getNode(ISD::SUB, DL, VT, A, C);
 
   // fold (A+(B-(A+C))) to (B-C)
-  if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
-      N0 == N1.getOperand(1).getOperand(0))
-    return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
-                       N1.getOperand(1).getOperand(1));
-
   // fold (A+(B-(C+A))) to (B-C)
-  if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
-      N0 == N1.getOperand(1).getOperand(1))
-    return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
-                       N1.getOperand(1).getOperand(0));
+  if (sd_match(N1, m_Sub(m_Value(B),
+                         m_c_BinOp(ISD::ADD, m_Specific(N0), m_Value(C)))))
+    return DAG.getNode(ISD::SUB, DL, VT, B, C);
 
   // fold (A+((B-A)+or-C)) to (B+or-C)
-  if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
-      N1.getOperand(0).getOpcode() == ISD::SUB &&
-      N0 == N1.getOperand(0).getOperand(1))
-    return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
-                       N1.getOperand(1));
+  if (sd_match(N1,
+               m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)),
+                       m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
+    return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
----------------
XChy wrote:

I think we need `m_BinOp(LHS, RHS)` without the constraint on opcode to express better here.

https://github.com/llvm/llvm-project/pull/84759


More information about the llvm-commits mailing list