[llvm] [SelectionDAG] Replace some basic patterns in visitADDLike with SDPatternMatch (PR #84759)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 11 06:13:09 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: XChy (XChy)
<details>
<summary>Changes</summary>
Resolves #<!-- -->84745.
Based on SDPatternMatch introduced by #<!-- -->78654, this patch replaces some of basic patterns in `visitADDLike` with corresponding patterns in SDPatternMatch.
This patch only replaces original folds, instead of introducing new one, thus there are no new tests. If new tests are needed, let me know please.
---
Full diff: https://github.com/llvm/llvm-project/pull/84759.diff
1 Files Affected:
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+26-30)
``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cdcb7114640471..858cd892f41766 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -38,6 +38,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/RuntimeLibcalls.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -79,6 +80,7 @@
#include "MatchContext.h"
using namespace llvm;
+using namespace llvm::SDPatternMatch;
#define DEBUG_TYPE "dagcombine"
@@ -2697,52 +2699,46 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
return SD;
}
+
+ SDValue A, B, C;
+
// fold ((0-A) + B) -> B-A
- if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
- return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
+ if (sd_match(N0, m_Sub(m_Zero(), m_Value(A))))
+ return DAG.getNode(ISD::SUB, DL, VT, N1, A);
// fold (A + (0-B)) -> A-B
- if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
- return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
+ if (sd_match(N1, m_Sub(m_Zero(), m_Value(B))))
+ return DAG.getNode(ISD::SUB, DL, VT, N0, B);
// fold (A+(B-A)) -> B
- if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
- return N1.getOperand(0);
+ if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
+ return B;
// fold ((B-A)+A) -> B
- if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
- return N0.getOperand(0);
+ if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
+ return B;
// fold ((A-B)+(C-A)) -> (C-B)
- if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
- N0.getOperand(0) == N1.getOperand(1))
- return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
- N0.getOperand(1));
+ if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
+ sd_match(N1, m_Sub(m_Value(C), m_Specific(A))))
+ return DAG.getNode(ISD::SUB, DL, VT, C, B);
// fold ((A-B)+(B-C)) -> (A-C)
- if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
- N0.getOperand(1) == N1.getOperand(0))
- return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
- N1.getOperand(1));
+ if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
+ sd_match(N1, m_Sub(m_Specific(B), m_Value(C))))
+ return DAG.getNode(ISD::SUB, DL, VT, A, C);
// fold (A+(B-(A+C))) to (B-C)
- if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
- N0 == N1.getOperand(1).getOperand(0))
- return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
- N1.getOperand(1).getOperand(1));
-
// fold (A+(B-(C+A))) to (B-C)
- if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
- N0 == N1.getOperand(1).getOperand(1))
- return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
- N1.getOperand(1).getOperand(0));
+ if (sd_match(N1, m_Sub(m_Value(B),
+ m_c_BinOp(ISD::ADD, m_Specific(N0), m_Value(C)))))
+ return DAG.getNode(ISD::SUB, DL, VT, B, C);
// fold (A+((B-A)+or-C)) to (B+or-C)
- if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
- N1.getOperand(0).getOpcode() == ISD::SUB &&
- N0 == N1.getOperand(0).getOperand(1))
- return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
- N1.getOperand(1));
+ if (sd_match(N1,
+ m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)),
+ m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
+ return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
// fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
``````````
</details>
https://github.com/llvm/llvm-project/pull/84759
More information about the llvm-commits
mailing list