[llvm] 62c4ac7 - [RISCV] Optimize splats of extracted vector elements

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 8 02:46:01 PST 2022


Author: Fraser Cormack
Date: 2022-02-08T10:35:25Z
New Revision: 62c4ac764bc0160ac4eed70a5becd4e015bcb65b

URL: https://github.com/llvm/llvm-project/commit/62c4ac764bc0160ac4eed70a5becd4e015bcb65b
DIFF: https://github.com/llvm/llvm-project/commit/62c4ac764bc0160ac4eed70a5becd4e015bcb65b.diff

LOG: [RISCV] Optimize splats of extracted vector elements

This patch adds an optimization to splat-like operations where the
splatted value is extracted from a identically-sized vector. On RVV we
can splat that via vrgather.vx/vrgather.vi without dropping to scalar
beforehand.

We do have a similar VECTOR_SHUFFLE-specific optimization but that only
works on fixed-length vector types and for those with a constant splat
lane. This patch extends this optimization to make it work on
scalable-vector types and on unknown extract indices.

It is performed during fixed-vector BUILD_VECTOR lowering and during a
new DAGCombine on SPLAT_VECTOR for scalable vectors.

Reviewed By: craig.topper, khchen

Differential Revision: https://reviews.llvm.org/D118456

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
    llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7dd56d4bb04b1..dbbb3202dda3b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1086,6 +1086,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine(ISD::SRL);
     setTargetDAGCombine(ISD::SHL);
     setTargetDAGCombine(ISD::STORE);
+    setTargetDAGCombine(ISD::SPLAT_VECTOR);
   }
 
   setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
@@ -2000,6 +2001,40 @@ static Optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
   return VIDSequence{*SeqStepNum, *SeqStepDenom, *SeqAddend};
 }
 
+// Match a splatted value (SPLAT_VECTOR/BUILD_VECTOR) of an EXTRACT_VECTOR_ELT
+// and lower it as a VRGATHER_VX_VL from the source vector.
+static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
+                                  SelectionDAG &DAG,
+                                  const RISCVSubtarget &Subtarget) {
+  if (SplatVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+    return SDValue();
+  SDValue Vec = SplatVal.getOperand(0);
+  // Only perform this optimization on vectors of the same size for simplicity.
+  if (Vec.getValueType() != VT)
+    return SDValue();
+  SDValue Idx = SplatVal.getOperand(1);
+  // The index must be a legal type.
+  if (Idx.getValueType() != Subtarget.getXLenVT())
+    return SDValue();
+
+  MVT ContainerVT = VT;
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+    Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
+  }
+
+  SDValue Mask, VL;
+  std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+  SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
+                               Idx, Mask, VL);
+
+  if (!VT.isFixedLengthVector())
+    return Gather;
+
+  return convertFromScalableVector(VT, Gather, DAG, Subtarget);
+}
+
 static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
@@ -2123,6 +2158,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
   }
 
   if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
+    if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
+      return Gather;
     unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
                                         : RISCVISD::VMV_V_X_VL;
     Splat = DAG.getNode(Opc, DL, ContainerVT, Splat, VL);
@@ -8260,6 +8297,16 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
 
     break;
   }
+  case ISD::SPLAT_VECTOR: {
+    EVT VT = N->getValueType(0);
+    // Only perform this combine on legal MVT types.
+    if (!isTypeLegal(VT))
+      break;
+    if (auto Gather = matchSplatAsGather(N->getOperand(0), VT.getSimpleVT(), N,
+                                         DAG, Subtarget))
+      return Gather;
+    break;
+  }
   }
 
   return SDValue();

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
index 9db7b3c39e8b4..80abc47a23bb7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
@@ -217,11 +217,9 @@ define <4 x half> @splat_c3_v4f16(<4 x half> %v) {
 define <4 x half> @splat_idx_v4f16(<4 x half> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_v4f16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e16, mf2, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vfmv.f.s ft0, v8
 ; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, mu
-; CHECK-NEXT:    vfmv.v.f v8, ft0
+; CHECK-NEXT:    vrgather.vx v9, v8, a0
+; CHECK-NEXT:    vmv1r.v v8, v9
 ; CHECK-NEXT:    ret
   %x = extractelement <4 x half> %v, i64 %idx
   %ins = insertelement <4 x half> poison, half %x, i32 0
@@ -270,11 +268,9 @@ define <8 x float> @splat_idx_v8f32(<8 x float> %v, i64 %idx) {
 ;
 ; LMULMAX2-LABEL: splat_idx_v8f32:
 ; LMULMAX2:       # %bb.0:
-; LMULMAX2-NEXT:    vsetivli zero, 1, e32, m2, ta, mu
-; LMULMAX2-NEXT:    vslidedown.vx v8, v8, a0
-; LMULMAX2-NEXT:    vfmv.f.s ft0, v8
 ; LMULMAX2-NEXT:    vsetivli zero, 8, e32, m2, ta, mu
-; LMULMAX2-NEXT:    vfmv.v.f v8, ft0
+; LMULMAX2-NEXT:    vrgather.vx v10, v8, a0
+; LMULMAX2-NEXT:    vmv.v.v v8, v10
 ; LMULMAX2-NEXT:    ret
   %x = extractelement <8 x float> %v, i64 %idx
   %ins = insertelement <8 x float> poison, float %x, i32 0

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
index 80c4ad8662b14..06898e13bb050 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
@@ -665,11 +665,9 @@ define <4 x i32> @splat_c3_v4i32(<4 x i32> %v) {
 define <4 x i32> @splat_idx_v4i32(<4 x i32> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_v4i32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, mu
-; CHECK-NEXT:    vmv.v.x v8, a0
+; CHECK-NEXT:    vrgather.vx v9, v8, a0
+; CHECK-NEXT:    vmv.v.v v8, v9
 ; CHECK-NEXT:    ret
   %x = extractelement <4 x i32> %v, i64 %idx
   %ins = insertelement <4 x i32> poison, i32 %x, i32 0
@@ -693,11 +691,9 @@ define <8 x i16> @splat_c4_v8i16(<8 x i16> %v) {
 define <8 x i16> @splat_idx_v8i16(<8 x i16> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_v8i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e16, m1, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    vsetivli zero, 8, e16, m1, ta, mu
-; CHECK-NEXT:    vmv.v.x v8, a0
+; CHECK-NEXT:    vrgather.vx v9, v8, a0
+; CHECK-NEXT:    vmv.v.v v8, v9
 ; CHECK-NEXT:    ret
   %x = extractelement <8 x i16> %v, i64 %idx
   %ins = insertelement <8 x i16> poison, i16 %x, i32 0

diff  --git a/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll b/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
index c68e31051a5f9..b6887d368dba6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
@@ -5,11 +5,9 @@
 define <vscale x 4 x i32> @splat_c3_nxv4i32(<vscale x 4 x i32> %v) {
 ; CHECK-LABEL: splat_c3_nxv4i32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT:    vslidedown.vi v8, v8, 3
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    vsetvli a1, zero, e32, m2, ta, mu
-; CHECK-NEXT:    vmv.v.x v8, a0
+; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, mu
+; CHECK-NEXT:    vrgather.vi v10, v8, 3
+; CHECK-NEXT:    vmv.v.v v8, v10
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 4 x i32> %v, i32 3
   %ins = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
@@ -20,11 +18,9 @@ define <vscale x 4 x i32> @splat_c3_nxv4i32(<vscale x 4 x i32> %v) {
 define <vscale x 4 x i32> @splat_idx_nxv4i32(<vscale x 4 x i32> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_nxv4i32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    vsetvli a1, zero, e32, m2, ta, mu
-; CHECK-NEXT:    vmv.v.x v8, a0
+; CHECK-NEXT:    vrgather.vx v10, v8, a0
+; CHECK-NEXT:    vmv.v.v v8, v10
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 4 x i32> %v, i64 %idx
   %ins = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
@@ -35,11 +31,9 @@ define <vscale x 4 x i32> @splat_idx_nxv4i32(<vscale x 4 x i32> %v, i64 %idx) {
 define <vscale x 8 x i16> @splat_c4_nxv8i16(<vscale x 8 x i16> %v) {
 ; CHECK-LABEL: splat_c4_nxv8i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e16, m2, ta, mu
-; CHECK-NEXT:    vslidedown.vi v8, v8, 4
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, mu
-; CHECK-NEXT:    vmv.v.x v8, a0
+; CHECK-NEXT:    vsetvli a0, zero, e16, m2, ta, mu
+; CHECK-NEXT:    vrgather.vi v10, v8, 4
+; CHECK-NEXT:    vmv.v.v v8, v10
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 8 x i16> %v, i32 4
   %ins = insertelement <vscale x 8 x i16> poison, i16 %x, i32 0
@@ -50,11 +44,9 @@ define <vscale x 8 x i16> @splat_c4_nxv8i16(<vscale x 8 x i16> %v) {
 define <vscale x 8 x i16> @splat_idx_nxv8i16(<vscale x 8 x i16> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_nxv8i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e16, m2, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, mu
-; CHECK-NEXT:    vmv.v.x v8, a0
+; CHECK-NEXT:    vrgather.vx v10, v8, a0
+; CHECK-NEXT:    vmv.v.v v8, v10
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 8 x i16> %v, i64 %idx
   %ins = insertelement <vscale x 8 x i16> poison, i16 %x, i32 0
@@ -65,11 +57,9 @@ define <vscale x 8 x i16> @splat_idx_nxv8i16(<vscale x 8 x i16> %v, i64 %idx) {
 define <vscale x 2 x half> @splat_c1_nxv2f16(<vscale x 2 x half> %v) {
 ; CHECK-LABEL: splat_c1_nxv2f16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e16, mf2, ta, mu
-; CHECK-NEXT:    vslidedown.vi v8, v8, 1
-; CHECK-NEXT:    vfmv.f.s ft0, v8
 ; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, mu
-; CHECK-NEXT:    vfmv.v.f v8, ft0
+; CHECK-NEXT:    vrgather.vi v9, v8, 1
+; CHECK-NEXT:    vmv1r.v v8, v9
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 2 x half> %v, i32 1
   %ins = insertelement <vscale x 2 x half> poison, half %x, i32 0
@@ -80,11 +70,9 @@ define <vscale x 2 x half> @splat_c1_nxv2f16(<vscale x 2 x half> %v) {
 define <vscale x 2 x half> @splat_idx_nxv2f16(<vscale x 2 x half> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_nxv2f16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e16, mf2, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vfmv.f.s ft0, v8
-; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, mu
-; CHECK-NEXT:    vfmv.v.f v8, ft0
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, mu
+; CHECK-NEXT:    vrgather.vx v9, v8, a0
+; CHECK-NEXT:    vmv1r.v v8, v9
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 2 x half> %v, i64 %idx
   %ins = insertelement <vscale x 2 x half> poison, half %x, i32 0
@@ -95,11 +83,9 @@ define <vscale x 2 x half> @splat_idx_nxv2f16(<vscale x 2 x half> %v, i64 %idx)
 define <vscale x 4 x float> @splat_c3_nxv4f32(<vscale x 4 x float> %v) {
 ; CHECK-LABEL: splat_c3_nxv4f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT:    vslidedown.vi v8, v8, 3
-; CHECK-NEXT:    vfmv.f.s ft0, v8
 ; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, mu
-; CHECK-NEXT:    vfmv.v.f v8, ft0
+; CHECK-NEXT:    vrgather.vi v10, v8, 3
+; CHECK-NEXT:    vmv.v.v v8, v10
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 4 x float> %v, i64 3
   %ins = insertelement <vscale x 4 x float> poison, float %x, i32 0
@@ -110,11 +96,9 @@ define <vscale x 4 x float> @splat_c3_nxv4f32(<vscale x 4 x float> %v) {
 define <vscale x 4 x float> @splat_idx_nxv4f32(<vscale x 4 x float> %v, i64 %idx) {
 ; CHECK-LABEL: splat_idx_nxv4f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT:    vslidedown.vx v8, v8, a0
-; CHECK-NEXT:    vfmv.f.s ft0, v8
-; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, mu
-; CHECK-NEXT:    vfmv.v.f v8, ft0
+; CHECK-NEXT:    vsetvli a1, zero, e32, m2, ta, mu
+; CHECK-NEXT:    vrgather.vx v10, v8, a0
+; CHECK-NEXT:    vmv.v.v v8, v10
 ; CHECK-NEXT:    ret
   %x = extractelement <vscale x 4 x float> %v, i64 %idx
   %ins = insertelement <vscale x 4 x float> poison, float %x, i32 0


        


More information about the llvm-commits mailing list