[llvm] 7e1422c - [DAGCombiner] Fold step_vector with add/mul/shl

Jun Ma via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 15 03:06:59 PDT 2021


Author: Jun Ma
Date: 2021-04-15T18:06:35+08:00
New Revision: 7e1422c1e43023e4cd5fbb3305f8dbf2d626e87e

URL: https://github.com/llvm/llvm-project/commit/7e1422c1e43023e4cd5fbb3305f8dbf2d626e87e
DIFF: https://github.com/llvm/llvm-project/commit/7e1422c1e43023e4cd5fbb3305f8dbf2d626e87e.diff

LOG: [DAGCombiner] Fold step_vector with add/mul/shl

This patch implements some DAG combines for STEP_VECTOR:
add step_vector(C1), step_vector(C2) -> step_vector(C1+C2)
add (add X step_vector(C1)), step_vector(C2) -> add X step_vector(C1+C2)
mul step_vector(C1), C2 -> step_vector(C1*C2)
shl step_vector(C1), C2 -> step_vector(C1<<C2)

TestPlan: check-llvm

Differential Revision: https://reviews.llvm.org/D100088

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/sve-stepvector.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ba5cb777138a4..c7d619d35ff55 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2503,6 +2503,31 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
   }
 
+  // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
+  if (N0.getOpcode() == ISD::STEP_VECTOR &&
+      N1.getOpcode() == ISD::STEP_VECTOR) {
+    const APInt &C0 = N0->getConstantOperandAPInt(0);
+    const APInt &C1 = N1->getConstantOperandAPInt(0);
+    EVT SVT = N0.getOperand(0).getValueType();
+    SDValue NewStep = DAG.getConstant(C0 + C1, DL, SVT);
+    return DAG.getStepVector(DL, VT, NewStep);
+  }
+
+  // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
+  if ((N0.getOpcode() == ISD::ADD) &&
+      (N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) &&
+      (N1.getOpcode() == ISD::STEP_VECTOR)) {
+    const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
+    const APInt &SV1 = N1->getConstantOperandAPInt(0);
+    EVT SVT = N1.getOperand(0).getValueType();
+    assert(N1.getOperand(0).getValueType() ==
+               N0.getOperand(1)->getOperand(0).getValueType() &&
+           "Different operand types of STEP_VECTOR.");
+    SDValue NewStep = DAG.getConstant(SV0 + SV1, DL, SVT);
+    SDValue SV = DAG.getStepVector(DL, VT, NewStep);
+    return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
+  }
+
   return SDValue();
 }
 
@@ -3893,6 +3918,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
       return DAG.getVScale(SDLoc(N), VT, C0 * C1);
     }
 
+  // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
+  APInt MulVal;
+  if (N0.getOpcode() == ISD::STEP_VECTOR)
+    if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
+      const APInt &C0 = N0.getConstantOperandAPInt(0);
+      EVT SVT = N0.getOperand(0).getValueType();
+      SDValue NewStep = DAG.getConstant(
+          C0 * MulVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT);
+      return DAG.getStepVector(SDLoc(N), VT, NewStep);
+    }
+
   // Fold ((mul x, 0/undef) -> 0,
   //       (mul x, 1) -> x) -> x)
   // -> and(x, mask)
@@ -8381,6 +8417,17 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
       return DAG.getVScale(SDLoc(N), VT, C0 << C1);
     }
 
+  // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
+  APInt ShlVal;
+  if (N0.getOpcode() == ISD::STEP_VECTOR)
+    if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
+      const APInt &C0 = N0.getConstantOperandAPInt(0);
+      EVT SVT = N0.getOperand(0).getValueType();
+      SDValue NewStep = DAG.getConstant(
+          C0 << ShlVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT);
+      return DAG.getStepVector(SDLoc(N), VT, NewStep);
+    }
+
   return SDValue();
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-stepvector.ll b/llvm/test/CodeGen/AArch64/sve-stepvector.ll
index 4d31e2626ec9e..d121fb46be446 100644
--- a/llvm/test/CodeGen/AArch64/sve-stepvector.ll
+++ b/llvm/test/CodeGen/AArch64/sve-stepvector.ll
@@ -105,6 +105,59 @@ entry:
   ret <vscale x 8 x i8> %0
 }
 
+define <vscale x 8 x i8> @add_stepvector_nxv8i8() {
+; CHECK-LABEL: add_stepvector_nxv8i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.h, #0, #2
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
+  %1 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
+  %2 = add <vscale x 8 x i8> %0, %1
+  ret <vscale x 8 x i8> %2
+}
+
+define <vscale x 8 x i8> @add_stepvector_nxv8i8_1(<vscale x 8 x i8> %p) {
+; CHECK-LABEL: add_stepvector_nxv8i8_1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z1.h, #0, #2
+; CHECK-NEXT:    add z0.h, z0.h, z1.h
+; CHECK-NEXT:    ret
+entry:
+  %0 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
+  %1 = add <vscale x 8 x i8> %p, %0
+  %2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
+  %3 = add <vscale x 8 x i8> %1, %2
+  ret <vscale x 8 x i8> %3
+}
+
+define <vscale x 8 x i8> @mul_stepvector_nxv8i8() {
+; CHECK-LABEL: mul_stepvector_nxv8i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.h, #0, #2
+; CHECK-NEXT:    ret
+entry:
+  %0 = insertelement <vscale x 8 x i8> poison, i8 2, i32 0
+  %1 = shufflevector <vscale x 8 x i8> %0, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer
+  %2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
+  %3 = mul <vscale x 8 x i8> %2, %1
+  ret <vscale x 8 x i8> %3
+}
+
+define <vscale x 8 x i8> @shl_stepvector_nxv8i8() {
+; CHECK-LABEL: shl_stepvector_nxv8i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.h, #0, #4
+; CHECK-NEXT:    ret
+entry:
+  %0 = insertelement <vscale x 8 x i8> poison, i8 2, i32 0
+  %1 = shufflevector <vscale x 8 x i8> %0, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer
+  %2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
+  %3 = shl <vscale x 8 x i8> %2, %1
+  ret <vscale x 8 x i8> %3
+}
+
+
 declare <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
 declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
 declare <vscale x 8 x i16> @llvm.experimental.stepvector.nxv8i16()


        


More information about the llvm-commits mailing list