[llvm] fcf945f - [DAG] Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB) (#90860)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 14:11:22 PDT 2024


Author: David Green
Date: 2024-05-08T22:11:18+01:00
New Revision: fcf945f4edbad1f2d82df067c2826baa6165dd3e

URL: https://github.com/llvm/llvm-project/commit/fcf945f4edbad1f2d82df067c2826baa6165dd3e
DIFF: https://github.com/llvm/llvm-project/commit/fcf945f4edbad1f2d82df067c2826baa6165dd3e.diff

LOG: [DAG] Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB) (#90860)

This is useful when the inner add has multiple uses, and so cannot be
canonicalized by pushing the constants down through the mul. This patch
adds patterns for both `add(mul(add(A, CA), CM), CB)` and with an extra add
`add(add(mul(add(A, CA), CM), B) CB)` as the second can come up when
lowering geps.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/addimm-mulimm.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e835bd950a7b..4589d201d620 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2838,6 +2838,66 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
     return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
   }
 
+  // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
+  // This can help if the inner add has multiple uses.
+  APInt CM, CA;
+  if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
+    if (VT.getScalarSizeInBits() <= 64) {
+      if (sd_match(N0, m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
+                                      m_ConstInt(CM)))) &&
+          TLI.isLegalAddImmediate(
+              (CA * CM + CB->getAPIntValue()).getSExtValue())) {
+        SDNodeFlags Flags;
+        // If all the inputs are nuw, the outputs can be nuw. If all the input
+        // are _also_ nsw the outputs can be too.
+        if (N->getFlags().hasNoUnsignedWrap() &&
+            N0->getFlags().hasNoUnsignedWrap() &&
+            N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
+          Flags.setNoUnsignedWrap(true);
+          if (N->getFlags().hasNoSignedWrap() &&
+              N0->getFlags().hasNoSignedWrap() &&
+              N0.getOperand(0)->getFlags().hasNoSignedWrap())
+            Flags.setNoSignedWrap(true);
+        }
+        SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
+                                  DAG.getConstant(CM, DL, VT), Flags);
+        return DAG.getNode(
+            ISD::ADD, DL, VT, Mul,
+            DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
+      }
+      // Also look in case there is an intermediate add.
+      if (sd_match(N0, m_OneUse(m_Add(
+                           m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
+                                          m_ConstInt(CM))),
+                           m_Value(B)))) &&
+          TLI.isLegalAddImmediate(
+              (CA * CM + CB->getAPIntValue()).getSExtValue())) {
+        SDNodeFlags Flags;
+        // If all the inputs are nuw, the outputs can be nuw. If all the input
+        // are _also_ nsw the outputs can be too.
+        SDValue OMul =
+            N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
+        if (N->getFlags().hasNoUnsignedWrap() &&
+            N0->getFlags().hasNoUnsignedWrap() &&
+            OMul->getFlags().hasNoUnsignedWrap() &&
+            OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
+          Flags.setNoUnsignedWrap(true);
+          if (N->getFlags().hasNoSignedWrap() &&
+              N0->getFlags().hasNoSignedWrap() &&
+              OMul->getFlags().hasNoSignedWrap() &&
+              OMul.getOperand(0)->getFlags().hasNoSignedWrap())
+            Flags.setNoSignedWrap(true);
+        }
+        SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
+                                  DAG.getConstant(CM, DL, VT), Flags);
+        SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
+        return DAG.getNode(
+            ISD::ADD, DL, VT, Add,
+            DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
+      }
+    }
+  }
+
   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
     return Combined;
 

diff  --git a/llvm/test/CodeGen/AArch64/addimm-mulimm.ll b/llvm/test/CodeGen/AArch64/addimm-mulimm.ll
index 3618b14aa921..6636813eb250 100644
--- a/llvm/test/CodeGen/AArch64/addimm-mulimm.ll
+++ b/llvm/test/CodeGen/AArch64/addimm-mulimm.ll
@@ -166,9 +166,9 @@ define signext i32 @addmuladd_multiuse(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
+; CHECK-NEXT:    mov w9, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w9
 ; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w10, #4 // =0x4
-; CHECK-NEXT:    madd w8, w9, w8, w10
 ; CHECK-NEXT:    eor w0, w9, w8
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -198,11 +198,10 @@ define signext i32 @addmuladd_multiuse2(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -233,8 +232,8 @@ define signext i32 @addaddmuladd_multiuse(i32 signext %a, i32 %b) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
 ; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    madd w8, w9, w8, w1
-; CHECK-NEXT:    add w8, w8, #4
+; CHECK-NEXT:    madd w8, w0, w8, w1
+; CHECK-NEXT:    add w8, w8, #1300
 ; CHECK-NEXT:    eor w0, w9, w8
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -249,12 +248,11 @@ define signext i32 @addaddmuladd_multiuse2(i32 signext %a, i32 %b) {
 ; CHECK-LABEL: addaddmuladd_multiuse2:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w10, #162 // =0xa2
-; CHECK-NEXT:    madd w8, w9, w8, w1
-; CHECK-NEXT:    madd w9, w9, w10, w1
-; CHECK-NEXT:    add w8, w8, #4
-; CHECK-NEXT:    add w9, w9, #4
+; CHECK-NEXT:    mov w9, #162 // =0xa2
+; CHECK-NEXT:    madd w8, w0, w8, w1
+; CHECK-NEXT:    madd w9, w0, w9, w1
+; CHECK-NEXT:    add w8, w8, #1300
+; CHECK-NEXT:    add w9, w9, #652
 ; CHECK-NEXT:    eor w0, w9, w8
 ; CHECK-NEXT:    ret
   %tmp0 = add i32 %a, 4
@@ -319,17 +317,17 @@ define void @addmuladd_gep(ptr %p, i64 %a) {
 define i32 @addmuladd_gep2(ptr %p, i32 %a) {
 ; CHECK-LABEL: addmuladd_gep2:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #3240 // =0xca8
 ; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
-; CHECK-NEXT:    sxtw x8, w1
-; CHECK-NEXT:    mov w9, #3240 // =0xca8
-; CHECK-NEXT:    add x8, x8, #1
-; CHECK-NEXT:    madd x9, x8, x9, x0
-; CHECK-NEXT:    ldr w9, [x9, #20]
-; CHECK-NEXT:    tbnz w9, #31, .LBB22_2
+; CHECK-NEXT:    smaddl x8, w1, w8, x0
+; CHECK-NEXT:    ldr w8, [x8, #3260]
+; CHECK-NEXT:    tbnz w8, #31, .LBB22_2
 ; CHECK-NEXT:  // %bb.1:
 ; CHECK-NEXT:    mov w0, wzr
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB22_2: // %then
+; CHECK-NEXT:    sxtw x8, w1
+; CHECK-NEXT:    add x8, x8, #1
 ; CHECK-NEXT:    str x8, [x0]
 ; CHECK-NEXT:    mov w0, #1 // =0x1
 ; CHECK-NEXT:    ret
@@ -351,11 +349,10 @@ define signext i32 @addmuladd_multiuse2_nsw(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2_nsw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add nsw i32 %a, 4
@@ -371,11 +368,10 @@ define signext i32 @addmuladd_multiuse2_nuw(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2_nuw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add nuw i32 %a, 4
@@ -391,11 +387,10 @@ define signext i32 @addmuladd_multiuse2_nswnuw(i32 signext %a) {
 ; CHECK-LABEL: addmuladd_multiuse2_nswnuw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    mov w8, #324 // =0x144
-; CHECK-NEXT:    add w9, w0, #4
-; CHECK-NEXT:    mov w11, #4 // =0x4
-; CHECK-NEXT:    lsl w10, w9, #2
-; CHECK-NEXT:    madd w8, w9, w8, w11
-; CHECK-NEXT:    add w9, w10, #4
+; CHECK-NEXT:    lsl w9, w0, #2
+; CHECK-NEXT:    mov w10, #1300 // =0x514
+; CHECK-NEXT:    madd w8, w0, w8, w10
+; CHECK-NEXT:    add w9, w9, #20
 ; CHECK-NEXT:    eor w0, w8, w9
 ; CHECK-NEXT:    ret
   %tmp0 = add nsw nuw i32 %a, 4


        


More information about the llvm-commits mailing list