[llvm] a7af53e - [DAG] visitSUB - convert some folds to use SDPatternMatch

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 13 05:00:43 PDT 2024


Author: Simon Pilgrim
Date: 2024-03-13T12:00:24Z
New Revision: a7af53e99bb1fc92f45c14df2acf2da8f849af2f

URL: https://github.com/llvm/llvm-project/commit/a7af53e99bb1fc92f45c14df2acf2da8f849af2f
DIFF: https://github.com/llvm/llvm-project/commit/a7af53e99bb1fc92f45c14df2acf2da8f849af2f.diff

LOG: [DAG] visitSUB - convert some folds to use SDPatternMatch

General cleanup and allows us to handle several commutable matches with a single pattern

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SDPatternMatch.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/xor.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 92b478cec0a077..a86c7400fa0945 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -663,6 +663,7 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
 }
 
 inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
+inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
 inline SpecificInt_match m_AllOnes() { return m_SpecificInt(~0U); }
 
 /// Match true boolean value based on the information provided by

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 735cec8ecc0627..87033e824aa593 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -3789,63 +3789,34 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
       return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
   }
 
-  // fold ((A+(B+or-C))-B) -> A+or-C
-  if (N0.getOpcode() == ISD::ADD &&
-      (N0.getOperand(1).getOpcode() == ISD::SUB ||
-       N0.getOperand(1).getOpcode() == ISD::ADD) &&
-      N0.getOperand(1).getOperand(0) == N1)
-    return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
-                       N0.getOperand(1).getOperand(1));
-
-  // fold ((A+(C+B))-B) -> A+C
-  if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
-      N0.getOperand(1).getOperand(1) == N1)
-    return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
-                       N0.getOperand(1).getOperand(0));
+  SDValue A, B, C;
+
+  // fold ((A+(B+C))-B) -> A+C
+  if (sd_match(N0, m_Add(m_Value(A), m_Add(m_Specific(N1), m_Value(C)))))
+    return DAG.getNode(ISD::ADD, DL, VT, A, C);
+
+  // fold ((A+(B-C))-B) -> A-C
+  if (sd_match(N0, m_Add(m_Value(A), m_Sub(m_Specific(N1), m_Value(C)))))
+    return DAG.getNode(ISD::SUB, DL, VT, A, C);
 
   // fold ((A-(B-C))-C) -> A-B
-  if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
-      N0.getOperand(1).getOperand(1) == N1)
-    return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
-                       N0.getOperand(1).getOperand(0));
+  if (sd_match(N0, m_Sub(m_Value(A), m_Sub(m_Value(B), m_Specific(N1)))))
+    return DAG.getNode(ISD::SUB, DL, VT, A, B);
 
   // fold (A-(B-C)) -> A+(C-B)
-  if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
+  if (sd_match(N1, m_OneUse(m_Sub(m_Value(B), m_Value(C)))))
     return DAG.getNode(ISD::ADD, DL, VT, N0,
-                       DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
-                                   N1.getOperand(0)));
+                       DAG.getNode(ISD::SUB, DL, VT, C, B));
 
   // A - (A & B)  ->  A & (~B)
-  if (N1.getOpcode() == ISD::AND) {
-    SDValue A = N1.getOperand(0);
-    SDValue B = N1.getOperand(1);
-    if (A != N0)
-      std::swap(A, B);
-    if (A == N0 &&
-        (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
-      SDValue InvB =
-          DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
-      return DAG.getNode(ISD::AND, DL, VT, A, InvB);
-    }
-  }
+  if (sd_match(N1, m_And(m_Specific(N0), m_Value(B))) &&
+      (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true)))
+    return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getNOT(DL, B, VT));
 
-  // fold (X - (-Y * Z)) -> (X + (Y * Z))
-  if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
-    if (N1.getOperand(0).getOpcode() == ISD::SUB &&
-        isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
-      SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
-                                N1.getOperand(0).getOperand(1),
-                                N1.getOperand(1));
-      return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
-    }
-    if (N1.getOperand(1).getOpcode() == ISD::SUB &&
-        isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
-      SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
-                                N1.getOperand(0),
-                                N1.getOperand(1).getOperand(1));
-      return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
-    }
-  }
+  // fold (A - (-B * C)) -> (A + (B * C))
+  if (sd_match(N1, m_OneUse(m_Mul(m_Sub(m_Zero(), m_Value(B)), m_Value(C)))))
+    return DAG.getNode(ISD::ADD, DL, VT, N0,
+                       DAG.getNode(ISD::MUL, DL, VT, B, C));
 
   // If either operand of a sub is undef, the result is undef
   if (N0.isUndef())
@@ -3865,12 +3836,9 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
   if (SDValue V = foldSubToUSubSat(VT, N))
     return V;
 
-  // (x - y) - 1  ->  add (xor y, -1), x
-  if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() && isOneOrOneSplat(N1)) {
-    SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
-                              DAG.getAllOnesConstant(DL, VT));
-    return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
-  }
+  // (A - B) - 1  ->  add (xor B, -1), A
+  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+    return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
 
   // Look for:
   //   sub y, (xor x, -1)

diff  --git a/llvm/test/CodeGen/AArch64/xor.ll b/llvm/test/CodeGen/AArch64/xor.ll
index d92402cf43b33e..7d7f7bfdbf38e8 100644
--- a/llvm/test/CodeGen/AArch64/xor.ll
+++ b/llvm/test/CodeGen/AArch64/xor.ll
@@ -51,7 +51,7 @@ define <4 x i32> @vec_add_of_not_decrement(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: vec_add_of_not_decrement:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mvn v1.16b, v1.16b
-; CHECK-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    add v0.4s, v0.4s, v1.4s
 ; CHECK-NEXT:    ret
   %t0 = sub <4 x i32> %x, %y
   %r = sub <4 x i32> %t0, <i32 1, i32 1, i32 1, i32 1>


        


More information about the llvm-commits mailing list