[llvm] [RISCV] Remove RISCVISD::VNSRL_VL and adjust deinterleave lowering to match (PR #118391)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 2 11:41:17 PST 2024
https://github.com/preames created https://github.com/llvm/llvm-project/pull/118391
Instead of directly lowering to vnsrl_vl and having custom pattern matching for that case, we can just lower to a (legal) shift and truncate, and let generic pattern matching produce the vnsrl.
The major motivation for this is that I'm going to reuse this logic to handle e.g. deinterleave4 w/ i8 result.
The test changes aren't particularly interesting. They're minor code improvements - I think because we do slightly better with the insert_subvector patterns, but that's mostly irrelevant.
>From 1aa53646afdb6d8393d0a0dfbb2d263b29a72222 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 2 Dec 2024 11:15:33 -0800
Subject: [PATCH] [RISCV] Remove RISCVISD::VNSRL_VL and adjust deinterleave
lowering to match
Instead of directly lowering to vnsrl_vl and having custom pattern matching
for that case, we can just lower to a (legal) shift and truncate, and let
generic pattern matching produce the vnsrl.
The major motivation for this is that I'm going to reuse this logic to
handle e.g. deinterleave4 w/ i8 result.
The test changes aren't particularly interesting. They're minor code
improvements - I think because we do slightly better with the
insert_subvector patterns, but that's mostly irrelevant.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 66 +++++++------------
llvm/lib/Target/RISCV/RISCVISelLowering.h | 4 --
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 36 ----------
.../fixed-vectors-shuffle-changes-length.ll | 29 ++++----
4 files changed, 36 insertions(+), 99 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 78dc3cb27a6988..c68abcd5916049 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4618,51 +4618,32 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
// VT is the type of the vector to return, <[vscale x ]n x ty>
// Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
- bool EvenElts,
- const RISCVSubtarget &Subtarget,
- SelectionDAG &DAG) {
- // The result is a vector of type <m x n x ty>
- MVT ContainerVT = VT;
- // Convert fixed vectors to scalable if needed
- if (ContainerVT.isFixedLengthVector()) {
- assert(Src.getSimpleValueType().isFixedLengthVector());
- ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);
-
- // The source is a vector of type <m x n*2 x ty> (For the single source
- // case, the high half is undefined)
- MVT SrcContainerVT =
- MVT::getVectorVT(ContainerVT.getVectorElementType(),
- ContainerVT.getVectorElementCount() * 2);
- Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
+ bool EvenElts, SelectionDAG &DAG) {
+ // FLAGIT
+ // The result is a vector of type <m x n x ty>. The source is a vector of
+ // type <m x n*2 x ty> (For the single source case, the high half is undef)
+ if (Src.getValueType() == VT) {
+ EVT WideVT = VT.getDoubleNumVectorElementsVT();
+ Src = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, DAG.getUNDEF(WideVT),
+ Src, DAG.getVectorIdxConstant(0, DL));
}
- auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-
// Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
// This also converts FP to int.
- unsigned EltBits = ContainerVT.getScalarSizeInBits();
- MVT WideSrcContainerVT = MVT::getVectorVT(
- MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount());
- Src = DAG.getBitcast(WideSrcContainerVT, Src);
+ unsigned EltBits = VT.getScalarSizeInBits();
+ MVT WideSrcVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2),
+ VT.getVectorElementCount());
+ Src = DAG.getBitcast(WideSrcVT, Src);
- // The integer version of the container type.
- MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger();
+ MVT IntVT = VT.changeVectorElementTypeToInteger();
// If we want even elements, then the shift amount is 0. Otherwise, shift by
// the original element size.
unsigned Shift = EvenElts ? 0 : EltBits;
- SDValue SplatShift = DAG.getNode(
- RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT),
- DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL);
- SDValue Res =
- DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift,
- DAG.getUNDEF(IntContainerVT), TrueMask, VL);
- // Cast back to FP if needed.
- Res = DAG.getBitcast(ContainerVT, Res);
-
- if (VT.isFixedLengthVector())
- Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
- return Res;
+ SDValue Res = DAG.getNode(ISD::SRL, DL, WideSrcVT, Src,
+ DAG.getConstant(Shift, DL, WideSrcVT));
+ Res = DAG.getNode(ISD::TRUNCATE, DL, IntVT, Res);
+ return DAG.getBitcast(VT, Res);
}
// Lower the following shuffle to vslidedown.
@@ -5356,7 +5337,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
// vnsrl to deinterleave.
if (SDValue Src =
isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget))
- return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, Subtarget, DAG);
+ return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, DAG);
if (SDValue V =
lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
@@ -6258,7 +6239,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 128 &&
+ 127 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -6284,7 +6265,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 128 &&
+ 127 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -10763,10 +10744,8 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
// We can deinterleave through vnsrl.wi if the element type is smaller than
// ELEN
if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
- SDValue Even =
- getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG);
- SDValue Odd =
- getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG);
+ SDValue Even = getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, DAG);
+ SDValue Odd = getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, DAG);
return DAG.getMergeValues({Even, Odd}, DL);
}
@@ -20494,7 +20473,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWMACC_VL)
NODE_NAME_CASE(VWMACCU_VL)
NODE_NAME_CASE(VWMACCSU_VL)
- NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VMERGE_VL)
NODE_NAME_CASE(VMAND_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 7ada941563c1ff..c753469562ebac 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -369,10 +369,6 @@ enum NodeType : unsigned {
VWMACCU_VL,
VWMACCSU_VL,
- // Narrowing logical shift right.
- // Operands are (source, shift, passthru, mask, vl)
- VNSRL_VL,
-
// Vector compare producing a mask. Fourth operand is input mask. Fifth
// operand is VL.
SETCC_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 4b938fc734e5c1..e48a6f9309294b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -459,17 +459,6 @@ def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWFPBinOp_VL, [SDNP
def riscv_vfwadd_vl : SDNode<"RISCVISD::VFWADD_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
def riscv_vfwsub_vl : SDNode<"RISCVISD::VFWSUB_VL", SDT_RISCVVWFPBinOp_VL, []>;
-def SDT_RISCVVNIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
- SDTCisInt<1>,
- SDTCisSameNumEltsAs<0, 1>,
- SDTCisOpSmallerThanOp<0, 1>,
- SDTCisSameAs<0, 2>,
- SDTCisSameAs<0, 3>,
- SDTCisSameNumEltsAs<0, 4>,
- SDTCVecEltisVT<4, i1>,
- SDTCisVT<5, XLenVT>]>;
-def riscv_vnsrl_vl : SDNode<"RISCVISD::VNSRL_VL", SDT_RISCVVNIntBinOp_VL>;
-
def SDT_RISCVVWIntBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisSameAs<0, 1>,
SDTCisInt<2>,
@@ -885,29 +874,6 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
}
}
-multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name> {
- foreach VtiToWti = AllWidenableIntVectors in {
- defvar vti = VtiToWti.Vti;
- defvar wti = VtiToWti.Wti;
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
- def : VPatBinaryVL_V<vop, instruction_name, "WV",
- vti.Vector, wti.Vector, vti.Vector, vti.Mask,
- vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
- vti.RegClass>;
- def : VPatBinaryVL_XI<vop, instruction_name, "WX",
- vti.Vector, wti.Vector, vti.Vector, vti.Mask,
- vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
- SplatPat, GPR>;
- def : VPatBinaryVL_XI<vop, instruction_name, "WI",
- vti.Vector, wti.Vector, vti.Vector, vti.Mask,
- vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
- !cast<ComplexPattern>(SplatPat#_#uimm5),
- uimm5>;
- }
- }
-}
-
class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
@@ -2166,8 +2132,6 @@ defm : VPatNarrowShiftSplatExt_WX<riscv_srl_vl, riscv_zext_vl_oneuse, "PseudoVNS
defm : VPatNarrowShiftVL_WV<riscv_srl_vl, "PseudoVNSRL">;
defm : VPatNarrowShiftVL_WV<riscv_sra_vl, "PseudoVNSRA">;
-defm : VPatBinaryNVL_WV_WX_WI<riscv_vnsrl_vl, "PseudoVNSRL">;
-
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
index 8b18be908089f2..9d2c722334b080 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
@@ -97,33 +97,31 @@ define <4 x i32> @v4i32_v8i32(<8 x i32>) {
define <4 x i32> @v4i32_v16i32(<16 x i32>) {
; RV32-LABEL: v4i32_v16i32:
; RV32: # %bb.0:
-; RV32-NEXT: vsetivli zero, 8, e32, m4, ta, ma
-; RV32-NEXT: vslidedown.vi v16, v8, 8
-; RV32-NEXT: vmv4r.v v20, v8
; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
-; RV32-NEXT: vmv.v.i v8, 1
-; RV32-NEXT: vmv2r.v v22, v12
-; RV32-NEXT: vmv.v.i v10, 6
+; RV32-NEXT: vmv.v.i v12, 1
+; RV32-NEXT: vmv.v.i v14, 6
; RV32-NEXT: li a0, 32
; RV32-NEXT: vmv.v.i v0, 10
; RV32-NEXT: vsetivli zero, 2, e16, m1, tu, ma
-; RV32-NEXT: vslideup.vi v10, v8, 1
+; RV32-NEXT: vslideup.vi v14, v12, 1
+; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; RV32-NEXT: vnsrl.wx v12, v8, a0
+; RV32-NEXT: vsetivli zero, 8, e32, m4, ta, ma
+; RV32-NEXT: vslidedown.vi v8, v8, 8
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, mu
-; RV32-NEXT: vnsrl.wx v8, v20, a0
-; RV32-NEXT: vrgatherei16.vv v8, v16, v10, v0.t
+; RV32-NEXT: vrgatherei16.vv v12, v8, v14, v0.t
+; RV32-NEXT: vmv1r.v v8, v12
; RV32-NEXT: ret
;
; RV64-LABEL: v4i32_v16i32:
; RV64: # %bb.0:
-; RV64-NEXT: vsetivli zero, 8, e32, m4, ta, ma
-; RV64-NEXT: vslidedown.vi v16, v8, 8
-; RV64-NEXT: vmv4r.v v20, v8
; RV64-NEXT: li a0, 32
-; RV64-NEXT: vmv2r.v v22, v12
; RV64-NEXT: vsetivli zero, 1, e8, mf8, ta, ma
; RV64-NEXT: vmv.v.i v0, 10
; RV64-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; RV64-NEXT: vnsrl.wx v8, v20, a0
+; RV64-NEXT: vnsrl.wx v12, v8, a0
+; RV64-NEXT: vsetivli zero, 8, e32, m4, ta, ma
+; RV64-NEXT: vslidedown.vi v8, v8, 8
; RV64-NEXT: li a0, 3
; RV64-NEXT: slli a0, a0, 33
; RV64-NEXT: addi a0, a0, 1
@@ -131,7 +129,8 @@ define <4 x i32> @v4i32_v16i32(<16 x i32>) {
; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; RV64-NEXT: vmv.v.x v10, a0
; RV64-NEXT: vsetivli zero, 8, e32, m2, ta, mu
-; RV64-NEXT: vrgatherei16.vv v8, v16, v10, v0.t
+; RV64-NEXT: vrgatherei16.vv v12, v8, v10, v0.t
+; RV64-NEXT: vmv1r.v v8, v12
; RV64-NEXT: ret
%2 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 1, i32 9, i32 5, i32 14>
ret <4 x i32> %2
More information about the llvm-commits
mailing list