[llvm] 3304d51 - [RISCV] Add performMULcombine to perform strength-reduction
Philipp Tomsich via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 7 22:57:38 PST 2023
Author: Philipp Tomsich
Date: 2023-02-08T07:57:27+01:00
New Revision: 3304d51b676ea511feca28089cb60eba3873132e
URL: https://github.com/llvm/llvm-project/commit/3304d51b676ea511feca28089cb60eba3873132e
DIFF: https://github.com/llvm/llvm-project/commit/3304d51b676ea511feca28089cb60eba3873132e.diff
LOG: [RISCV] Add performMULcombine to perform strength-reduction
The RISC-V backend thus far does not provide strength-reduction, which
causes a long (but not complete) list of 3-instruction patterns listed
to utilize the shift-and-add instruction from Zba and XTHeadBa in
strength-reduction.
This adds the logic to perform strength-reduction through the DAG
combine for ISD::MUL. Initially, we wire this up for XTheadBa only,
until this has had some time to settle and get real-world test
exposure.
The following strength-reductions strategies are currently supported:
- XTheadBa
- C = (n + 1) // th.addsl
- C = (n + 1)k // th.addsl, slli
- C = (n + 1)(m + 1) // th.addsl, th.addsl
- C = (n + 1)(m + 1)k // th.addsl, th.addsl, slli
- C = ((n + 1)m + 1) // th.addsl, th.addsl
- C = ((n + 1)m + 1)k // th.addslm th.addsl, slli
- base ISA
- C being 2 set-bits // slli, slli, add
(possibly slli, th.addsl)
Even though the slli+slli+add sequence would we supported without
XTheadBa, this currently is gated to avoid having to update a large
number of test cases (i.e., anything that has a multiplication with a
constant where only 2 bits are set) in this commit.
With the strength reduction now being performed in performMUL combine,
we drop the (now redundant) patterns from RISCVInstrInfoXTHead.td.
Depends on D143029
Differential Revision: https://reviews.llvm.org/D143394
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2eb50800669..eb37679b4d99 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1011,7 +1011,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setJumpIsExpensive();
setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
- ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+ ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT, ISD::MUL});
if (Subtarget.is64Bit())
setTargetDAGCombine(ISD::SRA);
@@ -8569,6 +8569,134 @@ static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT));
}
+static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ SDLoc DL(N);
+ const MVT XLenVT = Subtarget.getXLenVT();
+ const EVT VT = N->getValueType(0);
+
+ // An MUL is usually smaller than any alternative sequence for legal type.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (DAG.getMachineFunction().getFunction().hasMinSize() &&
+ TLI.isOperationLegal(ISD::MUL, VT))
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ ConstantSDNode *ConstOp = dyn_cast<ConstantSDNode>(N1);
+ // Any optimization requires a constant RHS.
+ if (!ConstOp)
+ return SDValue();
+
+ const APInt &C = ConstOp->getAPIntValue();
+ // A multiply-by-pow2 will be reduced to a shift by the
+ // architecture-independent code.
+ if (C.isPowerOf2())
+ return SDValue();
+
+ // The below optimizations only work for non-negative constants
+ if (!C.isNonNegative())
+ return SDValue();
+
+ auto Shl = [&](SDValue Value, unsigned ShiftAmount) {
+ if (!ShiftAmount)
+ return Value;
+
+ SDValue ShiftAmountConst = DAG.getConstant(ShiftAmount, DL, XLenVT);
+ return DAG.getNode(ISD::SHL, DL, Value.getValueType(), Value,
+ ShiftAmountConst);
+ };
+ auto Add = [&](SDValue Addend1, SDValue Addend2) {
+ return DAG.getNode(ISD::ADD, DL, Addend1.getValueType(), Addend1, Addend2);
+ };
+
+ if (Subtarget.hasVendorXTHeadBa()) {
+ // We try to simplify using shift-and-add instructions into up to
+ // 3 instructions (e.g. 2x shift-and-add and 1x shift).
+
+ auto isDivisibleByShiftedAddConst = [&](APInt C, APInt &N,
+ APInt &Quotient) {
+ unsigned BitWidth = C.getBitWidth();
+ for (unsigned i = 3; i >= 1; --i) {
+ APInt X(BitWidth, (1 << i) + 1);
+ APInt Remainder;
+ APInt::sdivrem(C, X, Quotient, Remainder);
+ if (Remainder == 0) {
+ N = X;
+ return true;
+ }
+ }
+ return false;
+ };
+ auto isShiftedAddConst = [&](APInt C, APInt &N) {
+ APInt Quotient;
+ return isDivisibleByShiftedAddConst(C, N, Quotient) && Quotient == 1;
+ };
+ auto isSmallShiftAmount = [](APInt C) {
+ return (C == 2) || (C == 4) || (C == 8);
+ };
+
+ auto ShiftAndAdd = [&](SDValue Value, unsigned ShiftAmount,
+ SDValue Addend) {
+ return Add(Shl(Value, ShiftAmount), Addend);
+ };
+ auto AnyExt = [&](SDValue Value) {
+ return DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Value);
+ };
+ auto Trunc = [&](SDValue Value) {
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, Value);
+ };
+
+ unsigned TrailingZeroes = C.countTrailingZeros();
+ const APInt ShiftedC = C.ashr(TrailingZeroes);
+ const APInt ShiftedCMinusOne = ShiftedC - 1;
+
+ // the below comments use the following notation:
+ // n, m .. a shift-amount for a shift-and-add instruction
+ // (i.e. in { 2, 4, 8 })
+ // k .. a power-of-2 that is equivalent to shifting by
+ // TrailingZeroes bits
+ // i, j .. a power-of-2
+
+ APInt ShiftAmt1;
+ APInt ShiftAmt2;
+ APInt Quotient;
+
+ // C = (m + 1) * k
+ if (isShiftedAddConst(ShiftedC, ShiftAmt1)) {
+ SDValue Op0 = AnyExt(N0);
+ SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
+ return Trunc(Shl(Result, TrailingZeroes));
+ }
+ // C = (m + 1) * (n + 1) * k
+ if (isDivisibleByShiftedAddConst(ShiftedC, ShiftAmt1, Quotient) &&
+ isShiftedAddConst(Quotient, ShiftAmt2)) {
+ SDValue Op0 = AnyExt(N0);
+ SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
+ Result = ShiftAndAdd(Result, ShiftAmt2.logBase2(), Result);
+ return Trunc(Shl(Result, TrailingZeroes));
+ }
+ // C = ((m + 1) * n + 1) * k
+ if (isDivisibleByShiftedAddConst(ShiftedCMinusOne, ShiftAmt1, ShiftAmt2) &&
+ isSmallShiftAmount(ShiftAmt2)) {
+ SDValue Op0 = AnyExt(N0);
+ SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
+ Result = ShiftAndAdd(Result, Quotient.logBase2(), Op0);
+ return Trunc(Shl(Result, TrailingZeroes));
+ }
+
+ // C has 2 bits set: synthesize using 2 shifts and 1 add (which may
+ // see one of the shifts merged into a shift-and-add, if feasible)
+ if (C.countPopulation() == 2) {
+ APInt HighBit(C.getBitWidth(), (1 << C.logBase2()));
+ APInt LowBit = C - HighBit;
+ return Add(Shl(N0, HighBit.logBase2()), Shl(N0, LowBit.logBase2()));
+ }
+ }
+
+ return SDValue();
+}
+
static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
@@ -10218,6 +10346,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performADDCombine(N, DAG, Subtarget);
case ISD::SUB:
return performSUBCombine(N, DAG, Subtarget);
+ case ISD::MUL:
+ return performMULCombine(N, DAG, Subtarget);
case ISD::AND:
return performANDCombine(N, DCI, Subtarget);
case ISD::OR:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
index c7da1c557d1a..9cf61ffa00e8 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
@@ -161,67 +161,6 @@ def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
(TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>;
def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
(TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>;
-
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
- (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 3)>;
-
-def : Pat<(add GPR:$r, CSImm12MulBy4:$i),
- (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy2XForm CSImm12MulBy4:$i)), 2)>;
-def : Pat<(add GPR:$r, CSImm12MulBy8:$i),
- (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy3XForm CSImm12MulBy8:$i)), 3)>;
-
-def : Pat<(mul GPR:$r, C3LeftShift:$i),
- (SLLI (TH_ADDSL GPR:$r, GPR:$r, 1),
- (TrailingZeros C3LeftShift:$i))>;
-def : Pat<(mul GPR:$r, C5LeftShift:$i),
- (SLLI (TH_ADDSL GPR:$r, GPR:$r, 2),
- (TrailingZeros C5LeftShift:$i))>;
-def : Pat<(mul GPR:$r, C9LeftShift:$i),
- (SLLI (TH_ADDSL GPR:$r, GPR:$r, 3),
- (TrailingZeros C9LeftShift:$i))>;
-
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 1), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
- (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
- (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
- (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
- (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
- (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>;
-
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
- (SLLI (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2),
- (TH_ADDSL GPR:$r, GPR:$r, 2), 2), 3)>;
} // Predicates = [HasVendorXTHeadBa]
defm PseudoTHVdotVMAQA : VPseudoVMAQA_VV_VX;
More information about the llvm-commits
mailing list