[llvm] [AArch64][SelectionDAG] Lower multiplication by a constant to shl+add+shl+add (PR #89532)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 20:17:41 PDT 2024


https://github.com/vfdff updated https://github.com/llvm/llvm-project/pull/89532

>From 3489aa220d795dee6ece9d10576544bb27b1d06f Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Sun, 21 Apr 2024 01:22:13 -0400
Subject: [PATCH 1/2] [AArch64][SelectionDAG] Lower multiplication by a
 constant to shl+add+shl+add

Change the costmodel to lower a = b * C where C = (1 + 2^m) * 2^n + 1 to
          add   w8, w0, w0, lsl #m
          add   w0, w0, w8, lsl #n
Note: The latency can vary depending on the shirt amount
Fix part of https://github.com/llvm/llvm-project/issues/89430
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 23 +++++++++++++++++++
 llvm/test/CodeGen/AArch64/mul_pow2.ll         | 21 +++++++++++++++--
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3d1453e3beb9a1..e4d552dcf4f0f1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17602,12 +17602,31 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
     return false;
   };
 
+  // Can the const C be decomposed into (2^M + 1) * 2^N + 1), eg:
+  // C = 11 is equal to (1+4)*2+1, we don't decompose it into (1+2)*4-1 as
+  // the (2^N - 1) can't be execused via a single instruction.
+  auto isPowPlusPlusOneConst = [](APInt C, APInt &M, APInt &N) {
+    APInt CVMinus1 = C - 1;
+    if (CVMinus1.isNegative())
+      return false;
+    unsigned TrailingZeroes = CVMinus1.countr_zero();
+    APInt SCVMinus1 = CVMinus1.ashr(TrailingZeroes) - 1;
+    if (SCVMinus1.isPowerOf2()) {
+      M = SCVMinus1.logBase2();
+      N = TrailingZeroes;
+      return true;
+    }
+    return false;
+  };
+
   if (ConstValue.isNonNegative()) {
     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
     // (mul x, 2^N - 1) => (sub (shl x, N), x)
     // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M))
     // (mul x, (2^M + 1) * (2^N + 1))
     //     => MV = (add (shl x, M), x); (add (shl MV, N), MV)
+    // (mul x, (2^M + 1) * 2^N + 1))
+    //     =>  MV = add (shl x, M), x); add (shl MV, N), x)
     APInt SCVMinus1 = ShiftedConstValue - 1;
     APInt SCVPlus1 = ShiftedConstValue + 1;
     APInt CVPlus1 = ConstValue + 1;
@@ -17632,6 +17651,10 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
         SDValue MVal = Add(Shl(N0, ShiftM1), N0);
         return Add(Shl(MVal, ShiftN1), MVal);
       }
+    } else if (Subtarget->hasALULSLFast() &&
+               isPowPlusPlusOneConst(ConstValue, CVM, CVN)) {
+      SDValue MVal = Add(Shl(N0, CVM.getZExtValue()), N0);
+      return Add(Shl(MVal, CVN.getZExtValue()), N0);
     }
   } else {
     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 90e560af4465a9..6f49f38bf41a5c 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -410,6 +410,23 @@ define i32 @test11(i32 %x) {
   ret i32 %mul
 }
 
+define i32 @test11_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
+; CHECK-LABEL: test11_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add w8, w0, w0
+; CHECK-NEXT:    add w0, w0, w8, lsl #1
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test11_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #11 // =0xb
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 11
+  ret i32 %mul
+}
+
 define i32 @test12(i32 %x) {
 ; CHECK-LABEL: test12:
 ; CHECK:       // %bb.0:
@@ -858,9 +875,9 @@ define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
 ;
 ; GISEL-LABEL: muladd_demand_commute:
 ; GISEL:       // %bb.0:
-; GISEL-NEXT:    adrp x8, .LCPI49_0
+; GISEL-NEXT:    adrp x8, .LCPI50_0
 ; GISEL-NEXT:    movi v3.4s, #1, msl #16
-; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI49_0]
+; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI50_0]
 ; GISEL-NEXT:    mla v1.4s, v0.4s, v2.4s
 ; GISEL-NEXT:    and v0.16b, v1.16b, v3.16b
 ; GISEL-NEXT:    ret

>From bffd3bb1826a5f135a3e9b30192bffd57628c3a4 Mon Sep 17 00:00:00 2001
From: zhongyunde 00443407 <zhongyunde at huawei.com>
Date: Tue, 23 Apr 2024 08:31:37 -0400
Subject: [PATCH 2/2] Fix comment

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 16 ++++---
 llvm/test/CodeGen/AArch64/mul_pow2.ll         | 42 +++++++++++++++++--
 2 files changed, 50 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e4d552dcf4f0f1..d4faec4c4fbbc8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17612,8 +17612,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
     unsigned TrailingZeroes = CVMinus1.countr_zero();
     APInt SCVMinus1 = CVMinus1.ashr(TrailingZeroes) - 1;
     if (SCVMinus1.isPowerOf2()) {
-      M = SCVMinus1.logBase2();
-      N = TrailingZeroes;
+      unsigned BitWidth = SCVMinus1.getBitWidth();
+      M = APInt(BitWidth, SCVMinus1.logBase2());
+      N = APInt(BitWidth, TrailingZeroes);
       return true;
     }
     return false;
@@ -17646,15 +17647,20 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
       APInt CVNMinus1 = CVN - 1;
       unsigned ShiftM1 = CVMMinus1.logBase2();
       unsigned ShiftN1 = CVNMinus1.logBase2();
-      // LSLFast implicate that Shifts <= 3 places are fast
+      // ALULSLFast implicate that Shifts <= 3 places are fast
       if (ShiftM1 <= 3 && ShiftN1 <= 3) {
         SDValue MVal = Add(Shl(N0, ShiftM1), N0);
         return Add(Shl(MVal, ShiftN1), MVal);
       }
     } else if (Subtarget->hasALULSLFast() &&
                isPowPlusPlusOneConst(ConstValue, CVM, CVN)) {
-      SDValue MVal = Add(Shl(N0, CVM.getZExtValue()), N0);
-      return Add(Shl(MVal, CVN.getZExtValue()), N0);
+      unsigned ShiftM = CVM.getZExtValue();
+      unsigned ShiftN = CVN.getZExtValue();
+      // ALULSLFast implicate that Shifts <= 3 places are fast
+      if (ShiftM <= 3 && ShiftN <= 3) {
+        SDValue MVal = Add(Shl(N0, CVM.getZExtValue()), N0);
+        return Add(Shl(MVal, CVN.getZExtValue()), N0);
+      }
     }
   } else {
     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 6f49f38bf41a5c..d866f41caa3084 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -413,7 +413,7 @@ define i32 @test11(i32 %x) {
 define i32 @test11_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
 ; CHECK-LABEL: test11_fast_shift:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    add w8, w0, w0
+; CHECK-NEXT:    add w8, w0, w0, lsl #2
 ; CHECK-NEXT:    add w0, w0, w8, lsl #1
 ; CHECK-NEXT:    ret
 ;
@@ -527,6 +527,24 @@ define i32 @test25_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
   ret i32 %mul
 }
 
+; Negative: 35 = (((1<<4) + 1) << 1) + 1, the shift number 4 is out of bound
+define i32 @test35_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
+; CHECK-LABEL: test35_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #35 // =0x23
+; CHECK-NEXT:    mul w0, w0, w8
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test35_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #35 // =0x23
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 35
+  ret i32 %mul
+}
+
 define i32 @test45_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
 ; CHECK-LABEL: test45_fast_shift:
 ; CHECK:       // %bb.0:
@@ -562,6 +580,24 @@ define i32 @test45(i32 %x) {
   ret i32 %mul
 }
 
+; Negative: 49 = (((1<<1) + 1) << 4) + 1, the shift number 4 is out of bound
+define i32 @test49_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
+; CHECK-LABEL: test49_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #49 // =0x31
+; CHECK-NEXT:    mul w0, w0, w8
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test49_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #49 // =0x31
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 49
+  ret i32 %mul
+}
+
 ; Negative test: The shift amount 4 larger than 3
 define i32 @test85_fast_shift(i32 %x) "target-features"="+alu-lsl-fast" {
 ; CHECK-LABEL: test85_fast_shift:
@@ -875,9 +911,9 @@ define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
 ;
 ; GISEL-LABEL: muladd_demand_commute:
 ; GISEL:       // %bb.0:
-; GISEL-NEXT:    adrp x8, .LCPI50_0
+; GISEL-NEXT:    adrp x8, .LCPI52_0
 ; GISEL-NEXT:    movi v3.4s, #1, msl #16
-; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI50_0]
+; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI52_0]
 ; GISEL-NEXT:    mla v1.4s, v0.4s, v2.4s
 ; GISEL-NEXT:    and v0.16b, v1.16b, v3.16b
 ; GISEL-NEXT:    ret



More information about the llvm-commits mailing list