[llvm] 974f00a - [AArch64][SVE] Fold constant multiply of element count

Cullen Rhodes via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 20 03:58:25 PST 2019


Author: Cullen Rhodes
Date: 2019-12-20T11:58:00Z
New Revision: 974f00a4369371fae9d25477753c0f68f331e05a

URL: https://github.com/llvm/llvm-project/commit/974f00a4369371fae9d25477753c0f68f331e05a
DIFF: https://github.com/llvm/llvm-project/commit/974f00a4369371fae9d25477753c0f68f331e05a.diff

LOG: [AArch64][SVE] Fold constant multiply of element count

Summary:
E.g.

  %0 = tail call i64 @llvm.aarch64.sve.cntw(i32 31)
  %mul = mul i64 %0, <const>

Should emit:

  cntw    x0, all, mul #<const>

For <const> in the range 1-16.

Patch by Kerry McLaughlin

Reviewers: sdesmalen, huntergr, dancgr, rengolin, efriedma

Reviewed By: sdesmalen

Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D71014

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index e875844ed707..ef06993d618d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -169,6 +169,28 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
     return SelectSVELogicalImm(N, VT, Imm);
   }
 
+  // Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N.
+  template<signed Min, signed Max, signed Scale, bool Shift>
+  bool SelectCntImm(SDValue N, SDValue &Imm) {
+    if (!isa<ConstantSDNode>(N))
+      return false;
+
+    int64_t MulImm = cast<ConstantSDNode>(N)->getSExtValue();
+    if (Shift)
+      MulImm = 1 << MulImm;
+
+    if ((MulImm % std::abs(Scale)) != 0)
+      return false;
+
+    MulImm /= Scale;
+    if ((MulImm >= Min) && (MulImm <= Max)) {
+      Imm = CurDAG->getTargetConstant(MulImm, SDLoc(N), MVT::i32);
+      return true;
+    }
+
+    return false;
+  }
+
   /// Form sequences of consecutive 64/128-bit registers for use in NEON
   /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have
   /// between 1 and 4 elements. If it contains a single element that is returned

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7f9a7bd97467..a3dd2e65a121 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -9541,6 +9541,19 @@ AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
   return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA);
 }
 
+static bool IsSVECntIntrinsic(SDValue S) {
+  switch(getIntrinsicID(S.getNode())) {
+  default:
+    break;
+  case Intrinsic::aarch64_sve_cntb:
+  case Intrinsic::aarch64_sve_cnth:
+  case Intrinsic::aarch64_sve_cntw:
+  case Intrinsic::aarch64_sve_cntd:
+    return true;
+  }
+  return false;
+}
+
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const AArch64Subtarget *Subtarget) {
@@ -9551,9 +9564,18 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   if (!isa<ConstantSDNode>(N->getOperand(1)))
     return SDValue();
 
+  SDValue N0 = N->getOperand(0);
   ConstantSDNode *C = cast<ConstantSDNode>(N->getOperand(1));
   const APInt &ConstValue = C->getAPIntValue();
 
+  // Allow the scaling to be folded into the `cnt` instruction by preventing
+  // the scaling to be obscured here. This makes it easier to pattern match.
+  if (IsSVECntIntrinsic(N0) ||
+     (N0->getOpcode() == ISD::TRUNCATE &&
+      (IsSVECntIntrinsic(N0->getOperand(0)))))
+       if (ConstValue.sge(1) && ConstValue.sle(16))
+         return SDValue();
+
   // Multiplication of a power of two plus/minus one can be done more
   // cheaply as as shift+add/sub. For now, this is true unilaterally. If
   // future CPUs have a cheaper MADD instruction, this may need to be
@@ -9564,7 +9586,6 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
   // e.g. 6=3*2=(2+1)*2.
   // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
   // which equals to (1+2)*16-(1+2).
-  SDValue N0 = N->getOperand(0);
   // TrailingZeroes is used to test if the mul can be lowered to
   // shift+add+shift.
   unsigned TrailingZeroes = ConstValue.countTrailingZeros();

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 0a3df4f2b71d..764ff99a1dd0 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -244,6 +244,10 @@ def sve_incdec_imm : Operand<i32>, TImmLeaf<i32, [{
   let DecoderMethod = "DecodeSVEIncDecImm";
 }
 
+// This allows i32 immediate extraction from i64 based arithmetic.
+def sve_cnt_mul_imm : ComplexPattern<i32, 1, "SelectCntImm<1, 16, 1, false>">;
+def sve_cnt_shl_imm : ComplexPattern<i32, 1, "SelectCntImm<1, 16, 1, true>">;
+
 //===----------------------------------------------------------------------===//
 // SVE PTrue - These are used extensively throughout the pattern matching so
 //             it's important we define them first.
@@ -635,6 +639,12 @@ multiclass sve_int_count<bits<3> opc, string asm, SDPatternOperator op> {
   def : InstAlias<asm # "\t$Rd",
                   (!cast<Instruction>(NAME) GPR64:$Rd, 0b11111, 1), 2>;
 
+  def : Pat<(i64 (mul (op sve_pred_enum:$pattern), (sve_cnt_mul_imm i32:$imm))),
+            (!cast<Instruction>(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>;
+
+  def : Pat<(i64 (shl (op sve_pred_enum:$pattern), (i64 (sve_cnt_shl_imm i32:$imm)))),
+            (!cast<Instruction>(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>;
+
   def : Pat<(i64 (op sve_pred_enum:$pattern)),
             (!cast<Instruction>(NAME) sve_pred_enum:$pattern, 1)>;
 }

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll
index a3fd4faf196f..b37e3d8b8c82 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll
@@ -12,6 +12,24 @@ define i64 @cntb() {
   ret i64 %out
 }
 
+define i64 @cntb_mul3() {
+; CHECK-LABEL: cntb_mul3:
+; CHECK: cntb x0, vl6, mul #3
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cntb(i32 6)
+  %out = mul i64 %cnt, 3
+  ret i64 %out
+}
+
+define i64 @cntb_mul4() {
+; CHECK-LABEL: cntb_mul4:
+; CHECK: cntb x0, vl8, mul #4
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cntb(i32 8)
+  %out = mul i64 %cnt, 4
+  ret i64 %out
+}
+
 ;
 ; CNTH
 ;
@@ -24,6 +42,24 @@ define i64 @cnth() {
   ret i64 %out
 }
 
+define i64 @cnth_mul5() {
+; CHECK-LABEL: cnth_mul5:
+; CHECK: cnth x0, vl7, mul #5
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cnth(i32 7)
+  %out = mul i64 %cnt, 5
+  ret i64 %out
+}
+
+define i64 @cnth_mul8() {
+; CHECK-LABEL: cnth_mul8:
+; CHECK: cnth x0, vl5, mul #8
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cnth(i32 5)
+  %out = mul i64 %cnt, 8
+  ret i64 %out
+}
+
 ;
 ; CNTW
 ;
@@ -36,6 +72,24 @@ define i64 @cntw() {
   ret i64 %out
 }
 
+define i64 @cntw_mul11() {
+; CHECK-LABEL: cntw_mul11:
+; CHECK: cntw x0, vl8, mul #11
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cntw(i32 8)
+  %out = mul i64 %cnt, 11
+  ret i64 %out
+}
+
+define i64 @cntw_mul2() {
+; CHECK-LABEL: cntw_mul2:
+; CHECK: cntw x0, vl6, mul #2
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cntw(i32 6)
+  %out = mul i64 %cnt, 2
+  ret i64 %out
+}
+
 ;
 ; CNTD
 ;
@@ -48,6 +102,24 @@ define i64 @cntd() {
   ret i64 %out
 }
 
+define i64 @cntd_mul15() {
+; CHECK-LABEL: cntd_mul15:
+; CHECK: cntd x0, vl16, mul #15
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cntd(i32 9)
+  %out = mul i64 %cnt, 15
+  ret i64 %out
+}
+
+define i64 @cntd_mul16() {
+; CHECK-LABEL: cntd_mul16:
+; CHECK: cntd x0, vl32, mul #16
+; CHECK-NEXT: ret
+  %cnt = call i64 @llvm.aarch64.sve.cntd(i32 10)
+  %out = mul i64 %cnt, 16
+  ret i64 %out
+}
+
 ;
 ; CNTP
 ;


        


More information about the llvm-commits mailing list