[llvm] [DAG] Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB) (PR #90860)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 12:41:12 PDT 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/90860

>From c8b0a3f21cf21637b6f1fb90f3d14d6e2ec88891 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Wed, 8 May 2024 20:41:00 +0100
Subject: [PATCH] [DAG] Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM),
 CM*CA+CB)

This is useful when the inner add has multiple uses, and so cannot be
canonicalized by pushing the constants down through the mul. I have added
patterns for both `add(mul(add(A, CA), CM), CB)` and with an extra add
`add(add(mul(add(A, CA), CM), B) CB)` as the second can come up when lowering
geps.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 60 +++++++++++++++++
 llvm/test/CodeGen/AArch64/addimm-mulimm.ll    | 67 +++++++++----------
 2 files changed, 91 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e835bd950a7be..4589d201d6203 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2838,6 +2838,66 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
     return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
   }
 
+  // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
+  // This can help if the inner add has multiple uses.
+  APInt CM, CA;
+  if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
+    if (VT.getScalarSizeInBits() <= 64) {
+      if (sd_match(N0, m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
+                                      m_ConstInt(CM)))) &&
+          TLI.isLegalAddImmediate(
+              (CA * CM + CB->getAPIntValue()).getSExtValue())) {
+        SDNodeFlags Flags;
+        // If all the inputs are nuw, the outputs can be nuw. If all the input
+        // are _also_ nsw the outputs can be too.
+        if (N->getFlags().hasNoUnsignedWrap() &&
+            N0->getFlags().hasNoUnsignedWrap() &&
+            N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
+          Flags.setNoUnsignedWrap(true);
+          if (N->getFlags().hasNoSignedWrap() &&
+              N0->getFlags().hasNoSignedWrap() &&
+              N0.getOperand(0)->getFlags().hasNoSignedWrap())
+            Flags.setNoSignedWrap(true);
+        }
+        SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
+                                  DAG.getConstant(CM, DL, VT), Flags);
+        return DAG.getNode(
+            ISD::ADD, DL, VT, Mul,
+            DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
+      }
+      // Also look in case there is an intermediate add.
+      if (sd_match(N0, m_OneUse(m_Add(
+                           m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
+                                          m_ConstInt(CM))),
+                           m_Value(B)))) &&
+          TLI.isLegalAddImmediate(
+              (CA * CM + CB->getAPIntValue()).getSExtValue())) {
+        SDNodeFlags Flags;
+        // If all the inputs are nuw, the outputs can be nuw. If all the input
+        // are _also_ nsw the outputs can be too.
+        SDValue OMul =
+            N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
+        if (N->getFlags().hasNoUnsignedWrap() &&
+            N0->getFlags().hasNoUnsignedWrap() &&
+            OMul->getFlags().hasNoUnsignedWrap() &&
+            OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
+          Flags.setNoUnsignedWrap(true);
+          if (N->getFlags().hasNoSignedWrap() &&
+              N0->getFlags().hasNoSignedWrap() &&
+              OMul->getFlags().hasNoSignedWrap() &&
+              OMul.getOperand(0)->getFlags().hasNoSignedWrap())
+            Flags.setNoSignedWrap(true);
+        }
+        SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
+                                  DAG.getConstant(CM, DL, VT), Flags);
+        SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
+        return DAG.getNode(
+            ISD::ADD, DL, VT, Add,
+            DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
+      }
+    }
+  }
+
   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
     return Combined;
 
diff --git a/llvm/test/CodeGen/AArch64/addimm-mulimm.ll b/llvm/test/CodeGen/AArch64/addimm-mulimm.ll
index 3618b14aa9212..6636813eb2504 100644
--- a/llvm/test/CodeGen/AArch64/addimm-mulimm.ll
+++ b/llvm/test/CodeGen/AArch64/addimm-mulimm.ll
@@ -166,9 +166,9 @@ define signext i32 @addmuladd_multiuse(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
+; CHECK-NEXT:    mov w9, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w9
 ; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w10, #4 // =0x4
-; CHECK-NEXT:    madd w8, w9, w8, w10
 ; CHECK-NEXT:    eor w0, w9, w8
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -198,11 +198,10 @@ define signext i32 @addmuladd_multiuse2(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -233,8 +232,8 @@ define signext i32 @addaddmuladd_multiuse(i32 signext %a, i32 %b) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
 ; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    madd w8, w9, w8, w1
-; CHECK-NEXT:    add w8, w8, #4
+; CHECK-NEXT:    madd w8, w0, w8, w1
+; CHECK-NEXT:    add w8, w8, #1300
 ; CHECK-NEXT:    eor w0, w9, w8
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -249,12 +248,11 @@ define signext i32 @addaddmuladd_multiuse2(i32 signext %a, i32 %b) {
 ; CHECK-LABEL: addaddmuladd_multiuse2:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w10, #162 // =0xa2
-; CHECK-NEXT:    madd w8, w9, w8, w1
-; CHECK-NEXT:    madd w9, w9, w10, w1
-; CHECK-NEXT:    add w8, w8, #4
-; CHECK-NEXT:    add w9, w9, #4
+; CHECK-NEXT:    mov w9, #162 // =0xa2
+; CHECK-NEXT:    madd w8, w0, w8, w1
+; CHECK-NEXT:    madd w9, w0, w9, w1
+; CHECK-NEXT:    add w8, w8, #1300
+; CHECK-NEXT:    add w9, w9, #652
 ; CHECK-NEXT:    eor w0, w9, w8
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -319,17 +317,17 @@ define void @addmuladd_gep(ptr %p, i64 %a) {
 define i32 @addmuladd_gep2(ptr %p, i32 %a) {
 ; CHECK-LABEL: addmuladd_gep2:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #3240 // =0xca8
 ; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
-; CHECK-NEXT:    sxtw x8, w1
-; CHECK-NEXT:    mov w9, #3240 // =0xca8
-; CHECK-NEXT:    add x8, x8, #1
-; CHECK-NEXT:    madd x9, x8, x9, x0
-; CHECK-NEXT:    ldr w9, [x9, #20]
-; CHECK-NEXT:    tbnz w9, #31, .LBB22_2
+; CHECK-NEXT:    smaddl x8, w1, w8, x0
+; CHECK-NEXT:    ldr w8, [x8, #3260]
+; CHECK-NEXT:    tbnz w8, #31, .LBB22_2
 ; CHECK-NEXT:  // %bb.1:
 ; CHECK-NEXT:    mov w0, wzr
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB22_2: // %then
+; CHECK-NEXT:    sxtw x8, w1
+; CHECK-NEXT:    add x8, x8, #1
 ; CHECK-NEXT:    str x8, [x0]
 ; CHECK-NEXT:    mov w0, #1 // =0x1
 ; CHECK-NEXT:    ret
@@ -351,11 +349,10 @@ define signext i32 @addmuladd_multiuse2_nsw(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2_nsw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add nsw i32 %a, 4
@@ -371,11 +368,10 @@ define signext i32 @addmuladd_multiuse2_nuw(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2_nuw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add nuw i32 %a, 4
@@ -391,11 +387,10 @@ define signext i32 @addmuladd_multiuse2_nswnuw(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2_nswnuw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add nsw nuw i32 %a, 4



More information about the llvm-commits mailing list