[llvm] 62a51c3 - [AArch64] Lower multiplication by a constant int to shl+sub+shl

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 29 10:33:35 PDT 2022


Author: zhongyunde
Date: 2022-09-30T01:31:06+08:00
New Revision: 62a51c357cf47127d8044e1bf367f6d1d0612969

URL: https://github.com/llvm/llvm-project/commit/62a51c357cf47127d8044e1bf367f6d1d0612969
DIFF: https://github.com/llvm/llvm-project/commit/62a51c357cf47127d8044e1bf367f6d1d0612969.diff

LOG: [AArch64] Lower multiplication by a constant int to shl+sub+shl

Decompose the const 14 can be separated from D132322
Change the costmodel to lower a = b * C where C = 2^n - 2^m to
        lsl     w8, w0, n
        sub     w0, w8, w0, lsl m
Reviewed By: efriedma
Differential Revision: https://reviews.llvm.org/D134706

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/mul_pow2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9750f01de2956..9f3af41689ec8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14798,7 +14798,7 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   // More aggressively, some multiplications N0 * C can be lowered to
   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
   // e.g. 6=3*2=(2+1)*2.
-  // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
+  // TODO: consider lowering more cases, e.g. C = -6, -14 or even 45
   // which equals to (1+2)*16-(1+2).
 
   // TrailingZeroes is used to test if the mul can be lowered to
@@ -14826,11 +14826,21 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   // Do we need to negate the result?
   bool NegateResult = false;
 
+  auto Shl = [&](SDValue N0, unsigned N1) {
+    SDValue RHS = DAG.getConstant(N1, DL, MVT::i64);
+    return DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
+  };
+  auto Sub = [&](SDValue N0, SDValue N1) {
+    return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
+  };
+
   if (ConstValue.isNonNegative()) {
     // (mul x, 2^N + 1) => (add (shl x, N), x)
     // (mul x, 2^N - 1) => (sub (shl x, N), x)
     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
+    // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M))
     APInt SCVMinus1 = ShiftedConstValue - 1;
+    APInt SCVPlus1 = ShiftedConstValue + 1;
     APInt CVPlus1 = ConstValue + 1;
     if (SCVMinus1.isPowerOf2()) {
       ShiftAmt = SCVMinus1.logBase2();
@@ -14838,6 +14848,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
     } else if (CVPlus1.isPowerOf2()) {
       ShiftAmt = CVPlus1.logBase2();
       AddSubOpc = ISD::SUB;
+    } else if (SCVPlus1.isPowerOf2()) {
+      ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes;
+      return Sub(Shl(N0, ShiftAmt), Shl(N0, TrailingZeroes));
     } else
       return SDValue();
   } else {
@@ -14857,21 +14870,18 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
       return SDValue();
   }
 
-  SDValue ShiftedVal = DAG.getNode(ISD::SHL, DL, VT, N0,
-                                   DAG.getConstant(ShiftAmt, DL, MVT::i64));
-
-  SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal : N0;
-  SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal;
+  SDValue ShiftedVal0 = Shl(N0, ShiftAmt);
+  SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal0 : N0;
+  SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal0;
   SDValue Res = DAG.getNode(AddSubOpc, DL, VT, AddSubN0, AddSubN1);
   assert(!(NegateResult && TrailingZeroes) &&
          "NegateResult and TrailingZeroes cannot both be true for now.");
   // Negate the result.
   if (NegateResult)
-    return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Res);
+    return Sub(DAG.getConstant(0, DL, VT), Res);
   // Shift the result.
   if (TrailingZeroes)
-    return DAG.getNode(ISD::SHL, DL, VT, Res,
-                       DAG.getConstant(TrailingZeroes, DL, MVT::i64));
+    return Shl(Res, TrailingZeroes);
   return Res;
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll
index 2c0bec9b87902..fa756f82f5423 100644
--- a/llvm/test/CodeGen/AArch64/mul_pow2.ll
+++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll
@@ -408,8 +408,8 @@ define i32 @test13(i32 %x) {
 define i32 @test14(i32 %x) {
 ; CHECK-LABEL: test14:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #14
-; CHECK-NEXT:    mul w0, w0, w8
+; CHECK-NEXT:    lsl w8, w0, #4
+; CHECK-NEXT:    sub w0, w8, w0, lsl #1
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test14:


        


More information about the llvm-commits mailing list