[llvm] 74c2d4f - [AArch64][SelectionDAG] Lower multiplication by a constant to shl+add+shl+add

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 20 09:37:27 PDT 2022


Author: zhongyunde
Date: 2022-10-21T00:33:49+08:00
New Revision: 74c2d4f6024c8f160871a2baa928d0b42415f183

URL: https://github.com/llvm/llvm-project/commit/74c2d4f6024c8f160871a2baa928d0b42415f183
DIFF: https://github.com/llvm/llvm-project/commit/74c2d4f6024c8f160871a2baa928d0b42415f183.diff

LOG: [AArch64][SelectionDAG] Lower multiplication by a constant to shl+add+shl+add

Change the costmodel to lower a = b * C where C = (1 + 2^m) * (1 + 2^n) to
      add   w8, w0, w0, lsl #m
      add   w0, w8, w8, lsl #n
Note: The latency can vary depending on the shirt amount

Reviewed By: efriedma, dmgreen
Differential Revision: https://reviews.llvm.org/D135441

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/mul_pow2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e7892c8d82dc..fba1da74e810 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15065,8 +15065,8 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   // 64-bit is 5 cycles, so this is always a win.
   // More aggressively, some multiplications N0 * C can be lowered to
   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
-  // e.g. 6=3*2=(2+1)*2.
-  // TODO: lower more cases, e.g. C = 45 which equals to (1+2)*16-(1+2).
+  // e.g. 6=3*2=(2+1)*2, 45=(1+4)*(1+8)
+  // TODO: lower more cases.
 
   // TrailingZeroes is used to test if the mul can be lowered to
   // shift+add+shift.
@@ -15103,13 +15103,34 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
     return DAG.getNode(ISD::SUB, DL, VT, Zero, N);
   };
 
+  // Can the const C be decomposed into (1+2^M1)*(1+2^N1), eg:
+  // C = 45 is equal to (1+4)*(1+8), we don't decompose it into (1+2)*(16-1) as
+  // the (2^N - 1) can't be execused via a single instruction.
+  auto isPowPlusPlusConst = [](APInt C, APInt &M, APInt &N) {
+    unsigned BitWidth = C.getBitWidth();
+    for (unsigned i = 1; i < BitWidth / 2; i++) {
+      APInt Rem;
+      APInt X(BitWidth, (1 << i) + 1);
+      APInt::sdivrem(C, X, N, Rem);
+      APInt NVMinus1 = N - 1;
+      if (Rem == 0 && NVMinus1.isPowerOf2()) {
+        M = X;
+        return true;
+      }
+    }
+    return false;
+  };
+
   if (ConstValue.isNonNegative()) {
     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
     // (mul x, 2^N - 1) => (sub (shl x, N), x)
     // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M))
+    // (mul x, (2^M + 1) * (2^N + 1))
+    //     => MV = (add (shl x, M), x); (add (shl MV, N), MV)
     APInt SCVMinus1 = ShiftedConstValue - 1;
     APInt SCVPlus1 = ShiftedConstValue + 1;
     APInt CVPlus1 = ConstValue + 1;
+    APInt CVM, CVN;
     if (SCVMinus1.isPowerOf2()) {
       ShiftAmt = SCVMinus1.logBase2();
       return Shl(Add(Shl(N0, ShiftAmt), N0), TrailingZeroes);
@@ -15119,6 +15140,17 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
     } else if (SCVPlus1.isPowerOf2()) {
       ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes;
       return Sub(Shl(N0, ShiftAmt), Shl(N0, TrailingZeroes));
+    } else if (Subtarget->hasLSLFast() &&
+               isPowPlusPlusConst(ConstValue, CVM, CVN)) {
+      APInt CVMMinus1 = CVM - 1;
+      APInt CVNMinus1 = CVN - 1;
+      unsigned ShiftM1 = CVMMinus1.logBase2();
+      unsigned ShiftN1 = CVNMinus1.logBase2();
+      // LSLFast implicate that Shifts <= 3 places are fast
+      if (ShiftM1 <= 3 && ShiftN1 <= 3) {
+        SDValue MVal = Add(Shl(N0, ShiftM1), N0);
+        return Add(Shl(MVal, ShiftN1), MVal);
+      }
     }
   } else {
     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))

diff  --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 6ec0b62c0210..cbdf6337847c 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -493,6 +493,94 @@ define i32 @test16(i32 %x) {
   ret i32 %mul
 }
 
+define i32 @test25_fast_shift(i32 %x) "target-features"="+lsl-fast" {
+; CHECK-LABEL: test25_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add w8, w0, w0, lsl #2
+; CHECK-NEXT:    add w0, w8, w8, lsl #2
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test25_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #25
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 25 ; 25 = (1+4)*(1+4)
+  ret i32 %mul
+}
+
+define i32 @test45_fast_shift(i32 %x) "target-features"="+lsl-fast" {
+; CHECK-LABEL: test45_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add w8, w0, w0, lsl #2
+; CHECK-NEXT:    add w0, w8, w8, lsl #3
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test45_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #45
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 45 ; 45 = (1+4)*(1+8)
+  ret i32 %mul
+}
+
+; Negative test: Keep MUL as don't have the feature LSLFast
+define i32 @test45(i32 %x) {
+; CHECK-LABEL: test45:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #45
+; CHECK-NEXT:    mul w0, w0, w8
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test45:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #45
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 45 ; 45 = (1+4)*(1+8)
+  ret i32 %mul
+}
+
+; Negative test: The shift amount 4 larger than 3
+define i32 @test85_fast_shift(i32 %x) "target-features"="+lsl-fast" {
+; CHECK-LABEL: test85_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #85
+; CHECK-NEXT:    mul w0, w0, w8
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test85_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #85
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 85 ; 85 = (1+4)*(1+16)
+  ret i32 %mul
+}
+
+; Negative test: The shift amount 5 larger than 3
+define i32 @test297_fast_shift(i32 %x) "target-features"="+lsl-fast" {
+; CHECK-LABEL: test297_fast_shift:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #297
+; CHECK-NEXT:    mul w0, w0, w8
+; CHECK-NEXT:    ret
+;
+; GISEL-LABEL: test297_fast_shift:
+; GISEL:       // %bb.0:
+; GISEL-NEXT:    mov w8, #297
+; GISEL-NEXT:    mul w0, w0, w8
+; GISEL-NEXT:    ret
+
+  %mul = mul nsw i32 %x, 297 ; 297 = (1+8)*(1+32)
+  ret i32 %mul
+}
+
 ; Convert mul x, -pow2 to shift.
 ; Convert mul x, -(pow2 +/- 1) to shift + add/sub.
 ; Lowering other negative constants are not supported yet.
@@ -770,11 +858,11 @@ define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
 ;
 ; GISEL-LABEL: muladd_demand_commute:
 ; GISEL:       // %bb.0:
-; GISEL-NEXT:    adrp x8, .LCPI44_1
-; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI44_1]
-; GISEL-NEXT:    adrp x8, .LCPI44_0
+; GISEL-NEXT:    adrp x8, .LCPI49_1
+; GISEL-NEXT:    ldr q2, [x8, :lo12:.LCPI49_1]
+; GISEL-NEXT:    adrp x8, .LCPI49_0
 ; GISEL-NEXT:    mla v1.4s, v0.4s, v2.4s
-; GISEL-NEXT:    ldr q0, [x8, :lo12:.LCPI44_0]
+; GISEL-NEXT:    ldr q0, [x8, :lo12:.LCPI49_0]
 ; GISEL-NEXT:    and v0.16b, v1.16b, v0.16b
 ; GISEL-NEXT:    ret
   %m = mul <4 x i32> %x, <i32 131008, i32 131008, i32 131008, i32 131008>


        


More information about the llvm-commits mailing list