[llvm] 62c4ac7 - [RISCV] Optimize splats of extracted vector elements
Fraser Cormack via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 8 02:46:01 PST 2022
Author: Fraser Cormack
Date: 2022-02-08T10:35:25Z
New Revision: 62c4ac764bc0160ac4eed70a5becd4e015bcb65b
URL: https://github.com/llvm/llvm-project/commit/62c4ac764bc0160ac4eed70a5becd4e015bcb65b
DIFF: https://github.com/llvm/llvm-project/commit/62c4ac764bc0160ac4eed70a5becd4e015bcb65b.diff
LOG: [RISCV] Optimize splats of extracted vector elements
This patch adds an optimization to splat-like operations where the
splatted value is extracted from a identically-sized vector. On RVV we
can splat that via vrgather.vx/vrgather.vi without dropping to scalar
beforehand.
We do have a similar VECTOR_SHUFFLE-specific optimization but that only
works on fixed-length vector types and for those with a constant splat
lane. This patch extends this optimization to make it work on
scalable-vector types and on unknown extract indices.
It is performed during fixed-vector BUILD_VECTOR lowering and during a
new DAGCombine on SPLAT_VECTOR for scalable vectors.
Reviewed By: craig.topper, khchen
Differential Revision: https://reviews.llvm.org/D118456
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7dd56d4bb04b1..dbbb3202dda3b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1086,6 +1086,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SRL);
setTargetDAGCombine(ISD::SHL);
setTargetDAGCombine(ISD::STORE);
+ setTargetDAGCombine(ISD::SPLAT_VECTOR);
}
setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
@@ -2000,6 +2001,40 @@ static Optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
return VIDSequence{*SeqStepNum, *SeqStepDenom, *SeqAddend};
}
+// Match a splatted value (SPLAT_VECTOR/BUILD_VECTOR) of an EXTRACT_VECTOR_ELT
+// and lower it as a VRGATHER_VX_VL from the source vector.
+static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ if (SplatVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+ SDValue Vec = SplatVal.getOperand(0);
+ // Only perform this optimization on vectors of the same size for simplicity.
+ if (Vec.getValueType() != VT)
+ return SDValue();
+ SDValue Idx = SplatVal.getOperand(1);
+ // The index must be a legal type.
+ if (Idx.getValueType() != Subtarget.getXLenVT())
+ return SDValue();
+
+ MVT ContainerVT = VT;
+ if (VT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
+ }
+
+ SDValue Mask, VL;
+ std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+ SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
+ Idx, Mask, VL);
+
+ if (!VT.isFixedLengthVector())
+ return Gather;
+
+ return convertFromScalableVector(VT, Gather, DAG, Subtarget);
+}
+
static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
@@ -2123,6 +2158,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
}
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
+ if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
+ return Gather;
unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
: RISCVISD::VMV_V_X_VL;
Splat = DAG.getNode(Opc, DL, ContainerVT, Splat, VL);
@@ -8260,6 +8297,16 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
+ case ISD::SPLAT_VECTOR: {
+ EVT VT = N->getValueType(0);
+ // Only perform this combine on legal MVT types.
+ if (!isTypeLegal(VT))
+ break;
+ if (auto Gather = matchSplatAsGather(N->getOperand(0), VT.getSimpleVT(), N,
+ DAG, Subtarget))
+ return Gather;
+ break;
+ }
}
return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
index 9db7b3c39e8b4..80abc47a23bb7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
@@ -217,11 +217,9 @@ define <4 x half> @splat_c3_v4f16(<4 x half> %v) {
define <4 x half> @splat_idx_v4f16(<4 x half> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_v4f16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu
-; CHECK-NEXT: vfmv.v.f v8, ft0
+; CHECK-NEXT: vrgather.vx v9, v8, a0
+; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <4 x half> %v, i64 %idx
%ins = insertelement <4 x half> poison, half %x, i32 0
@@ -270,11 +268,9 @@ define <8 x float> @splat_idx_v8f32(<8 x float> %v, i64 %idx) {
;
; LMULMAX2-LABEL: splat_idx_v8f32:
; LMULMAX2: # %bb.0:
-; LMULMAX2-NEXT: vsetivli zero, 1, e32, m2, ta, mu
-; LMULMAX2-NEXT: vslidedown.vx v8, v8, a0
-; LMULMAX2-NEXT: vfmv.f.s ft0, v8
; LMULMAX2-NEXT: vsetivli zero, 8, e32, m2, ta, mu
-; LMULMAX2-NEXT: vfmv.v.f v8, ft0
+; LMULMAX2-NEXT: vrgather.vx v10, v8, a0
+; LMULMAX2-NEXT: vmv.v.v v8, v10
; LMULMAX2-NEXT: ret
%x = extractelement <8 x float> %v, i64 %idx
%ins = insertelement <8 x float> poison, float %x, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
index 80c4ad8662b14..06898e13bb050 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
@@ -665,11 +665,9 @@ define <4 x i32> @splat_c3_v4i32(<4 x i32> %v) {
define <4 x i32> @splat_idx_v4i32(<4 x i32> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_v4i32:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu
-; CHECK-NEXT: vmv.v.x v8, a0
+; CHECK-NEXT: vrgather.vx v9, v8, a0
+; CHECK-NEXT: vmv.v.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <4 x i32> %v, i64 %idx
%ins = insertelement <4 x i32> poison, i32 %x, i32 0
@@ -693,11 +691,9 @@ define <8 x i16> @splat_c4_v8i16(<8 x i16> %v) {
define <8 x i16> @splat_idx_v8i16(<8 x i16> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_v8i16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu
-; CHECK-NEXT: vmv.v.x v8, a0
+; CHECK-NEXT: vrgather.vx v9, v8, a0
+; CHECK-NEXT: vmv.v.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <8 x i16> %v, i64 %idx
%ins = insertelement <8 x i16> poison, i16 %x, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll b/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
index c68e31051a5f9..b6887d368dba6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
@@ -5,11 +5,9 @@
define <vscale x 4 x i32> @splat_c3_nxv4i32(<vscale x 4 x i32> %v) {
; CHECK-LABEL: splat_c3_nxv4i32:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT: vslidedown.vi v8, v8, 3
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
-; CHECK-NEXT: vmv.v.x v8, a0
+; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu
+; CHECK-NEXT: vrgather.vi v10, v8, 3
+; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x i32> %v, i32 3
%ins = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
@@ -20,11 +18,9 @@ define <vscale x 4 x i32> @splat_c3_nxv4i32(<vscale x 4 x i32> %v) {
define <vscale x 4 x i32> @splat_idx_nxv4i32(<vscale x 4 x i32> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv4i32:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
-; CHECK-NEXT: vmv.v.x v8, a0
+; CHECK-NEXT: vrgather.vx v10, v8, a0
+; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x i32> %v, i64 %idx
%ins = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
@@ -35,11 +31,9 @@ define <vscale x 4 x i32> @splat_idx_nxv4i32(<vscale x 4 x i32> %v, i64 %idx) {
define <vscale x 8 x i16> @splat_c4_nxv8i16(<vscale x 8 x i16> %v) {
; CHECK-LABEL: splat_c4_nxv8i16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e16, m2, ta, mu
-; CHECK-NEXT: vslidedown.vi v8, v8, 4
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, mu
-; CHECK-NEXT: vmv.v.x v8, a0
+; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, mu
+; CHECK-NEXT: vrgather.vi v10, v8, 4
+; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 8 x i16> %v, i32 4
%ins = insertelement <vscale x 8 x i16> poison, i16 %x, i32 0
@@ -50,11 +44,9 @@ define <vscale x 8 x i16> @splat_c4_nxv8i16(<vscale x 8 x i16> %v) {
define <vscale x 8 x i16> @splat_idx_nxv8i16(<vscale x 8 x i16> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv8i16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e16, m2, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, mu
-; CHECK-NEXT: vmv.v.x v8, a0
+; CHECK-NEXT: vrgather.vx v10, v8, a0
+; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 8 x i16> %v, i64 %idx
%ins = insertelement <vscale x 8 x i16> poison, i16 %x, i32 0
@@ -65,11 +57,9 @@ define <vscale x 8 x i16> @splat_idx_nxv8i16(<vscale x 8 x i16> %v, i64 %idx) {
define <vscale x 2 x half> @splat_c1_nxv2f16(<vscale x 2 x half> %v) {
; CHECK-LABEL: splat_c1_nxv2f16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, mu
-; CHECK-NEXT: vslidedown.vi v8, v8, 1
-; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, mu
-; CHECK-NEXT: vfmv.v.f v8, ft0
+; CHECK-NEXT: vrgather.vi v9, v8, 1
+; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <vscale x 2 x half> %v, i32 1
%ins = insertelement <vscale x 2 x half> poison, half %x, i32 0
@@ -80,11 +70,9 @@ define <vscale x 2 x half> @splat_c1_nxv2f16(<vscale x 2 x half> %v) {
define <vscale x 2 x half> @splat_idx_nxv2f16(<vscale x 2 x half> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv2f16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vfmv.f.s ft0, v8
-; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, mu
-; CHECK-NEXT: vfmv.v.f v8, ft0
+; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, mu
+; CHECK-NEXT: vrgather.vx v9, v8, a0
+; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <vscale x 2 x half> %v, i64 %idx
%ins = insertelement <vscale x 2 x half> poison, half %x, i32 0
@@ -95,11 +83,9 @@ define <vscale x 2 x half> @splat_idx_nxv2f16(<vscale x 2 x half> %v, i64 %idx)
define <vscale x 4 x float> @splat_c3_nxv4f32(<vscale x 4 x float> %v) {
; CHECK-LABEL: splat_c3_nxv4f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT: vslidedown.vi v8, v8, 3
-; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu
-; CHECK-NEXT: vfmv.v.f v8, ft0
+; CHECK-NEXT: vrgather.vi v10, v8, 3
+; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x float> %v, i64 3
%ins = insertelement <vscale x 4 x float> poison, float %x, i32 0
@@ -110,11 +96,9 @@ define <vscale x 4 x float> @splat_c3_nxv4f32(<vscale x 4 x float> %v) {
define <vscale x 4 x float> @splat_idx_nxv4f32(<vscale x 4 x float> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv4f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
-; CHECK-NEXT: vslidedown.vx v8, v8, a0
-; CHECK-NEXT: vfmv.f.s ft0, v8
-; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu
-; CHECK-NEXT: vfmv.v.f v8, ft0
+; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
+; CHECK-NEXT: vrgather.vx v10, v8, a0
+; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x float> %v, i64 %idx
%ins = insertelement <vscale x 4 x float> poison, float %x, i32 0
More information about the llvm-commits
mailing list