[llvm] [SelectionDAG] Replace some basic patterns in visitADDLike with SDPatternMatch (PR #84759)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 06:12:40 PDT 2024


https://github.com/XChy created https://github.com/llvm/llvm-project/pull/84759

Resolves #84745.

Based on SDPatternMatch introduced by #78654, this patch replaces some of basic patterns in `visitADDLike` with corresponding patterns in SDPatternMatch.

This patch only replaces original folds, instead of introducing new one, thus there are no new tests. If new tests are needed, let me know please.

>From 4bc54ca4d6ee64b5eb8da2162ef3dcfe584b63f2 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Mon, 11 Mar 2024 21:02:18 +0800
Subject: [PATCH] [SelectionDAG] Replace some basic patterns in visitADDLike
 with SDPatternMatch

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 56 +++++++++----------
 1 file changed, 26 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cdcb7114640471..858cd892f41766 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -38,6 +38,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/RuntimeLibcalls.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -79,6 +80,7 @@
 #include "MatchContext.h"
 
 using namespace llvm;
+using namespace llvm::SDPatternMatch;
 
 #define DEBUG_TYPE "dagcombine"
 
@@ -2697,52 +2699,46 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
             reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
       return SD;
   }
+
+  SDValue A, B, C;
+
   // fold ((0-A) + B) -> B-A
-  if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
-    return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
+  if (sd_match(N0, m_Sub(m_Zero(), m_Value(A))))
+    return DAG.getNode(ISD::SUB, DL, VT, N1, A);
 
   // fold (A + (0-B)) -> A-B
-  if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
-    return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
+  if (sd_match(N1, m_Sub(m_Zero(), m_Value(B))))
+    return DAG.getNode(ISD::SUB, DL, VT, N0, B);
 
   // fold (A+(B-A)) -> B
-  if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
-    return N1.getOperand(0);
+  if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
+    return B;
 
   // fold ((B-A)+A) -> B
-  if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
-    return N0.getOperand(0);
+  if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
+    return B;
 
   // fold ((A-B)+(C-A)) -> (C-B)
-  if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
-      N0.getOperand(0) == N1.getOperand(1))
-    return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
-                       N0.getOperand(1));
+  if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
+      sd_match(N1, m_Sub(m_Value(C), m_Specific(A))))
+    return DAG.getNode(ISD::SUB, DL, VT, C, B);
 
   // fold ((A-B)+(B-C)) -> (A-C)
-  if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
-      N0.getOperand(1) == N1.getOperand(0))
-    return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
-                       N1.getOperand(1));
+  if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
+      sd_match(N1, m_Sub(m_Specific(B), m_Value(C))))
+    return DAG.getNode(ISD::SUB, DL, VT, A, C);
 
   // fold (A+(B-(A+C))) to (B-C)
-  if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
-      N0 == N1.getOperand(1).getOperand(0))
-    return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
-                       N1.getOperand(1).getOperand(1));
-
   // fold (A+(B-(C+A))) to (B-C)
-  if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
-      N0 == N1.getOperand(1).getOperand(1))
-    return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
-                       N1.getOperand(1).getOperand(0));
+  if (sd_match(N1, m_Sub(m_Value(B),
+                         m_c_BinOp(ISD::ADD, m_Specific(N0), m_Value(C)))))
+    return DAG.getNode(ISD::SUB, DL, VT, B, C);
 
   // fold (A+((B-A)+or-C)) to (B+or-C)
-  if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
-      N1.getOperand(0).getOpcode() == ISD::SUB &&
-      N0 == N1.getOperand(0).getOperand(1))
-    return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
-                       N1.getOperand(1));
+  if (sd_match(N1,
+               m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)),
+                       m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
+    return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
 
   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&



More information about the llvm-commits mailing list