[llvm] 3304d51 - [RISCV] Add performMULcombine to perform strength-reduction

Philipp Tomsich via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 7 22:57:38 PST 2023


Author: Philipp Tomsich
Date: 2023-02-08T07:57:27+01:00
New Revision: 3304d51b676ea511feca28089cb60eba3873132e

URL: https://github.com/llvm/llvm-project/commit/3304d51b676ea511feca28089cb60eba3873132e
DIFF: https://github.com/llvm/llvm-project/commit/3304d51b676ea511feca28089cb60eba3873132e.diff

LOG: [RISCV] Add performMULcombine to perform strength-reduction

The RISC-V backend thus far does not provide strength-reduction, which
causes a long (but not complete) list of 3-instruction patterns listed
to utilize the shift-and-add instruction from Zba and XTHeadBa in
strength-reduction.

This adds the logic to perform strength-reduction through the DAG
combine for ISD::MUL.  Initially, we wire this up for XTheadBa only,
until this has had some time to settle and get real-world test
exposure.

The following strength-reductions strategies are currently supported:
  - XTheadBa
    - C = (n + 1)           // th.addsl
    - C = (n + 1)k          // th.addsl, slli
    - C = (n + 1)(m + 1)    // th.addsl, th.addsl
    - C = (n + 1)(m + 1)k   // th.addsl, th.addsl, slli
    - C = ((n + 1)m + 1)    // th.addsl, th.addsl
    - C = ((n + 1)m + 1)k   // th.addslm th.addsl, slli
  - base ISA
    - C being 2 set-bits    // slli, slli, add
			       (possibly slli, th.addsl)

Even though the slli+slli+add sequence would we supported without
XTheadBa, this currently is gated to avoid having to update a large
number of test cases (i.e., anything that has a multiplication with a
constant where only 2 bits are set) in this commit.

With the strength reduction now being performed in performMUL combine,
we drop the (now redundant) patterns from RISCVInstrInfoXTHead.td.

Depends on D143029

Differential Revision: https://reviews.llvm.org/D143394

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2eb50800669..eb37679b4d99 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1011,7 +1011,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setJumpIsExpensive();
 
   setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
-                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT, ISD::MUL});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
@@ -8569,6 +8569,134 @@ static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT));
 }
 
+static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
+  SDLoc DL(N);
+  const MVT XLenVT = Subtarget.getXLenVT();
+  const EVT VT = N->getValueType(0);
+
+  // An MUL is usually smaller than any alternative sequence for legal type.
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (DAG.getMachineFunction().getFunction().hasMinSize() &&
+      TLI.isOperationLegal(ISD::MUL, VT))
+    return SDValue();
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  ConstantSDNode *ConstOp = dyn_cast<ConstantSDNode>(N1);
+  // Any optimization requires a constant RHS.
+  if (!ConstOp)
+    return SDValue();
+
+  const APInt &C = ConstOp->getAPIntValue();
+  // A multiply-by-pow2 will be reduced to a shift by the
+  // architecture-independent code.
+  if (C.isPowerOf2())
+    return SDValue();
+
+  // The below optimizations only work for non-negative constants
+  if (!C.isNonNegative())
+    return SDValue();
+
+  auto Shl = [&](SDValue Value, unsigned ShiftAmount) {
+    if (!ShiftAmount)
+      return Value;
+
+    SDValue ShiftAmountConst = DAG.getConstant(ShiftAmount, DL, XLenVT);
+    return DAG.getNode(ISD::SHL, DL, Value.getValueType(), Value,
+                       ShiftAmountConst);
+  };
+  auto Add = [&](SDValue Addend1, SDValue Addend2) {
+    return DAG.getNode(ISD::ADD, DL, Addend1.getValueType(), Addend1, Addend2);
+  };
+
+  if (Subtarget.hasVendorXTHeadBa()) {
+    // We try to simplify using shift-and-add instructions into up to
+    // 3 instructions (e.g. 2x shift-and-add and 1x shift).
+
+    auto isDivisibleByShiftedAddConst = [&](APInt C, APInt &N,
+                                            APInt &Quotient) {
+      unsigned BitWidth = C.getBitWidth();
+      for (unsigned i = 3; i >= 1; --i) {
+        APInt X(BitWidth, (1 << i) + 1);
+        APInt Remainder;
+        APInt::sdivrem(C, X, Quotient, Remainder);
+        if (Remainder == 0) {
+          N = X;
+          return true;
+        }
+      }
+      return false;
+    };
+    auto isShiftedAddConst = [&](APInt C, APInt &N) {
+      APInt Quotient;
+      return isDivisibleByShiftedAddConst(C, N, Quotient) && Quotient == 1;
+    };
+    auto isSmallShiftAmount = [](APInt C) {
+      return (C == 2) || (C == 4) || (C == 8);
+    };
+
+    auto ShiftAndAdd = [&](SDValue Value, unsigned ShiftAmount,
+                           SDValue Addend) {
+      return Add(Shl(Value, ShiftAmount), Addend);
+    };
+    auto AnyExt = [&](SDValue Value) {
+      return DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Value);
+    };
+    auto Trunc = [&](SDValue Value) {
+      return DAG.getNode(ISD::TRUNCATE, DL, VT, Value);
+    };
+
+    unsigned TrailingZeroes = C.countTrailingZeros();
+    const APInt ShiftedC = C.ashr(TrailingZeroes);
+    const APInt ShiftedCMinusOne = ShiftedC - 1;
+
+    // the below comments use the following notation:
+    // n, m  .. a shift-amount for a shift-and-add instruction
+    //          (i.e. in { 2, 4, 8 })
+    // k     .. a power-of-2 that is equivalent to shifting by
+    //          TrailingZeroes bits
+    // i, j  .. a power-of-2
+
+    APInt ShiftAmt1;
+    APInt ShiftAmt2;
+    APInt Quotient;
+
+    // C = (m + 1) * k
+    if (isShiftedAddConst(ShiftedC, ShiftAmt1)) {
+      SDValue Op0 = AnyExt(N0);
+      SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
+      return Trunc(Shl(Result, TrailingZeroes));
+    }
+    // C = (m + 1) * (n + 1) * k
+    if (isDivisibleByShiftedAddConst(ShiftedC, ShiftAmt1, Quotient) &&
+        isShiftedAddConst(Quotient, ShiftAmt2)) {
+      SDValue Op0 = AnyExt(N0);
+      SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
+      Result = ShiftAndAdd(Result, ShiftAmt2.logBase2(), Result);
+      return Trunc(Shl(Result, TrailingZeroes));
+    }
+    // C = ((m + 1) * n + 1) * k
+    if (isDivisibleByShiftedAddConst(ShiftedCMinusOne, ShiftAmt1, ShiftAmt2) &&
+        isSmallShiftAmount(ShiftAmt2)) {
+      SDValue Op0 = AnyExt(N0);
+      SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
+      Result = ShiftAndAdd(Result, Quotient.logBase2(), Op0);
+      return Trunc(Shl(Result, TrailingZeroes));
+    }
+
+    // C has 2 bits set: synthesize using 2 shifts and 1 add (which may
+    // see one of the shifts merged into a shift-and-add, if feasible)
+    if (C.countPopulation() == 2) {
+      APInt HighBit(C.getBitWidth(), (1 << C.logBase2()));
+      APInt LowBit = C - HighBit;
+      return Add(Shl(N0, HighBit.logBase2()), Shl(N0, LowBit.logBase2()));
+    }
+  }
+
+  return SDValue();
+}
+
 static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
                                       const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -10218,6 +10346,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return performADDCombine(N, DAG, Subtarget);
   case ISD::SUB:
     return performSUBCombine(N, DAG, Subtarget);
+  case ISD::MUL:
+    return performMULCombine(N, DAG, Subtarget);
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
   case ISD::OR:

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
index c7da1c557d1a..9cf61ffa00e8 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td
@@ -161,67 +161,6 @@ def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
           (TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>;
 def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
           (TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>;
-
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 1)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 2)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 3)>;
-def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
-          (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 3)>;
-
-def : Pat<(add GPR:$r, CSImm12MulBy4:$i),
-          (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy2XForm CSImm12MulBy4:$i)), 2)>;
-def : Pat<(add GPR:$r, CSImm12MulBy8:$i),
-          (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy3XForm CSImm12MulBy8:$i)), 3)>;
-
-def : Pat<(mul GPR:$r, C3LeftShift:$i),
-          (SLLI (TH_ADDSL GPR:$r, GPR:$r, 1),
-                (TrailingZeros C3LeftShift:$i))>;
-def : Pat<(mul GPR:$r, C5LeftShift:$i),
-          (SLLI (TH_ADDSL GPR:$r, GPR:$r, 2),
-                (TrailingZeros C5LeftShift:$i))>;
-def : Pat<(mul GPR:$r, C9LeftShift:$i),
-          (SLLI (TH_ADDSL GPR:$r, GPR:$r, 3),
-                (TrailingZeros C9LeftShift:$i))>;
-
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 1), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
-	  (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
-	  (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
-	  (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
-	  (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>;
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
-	  (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>;
-
-def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
-	  (SLLI (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2),
-		       (TH_ADDSL GPR:$r, GPR:$r, 2), 2), 3)>;
 } // Predicates = [HasVendorXTHeadBa]
 
 defm PseudoTHVdotVMAQA      : VPseudoVMAQA_VV_VX;


        


More information about the llvm-commits mailing list