[llvm] [RISCV] Combine vslidedown_vl with known VL to a smaller LMUL (PR #66267)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 18 03:05:26 PDT 2023
https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/66267
>From 569fb00a59f9442063ca17cca94c834bdd665641 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 13 Sep 2023 19:07:55 +0100
Subject: [PATCH] [RISCV] Combine vslidedown_vl with known VL and offset to a
smaller LMUL
If we know the VL and offset of a vslidedown_vl, we can work out the minimum
number of registers it's going to operate across. We can reuse the logic from
extract_vector_elt to perform it in a smaller type and reduce the LMUL.
The aim is to generalize #65598 and hopefully extend this to vslideup_vl too so
that we can get the same optimisation for insert_subvector and
insert_vector_elt.
One observation from adding this is that the vslide*_vl nodes all take a mask
operand, but currently anything other than vmset_vl will fail to select, as all
the patterns expect true_mask. So we need to create a new vmset_vl instead of
using extract_subvector on the existing vmset_vl.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 46 ++++-
.../CodeGen/RISCV/rvv/extractelt-int-rv32.ll | 18 +-
.../rvv/fixed-vectors-int-explodevector.ll | 160 ++++++++++--------
3 files changed, 135 insertions(+), 89 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 52b19ce7a228dbe..5e3208f59a64d10 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8805,15 +8805,6 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}
- // Shrink down Vec so we're performing the slidedown on a smaller LMUL.
- unsigned LastIdx = OrigIdx + SubVecVT.getVectorNumElements() - 1;
- if (auto ShrunkVT =
- getSmallestVTForIndex(ContainerVT, LastIdx, DL, DAG, Subtarget)) {
- ContainerVT = *ShrunkVT;
- Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
- DAG.getVectorIdxConstant(0, DL));
- }
-
SDValue Mask =
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
// Set the vector length to only the number of elements we care about. This
@@ -14266,6 +14257,43 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = performCONCAT_VECTORSCombine(N, DAG, Subtarget, *this))
return V;
break;
+ case RISCVISD::VSLIDEDOWN_VL: {
+ MVT OrigVT = N->getSimpleValueType(0);
+ auto *CVL = dyn_cast<ConstantSDNode>(N->getOperand(4));
+ auto *CIdx = dyn_cast<ConstantSDNode>(N->getOperand(2));
+ if (!CVL || !CIdx)
+ break;
+ unsigned MaxIdx = CVL->getZExtValue() + CIdx->getZExtValue() - 1;
+ // We can try and reduce the LMUL that a vslidedown uses if we know where
+ // the maximum index is. For example, if the target has Zvl128b, a
+ // vslidedown of e32 with with an offset of 4 and VL of 2 is only going to
+ // read from the first 2 registers at most. So if we were operating at
+ // LMUL=4 (nxv8i32), we can reduce it to LMUL=2(nxv4i32).
+ if (auto ShrunkVT =
+ getSmallestVTForIndex(OrigVT, MaxIdx, DL, DAG, Subtarget)) {
+ SDValue ShrunkPassthru =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, *ShrunkVT, N->getOperand(0),
+ DAG.getVectorIdxConstant(0, DL));
+ SDValue ShrunkInVec =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, *ShrunkVT, N->getOperand(1),
+ DAG.getVectorIdxConstant(0, DL));
+
+ // The only mask ever used in vslide*_vl nodes is vmset_vl, and the only
+ // patterns on vslide*_vl only accept vmset_vl. So create a new vmset
+ // since using an extract_subvector breaks patterns.
+ assert(N->getOperand(3).getOpcode() == RISCVISD::VMSET_VL);
+ SDValue ShrunkMask =
+ DAG.getNode(RISCVISD::VMSET_VL, SDLoc(N), getMaskTypeFor(*ShrunkVT),
+ N->getOperand(4));
+ SDValue ShrunkSlidedown =
+ DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, *ShrunkVT,
+ {ShrunkPassthru, ShrunkInVec, N->getOperand(2),
+ ShrunkMask, N->getOperand(4), N->getOperand(5)});
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigVT, N->getOperand(0),
+ ShrunkSlidedown, DAG.getVectorIdxConstant(0, DL));
+ }
+ break;
+ }
case RISCVISD::VFMV_V_F_VL: {
const MVT VT = N->getSimpleValueType(0);
SDValue Passthru = N->getOperand(0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/extractelt-int-rv32.ll b/llvm/test/CodeGen/RISCV/rvv/extractelt-int-rv32.ll
index fd2f89e26e59809..c3181a296abe06d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/extractelt-int-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/extractelt-int-rv32.ll
@@ -679,12 +679,13 @@ define i64 @extractelt_nxv4i64_0(<vscale x 4 x i64> %v) {
define i64 @extractelt_nxv4i64_imm(<vscale x 4 x i64> %v) {
; CHECK-LABEL: extractelt_nxv4i64_imm:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e64, m4, ta, ma
+; CHECK-NEXT: vsetivli zero, 1, e64, m2, ta, ma
; CHECK-NEXT: vslidedown.vi v8, v8, 2
-; CHECK-NEXT: li a0, 32
-; CHECK-NEXT: vsrl.vx v12, v8, a0
-; CHECK-NEXT: vmv.x.s a1, v12
; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: li a1, 32
+; CHECK-NEXT: vsetivli zero, 1, e64, m4, ta, ma
+; CHECK-NEXT: vsrl.vx v8, v8, a1
+; CHECK-NEXT: vmv.x.s a1, v8
; CHECK-NEXT: ret
%r = extractelement <vscale x 4 x i64> %v, i32 2
ret i64 %r
@@ -720,12 +721,13 @@ define i64 @extractelt_nxv8i64_0(<vscale x 8 x i64> %v) {
define i64 @extractelt_nxv8i64_imm(<vscale x 8 x i64> %v) {
; CHECK-LABEL: extractelt_nxv8i64_imm:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e64, m8, ta, ma
+; CHECK-NEXT: vsetivli zero, 1, e64, m2, ta, ma
; CHECK-NEXT: vslidedown.vi v8, v8, 2
-; CHECK-NEXT: li a0, 32
-; CHECK-NEXT: vsrl.vx v16, v8, a0
-; CHECK-NEXT: vmv.x.s a1, v16
; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: li a1, 32
+; CHECK-NEXT: vsetivli zero, 1, e64, m8, ta, ma
+; CHECK-NEXT: vsrl.vx v8, v8, a1
+; CHECK-NEXT: vmv.x.s a1, v8
; CHECK-NEXT: ret
%r = extractelement <vscale x 8 x i64> %v, i32 2
ret i64 %r
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-explodevector.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-explodevector.ll
index 4e60edf058450f0..6e0ca4cba6bd6d6 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-explodevector.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-explodevector.ll
@@ -875,11 +875,15 @@ define i64 @explode_8xi64(<8 x i64> %v) {
; RV32-NEXT: vsrl.vx v12, v8, a0
; RV32-NEXT: vmv.x.s a1, v12
; RV32-NEXT: vmv.x.s a2, v8
+; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
; RV32-NEXT: vslidedown.vi v12, v8, 1
-; RV32-NEXT: vsrl.vx v16, v12, a0
-; RV32-NEXT: vmv.x.s a3, v16
+; RV32-NEXT: vmv.x.s a3, v12
+; RV32-NEXT: vsetivli zero, 1, e64, m4, ta, ma
+; RV32-NEXT: vsrl.vx v12, v12, a0
; RV32-NEXT: vmv.x.s a4, v12
+; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
; RV32-NEXT: vslidedown.vi v12, v8, 2
+; RV32-NEXT: vsetivli zero, 1, e64, m4, ta, ma
; RV32-NEXT: vsrl.vx v16, v12, a0
; RV32-NEXT: vmv.x.s a5, v16
; RV32-NEXT: vmv.x.s a6, v12
@@ -903,19 +907,19 @@ define i64 @explode_8xi64(<8 x i64> %v) {
; RV32-NEXT: vsrl.vx v12, v8, a0
; RV32-NEXT: vmv.x.s a0, v12
; RV32-NEXT: vmv.x.s s0, v8
-; RV32-NEXT: add a1, a1, a3
-; RV32-NEXT: add a4, a2, a4
-; RV32-NEXT: sltu a2, a4, a2
+; RV32-NEXT: add a1, a1, a4
+; RV32-NEXT: add a3, a2, a3
+; RV32-NEXT: sltu a2, a3, a2
; RV32-NEXT: add a1, a1, a2
-; RV32-NEXT: add a6, a4, a6
-; RV32-NEXT: sltu a2, a6, a4
+; RV32-NEXT: add a6, a3, a6
+; RV32-NEXT: sltu a2, a6, a3
; RV32-NEXT: add a1, a1, a5
-; RV32-NEXT: add a2, a2, a7
-; RV32-NEXT: add a1, a1, a2
; RV32-NEXT: add t0, a6, t0
-; RV32-NEXT: sltu a2, t0, a6
-; RV32-NEXT: add a2, a2, t1
+; RV32-NEXT: sltu a3, t0, a6
+; RV32-NEXT: add a2, a2, a7
; RV32-NEXT: add a1, a1, a2
+; RV32-NEXT: add a3, a3, t1
+; RV32-NEXT: add a1, a1, a3
; RV32-NEXT: add t2, t0, t2
; RV32-NEXT: sltu a2, t2, t0
; RV32-NEXT: add a2, a2, t3
@@ -1029,115 +1033,127 @@ define i64 @explode_16xi64(<16 x i64> %v) {
; RV32-NEXT: vmv.x.s a0, v16
; RV32-NEXT: sw a0, 8(sp) # 4-byte Folded Spill
; RV32-NEXT: vmv.x.s a0, v8
+; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
; RV32-NEXT: vslidedown.vi v16, v8, 1
+; RV32-NEXT: vmv.x.s a3, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m8, ta, ma
+; RV32-NEXT: vsrl.vx v16, v16, a1
+; RV32-NEXT: vmv.x.s a4, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
+; RV32-NEXT: vslidedown.vi v16, v8, 2
+; RV32-NEXT: vsetivli zero, 1, e64, m8, ta, ma
; RV32-NEXT: vsrl.vx v24, v16, a1
; RV32-NEXT: vmv.x.s a5, v24
; RV32-NEXT: vmv.x.s a6, v16
-; RV32-NEXT: vslidedown.vi v16, v8, 2
-; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s a3, v24
-; RV32-NEXT: vmv.x.s a4, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m4, ta, ma
; RV32-NEXT: vslidedown.vi v16, v8, 3
+; RV32-NEXT: vsetivli zero, 1, e64, m8, ta, ma
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s2, v24
+; RV32-NEXT: vmv.x.s t0, v24
; RV32-NEXT: vmv.x.s a7, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m4, ta, ma
; RV32-NEXT: vslidedown.vi v16, v8, 4
-; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s3, v24
-; RV32-NEXT: vmv.x.s t0, v16
-; RV32-NEXT: vslidedown.vi v16, v8, 5
-; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s4, v24
; RV32-NEXT: vmv.x.s t1, v16
-; RV32-NEXT: vslidedown.vi v16, v8, 6
-; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s5, v24
+; RV32-NEXT: vsetivli zero, 1, e64, m8, ta, ma
+; RV32-NEXT: vsrl.vx v16, v16, a1
+; RV32-NEXT: vmv.x.s t3, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m4, ta, ma
+; RV32-NEXT: vslidedown.vi v16, v8, 5
; RV32-NEXT: vmv.x.s t2, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m8, ta, ma
+; RV32-NEXT: vsrl.vx v16, v16, a1
+; RV32-NEXT: vmv.x.s t5, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m4, ta, ma
+; RV32-NEXT: vslidedown.vi v16, v8, 6
+; RV32-NEXT: vmv.x.s t4, v16
+; RV32-NEXT: vsetivli zero, 1, e64, m8, ta, ma
+; RV32-NEXT: vsrl.vx v16, v16, a1
+; RV32-NEXT: vmv.x.s ra, v16
; RV32-NEXT: vslidedown.vi v16, v8, 7
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s6, v24
-; RV32-NEXT: vmv.x.s t3, v16
+; RV32-NEXT: vmv.x.s s5, v24
+; RV32-NEXT: vmv.x.s t6, v16
; RV32-NEXT: vslidedown.vi v16, v8, 8
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s7, v24
-; RV32-NEXT: vmv.x.s t4, v16
+; RV32-NEXT: vmv.x.s s6, v24
+; RV32-NEXT: vmv.x.s s0, v16
; RV32-NEXT: vslidedown.vi v16, v8, 9
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s8, v24
-; RV32-NEXT: vmv.x.s t5, v16
+; RV32-NEXT: vmv.x.s s7, v24
+; RV32-NEXT: vmv.x.s s1, v16
; RV32-NEXT: vslidedown.vi v16, v8, 10
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s9, v24
-; RV32-NEXT: vmv.x.s t6, v16
+; RV32-NEXT: vmv.x.s s8, v24
+; RV32-NEXT: vmv.x.s s2, v16
; RV32-NEXT: vslidedown.vi v16, v8, 11
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s10, v24
-; RV32-NEXT: vmv.x.s s0, v16
+; RV32-NEXT: vmv.x.s s9, v24
+; RV32-NEXT: vmv.x.s s3, v16
; RV32-NEXT: vslidedown.vi v16, v8, 12
; RV32-NEXT: vsrl.vx v24, v16, a1
-; RV32-NEXT: vmv.x.s s11, v24
-; RV32-NEXT: vmv.x.s s1, v16
+; RV32-NEXT: vmv.x.s s10, v24
+; RV32-NEXT: vmv.x.s s4, v16
; RV32-NEXT: vslidedown.vi v0, v8, 13
; RV32-NEXT: vsrl.vx v16, v0, a1
-; RV32-NEXT: vmv.x.s ra, v16
+; RV32-NEXT: vmv.x.s s11, v16
; RV32-NEXT: vslidedown.vi v16, v8, 14
; RV32-NEXT: vsrl.vx v24, v16, a1
; RV32-NEXT: vslidedown.vi v8, v8, 15
; RV32-NEXT: vmv.x.s a2, v0
; RV32-NEXT: vsrl.vx v0, v8, a1
; RV32-NEXT: lw a1, 8(sp) # 4-byte Folded Reload
-; RV32-NEXT: add a5, a1, a5
-; RV32-NEXT: add a6, a0, a6
-; RV32-NEXT: sltu a0, a6, a0
-; RV32-NEXT: add a0, a5, a0
-; RV32-NEXT: add a0, a0, a3
-; RV32-NEXT: add a4, a6, a4
-; RV32-NEXT: sltu a1, a4, a6
-; RV32-NEXT: add a1, a1, s2
+; RV32-NEXT: add a4, a1, a4
+; RV32-NEXT: add a3, a0, a3
+; RV32-NEXT: sltu a0, a3, a0
+; RV32-NEXT: add a0, a4, a0
+; RV32-NEXT: add a0, a0, a5
+; RV32-NEXT: add a6, a3, a6
+; RV32-NEXT: sltu a1, a6, a3
+; RV32-NEXT: add a1, a1, t0
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add a7, a4, a7
-; RV32-NEXT: sltu a1, a7, a4
-; RV32-NEXT: add a1, a1, s3
+; RV32-NEXT: add a7, a6, a7
+; RV32-NEXT: sltu a1, a7, a6
+; RV32-NEXT: add a1, a1, t3
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add t0, a7, t0
-; RV32-NEXT: sltu a1, t0, a7
-; RV32-NEXT: add a1, a1, s4
-; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add t1, t0, t1
-; RV32-NEXT: sltu a1, t1, t0
-; RV32-NEXT: add a1, a1, s5
+; RV32-NEXT: add t1, a7, t1
+; RV32-NEXT: sltu a1, t1, a7
+; RV32-NEXT: add a1, a1, t5
; RV32-NEXT: add a0, a0, a1
; RV32-NEXT: add t2, t1, t2
; RV32-NEXT: sltu a1, t2, t1
+; RV32-NEXT: add a1, a1, ra
+; RV32-NEXT: add a0, a0, a1
+; RV32-NEXT: add t4, t2, t4
+; RV32-NEXT: sltu a1, t4, t2
+; RV32-NEXT: add a1, a1, s5
+; RV32-NEXT: add a0, a0, a1
+; RV32-NEXT: add t6, t4, t6
+; RV32-NEXT: sltu a1, t6, t4
; RV32-NEXT: add a1, a1, s6
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add t3, t2, t3
-; RV32-NEXT: sltu a1, t3, t2
+; RV32-NEXT: add s0, t6, s0
+; RV32-NEXT: sltu a1, s0, t6
; RV32-NEXT: add a1, a1, s7
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add t4, t3, t4
-; RV32-NEXT: sltu a1, t4, t3
+; RV32-NEXT: add s1, s0, s1
+; RV32-NEXT: sltu a1, s1, s0
; RV32-NEXT: add a1, a1, s8
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add t5, t4, t5
-; RV32-NEXT: sltu a1, t5, t4
+; RV32-NEXT: add s2, s1, s2
+; RV32-NEXT: sltu a1, s2, s1
; RV32-NEXT: add a1, a1, s9
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add t6, t5, t6
-; RV32-NEXT: sltu a1, t6, t5
+; RV32-NEXT: add s3, s2, s3
+; RV32-NEXT: sltu a1, s3, s2
; RV32-NEXT: add a1, a1, s10
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add s0, t6, s0
-; RV32-NEXT: sltu a1, s0, t6
+; RV32-NEXT: add s4, s3, s4
+; RV32-NEXT: sltu a1, s4, s3
; RV32-NEXT: add a1, a1, s11
; RV32-NEXT: add a0, a0, a1
-; RV32-NEXT: add s1, s0, s1
-; RV32-NEXT: sltu a1, s1, s0
-; RV32-NEXT: add a1, a1, ra
-; RV32-NEXT: add a0, a0, a1
; RV32-NEXT: vmv.x.s a1, v24
-; RV32-NEXT: add a2, s1, a2
-; RV32-NEXT: sltu a3, a2, s1
+; RV32-NEXT: add a2, s4, a2
+; RV32-NEXT: sltu a3, a2, s4
; RV32-NEXT: add a1, a3, a1
; RV32-NEXT: vmv.x.s a3, v16
; RV32-NEXT: add a0, a0, a1
More information about the llvm-commits
mailing list