[llvm] 4298fc5 - [RISCV] Move strength reduction of mul X, 3/5/9*2^N to combine (#89966)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 8 10:13:05 PDT 2024
Author: Philip Reames
Date: 2024-05-08T10:13:01-07:00
New Revision: 4298fc5eb5c483fb72db6fce062352087dfd0acf
URL: https://github.com/llvm/llvm-project/commit/4298fc5eb5c483fb72db6fce062352087dfd0acf
DIFF: https://github.com/llvm/llvm-project/commit/4298fc5eb5c483fb72db6fce062352087dfd0acf.diff
LOG: [RISCV] Move strength reduction of mul X, 3/5/9*2^N to combine (#89966)
This moves our last major category tablegen driven multiply strength
reduction into the post legalize combine framework. The one slightly
tricky bit is making sure that we use a leading shl if we can form a
slli.uw, and trailing shl otherwise. Having the trailing shl is critical
for shNadd matching, and folding any following sext.w.
As can be seen in the TD deltas, this allows us to kill off both the
actual multiply patterns and the explicit add (mul X, C) Y patterns. The
later are now handled by the generic shNadd matching code, with the
exception of the THead only C=200 case because we don't (yet) have a
multiply expansion with two shNadd + a shift.
---------
Co-authored-by: Yingwei Zheng <dtcxzyw at qq.com>
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
llvm/test/CodeGen/RISCV/addimm-mulimm.ll
llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
llvm/test/CodeGen/RISCV/rv64zba.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3536eb4c0ba4b..846768f6d631e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13565,10 +13565,27 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
if (MulAmt % Divisor != 0)
continue;
uint64_t MulAmt2 = MulAmt / Divisor;
- // 3/5/9 * 2^N -> shXadd (sll X, C), (sll X, C)
- // Matched in tablegen, avoid perturbing patterns.
- if (isPowerOf2_64(MulAmt2))
- return SDValue();
+ // 3/5/9 * 2^N -> shl (shXadd X, X), N
+ if (isPowerOf2_64(MulAmt2)) {
+ SDLoc DL(N);
+ SDValue X = N->getOperand(0);
+ // Put the shift first if we can fold a zext into the
+ // shift forming a slli.uw.
+ if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
+ X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
+ SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
+ DAG.getConstant(Log2_64(MulAmt2), DL, VT));
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), Shl);
+ }
+ // Otherwise, put rhe shl second so that it can fold with following
+ // instructions (e.g. sext or add).
+ SDValue Mul359 =
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359,
+ DAG.getConstant(Log2_64(MulAmt2), DL, VT));
+ }
// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
index b398c5e7fec2e..bc14f165d9622 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
@@ -549,40 +549,11 @@ def : Pat<(add_non_imm12 sh2add_op:$rs1, (XLenVT GPR:$rs2)),
def : Pat<(add_non_imm12 sh3add_op:$rs1, (XLenVT GPR:$rs2)),
(TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 1)), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 2)), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 3)), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 1)), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 2)), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 3)), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 1)), 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 2)), 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (XLenVT (TH_ADDSL GPR:$rs1, GPR:$rs1, 3)), 3)>;
-
def : Pat<(add (XLenVT GPR:$r), CSImm12MulBy4:$i),
(TH_ADDSL GPR:$r, (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy2XForm CSImm12MulBy4:$i))), 2)>;
def : Pat<(add (XLenVT GPR:$r), CSImm12MulBy8:$i),
(TH_ADDSL GPR:$r, (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy3XForm CSImm12MulBy8:$i))), 3)>;
-def : Pat<(mul (XLenVT GPR:$r), C3LeftShift:$i),
- (SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 1)),
- (TrailingZeros C3LeftShift:$i))>;
-def : Pat<(mul (XLenVT GPR:$r), C5LeftShift:$i),
- (SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
- (TrailingZeros C5LeftShift:$i))>;
-def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
- (SLLI (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 3)),
- (TrailingZeros C9LeftShift:$i))>;
-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
(SLLI (XLenVT (TH_ADDSL (XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)),
(XLenVT (TH_ADDSL GPR:$r, GPR:$r, 2)), 2)), 3)>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index ffe2b7e271208..8a0bbf6abd334 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -173,42 +173,6 @@ def BCLRIANDIMaskLow : SDNodeXForm<imm, [{
SDLoc(N), N->getValueType(0));
}]>;
-def C3LeftShift : PatLeaf<(imm), [{
- uint64_t C = N->getZExtValue();
- return C > 3 && (C >> llvm::countr_zero(C)) == 3;
-}]>;
-
-def C5LeftShift : PatLeaf<(imm), [{
- uint64_t C = N->getZExtValue();
- return C > 5 && (C >> llvm::countr_zero(C)) == 5;
-}]>;
-
-def C9LeftShift : PatLeaf<(imm), [{
- uint64_t C = N->getZExtValue();
- return C > 9 && (C >> llvm::countr_zero(C)) == 9;
-}]>;
-
-// Constant of the form (3 << C) where C is less than 32.
-def C3LeftShiftUW : PatLeaf<(imm), [{
- uint64_t C = N->getZExtValue();
- unsigned Shift = llvm::countr_zero(C);
- return 1 <= Shift && Shift < 32 && (C >> Shift) == 3;
-}]>;
-
-// Constant of the form (5 << C) where C is less than 32.
-def C5LeftShiftUW : PatLeaf<(imm), [{
- uint64_t C = N->getZExtValue();
- unsigned Shift = llvm::countr_zero(C);
- return 1 <= Shift && Shift < 32 && (C >> Shift) == 5;
-}]>;
-
-// Constant of the form (9 << C) where C is less than 32.
-def C9LeftShiftUW : PatLeaf<(imm), [{
- uint64_t C = N->getZExtValue();
- unsigned Shift = llvm::countr_zero(C);
- return 1 <= Shift && Shift < 32 && (C >> Shift) == 9;
-}]>;
-
def CSImm12MulBy4 : PatLeaf<(imm), [{
if (!N->hasOneUse())
return false;
@@ -693,25 +657,6 @@ foreach i = {1,2,3} in {
(shxadd pat:$rs1, GPR:$rs2)>;
}
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
- (SH1ADD (XLenVT (SH1ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
- (SH1ADD (XLenVT (SH2ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
- (SH1ADD (XLenVT (SH3ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
- (SH2ADD (XLenVT (SH1ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
- (SH2ADD (XLenVT (SH2ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
- (SH2ADD (XLenVT (SH3ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
- (SH3ADD (XLenVT (SH1ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
- (SH3ADD (XLenVT (SH2ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-def : Pat<(add_like (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
- (SH3ADD (XLenVT (SH3ADD GPR:$rs1, GPR:$rs1)), GPR:$rs2)>;
-
def : Pat<(add_like (XLenVT GPR:$r), CSImm12MulBy4:$i),
(SH2ADD (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy2XForm CSImm12MulBy4:$i))),
GPR:$r)>;
@@ -719,16 +664,6 @@ def : Pat<(add_like (XLenVT GPR:$r), CSImm12MulBy8:$i),
(SH3ADD (XLenVT (ADDI (XLenVT X0), (SimmShiftRightBy3XForm CSImm12MulBy8:$i))),
GPR:$r)>;
-def : Pat<(mul (XLenVT GPR:$r), C3LeftShift:$i),
- (SLLI (XLenVT (SH1ADD GPR:$r, GPR:$r)),
- (TrailingZeros C3LeftShift:$i))>;
-def : Pat<(mul (XLenVT GPR:$r), C5LeftShift:$i),
- (SLLI (XLenVT (SH2ADD GPR:$r, GPR:$r)),
- (TrailingZeros C5LeftShift:$i))>;
-def : Pat<(mul (XLenVT GPR:$r), C9LeftShift:$i),
- (SLLI (XLenVT (SH3ADD GPR:$r, GPR:$r)),
- (TrailingZeros C9LeftShift:$i))>;
-
} // Predicates = [HasStdExtZba]
let Predicates = [HasStdExtZba, IsRV64] in {
@@ -780,15 +715,6 @@ def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0x3FFFFFFFC), (XLenVT GPR:$rs2
def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8), (XLenVT GPR:$rs2))),
(SH3ADD_UW (XLenVT (SRLI GPR:$rs1, 3)), GPR:$rs2)>;
-def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
- (SH1ADD (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C3LeftShiftUW:$i))),
- (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C3LeftShiftUW:$i))))>;
-def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C5LeftShiftUW:$i)),
- (SH2ADD (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C5LeftShiftUW:$i))),
- (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C5LeftShiftUW:$i))))>;
-def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C9LeftShiftUW:$i)),
- (SH3ADD (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C9LeftShiftUW:$i))),
- (XLenVT (SLLI_UW GPR:$r, (TrailingZeros C9LeftShiftUW:$i))))>;
} // Predicates = [HasStdExtZba, IsRV64]
let Predicates = [HasStdExtZbcOrZbkc] in {
diff --git a/llvm/test/CodeGen/RISCV/addimm-mulimm.ll b/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
index 8fb251a75bd14..e2f7be2e6d7f7 100644
--- a/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
+++ b/llvm/test/CodeGen/RISCV/addimm-mulimm.ll
@@ -600,8 +600,9 @@ define i64 @add_mul_combine_infinite_loop(i64 %x) {
; RV32IMB-NEXT: sh3add a1, a1, a2
; RV32IMB-NEXT: sh1add a0, a0, a0
; RV32IMB-NEXT: slli a2, a0, 3
-; RV32IMB-NEXT: addi a0, a2, 2047
-; RV32IMB-NEXT: addi a0, a0, 1
+; RV32IMB-NEXT: li a3, 1
+; RV32IMB-NEXT: slli a3, a3, 11
+; RV32IMB-NEXT: sh3add a0, a0, a3
; RV32IMB-NEXT: sltu a2, a0, a2
; RV32IMB-NEXT: add a1, a1, a2
; RV32IMB-NEXT: ret
@@ -610,8 +611,8 @@ define i64 @add_mul_combine_infinite_loop(i64 %x) {
; RV64IMB: # %bb.0:
; RV64IMB-NEXT: addi a0, a0, 86
; RV64IMB-NEXT: sh1add a0, a0, a0
-; RV64IMB-NEXT: li a1, -16
-; RV64IMB-NEXT: sh3add a0, a0, a1
+; RV64IMB-NEXT: slli a0, a0, 3
+; RV64IMB-NEXT: addi a0, a0, -16
; RV64IMB-NEXT: ret
%tmp0 = mul i64 %x, 24
%tmp1 = add i64 %tmp0, 2048
diff --git a/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
index c3ae40124ba04..2db0d40b0ce52 100644
--- a/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zba.ll
@@ -646,8 +646,8 @@ define i64 @zext_mul12884901888(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul12884901888:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
+; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 12884901888
@@ -667,8 +667,8 @@ define i64 @zext_mul21474836480(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul21474836480:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
+; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 21474836480
@@ -688,8 +688,8 @@ define i64 @zext_mul38654705664(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul38654705664:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
+; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 38654705664
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index 817e2b7d0bd99..5931e0982a4a6 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -865,8 +865,8 @@ define i64 @zext_mul12884901888(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul12884901888:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
+; RV64ZBA-NEXT: sh1add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 12884901888
@@ -886,8 +886,8 @@ define i64 @zext_mul21474836480(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul21474836480:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
+; RV64ZBA-NEXT: sh2add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 21474836480
@@ -907,8 +907,8 @@ define i64 @zext_mul38654705664(i32 signext %a) {
;
; RV64ZBA-LABEL: zext_mul38654705664:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: slli a0, a0, 32
+; RV64ZBA-NEXT: sh3add a0, a0, a0
; RV64ZBA-NEXT: ret
%b = zext i32 %a to i64
%c = mul i64 %b, 38654705664
More information about the llvm-commits
mailing list