[llvm] 16d3a82 - [RISCV] Add merge operand to RISCVISD::VRGATHER*_VL nodes.
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 20 19:16:46 PDT 2022
Author: Craig Topper
Date: 2022-06-20T18:58:24-07:00
New Revision: 16d3a82de53dab4bb5ed468aff92df276f8a6e39
URL: https://github.com/llvm/llvm-project/commit/16d3a82de53dab4bb5ed468aff92df276f8a6e39
DIFF: https://github.com/llvm/llvm-project/commit/16d3a82de53dab4bb5ed468aff92df276f8a6e39.diff
LOG: [RISCV] Add merge operand to RISCVISD::VRGATHER*_VL nodes.
Use it in place of VSELECT_VL+VRGATHER*_VL.
This simplifies the isel patterns.
Overall, I think trying to match select+op to create masked instructions
in isel doesn't scale. We either need to do it in DAG combine, pre-isel
peepole, or post-isel peephole. I don't yet know which is the right
answer, but for this case it seemed best to be able to request the
masked form directly from lowering.
Reviewed By: frasercrmck
Differential Revision: https://reviews.llvm.org/D128023
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b80dc298890c5..60e490fa2ce93 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1957,8 +1957,8 @@ static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
SDValue Mask, VL;
std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
- SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
- Idx, Mask, VL);
+ SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT,
+ DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
if (!VT.isFixedLengthVector())
return Gather;
@@ -2581,9 +2581,9 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget);
assert(Lane < (int)NumElts && "Unexpected lane!");
- SDValue Gather =
- DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, V1,
- DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL);
+ SDValue Gather = DAG.getNode(
+ RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
+ V1, DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL);
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
}
}
@@ -2793,16 +2793,17 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
// that's beneficial.
if (LHSIndexCounts.size() == 1) {
int SplatIndex = LHSIndexCounts.begin()->getFirst();
- Gather =
- DAG.getNode(GatherVXOpc, DL, ContainerVT, V1,
- DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL);
+ Gather = DAG.getNode(
+ GatherVXOpc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), V1,
+ DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL);
} else {
SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS);
LHSIndices =
convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget);
- Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices,
- TrueMask, VL);
+ Gather =
+ DAG.getNode(GatherVVOpc, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
+ V1, LHSIndices, TrueMask, VL);
}
}
@@ -2810,27 +2811,26 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
// additional vrgather.
if (!V2.isUndef()) {
V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget);
+
+ MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
+ SelectMask =
+ convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget);
+
// If only one index is used, we can use a "splat" vrgather.
// TODO: We can splat the most-common index and fix-up any stragglers, if
// that's beneficial.
if (RHSIndexCounts.size() == 1) {
int SplatIndex = RHSIndexCounts.begin()->getFirst();
- V2 = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2,
- DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL);
+ Gather =
+ DAG.getNode(GatherVXOpc, DL, ContainerVT, Gather, V2,
+ DAG.getConstant(SplatIndex, DL, XLenVT), SelectMask, VL);
} else {
SDValue RHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesRHS);
RHSIndices =
convertToScalableVector(IndexContainerVT, RHSIndices, DAG, Subtarget);
- V2 = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, TrueMask,
- VL);
+ Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, Gather, V2, RHSIndices,
+ SelectMask, VL);
}
-
- MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1);
- SelectMask =
- convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget);
-
- Gather = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, SelectMask, V2,
- Gather, VL);
}
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
@@ -5691,7 +5691,8 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
SDValue Indices =
DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, VL);
- return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, VL);
+ return DAG.getNode(GatherOpc, DL, VecVT, DAG.getUNDEF(VecVT),
+ Op.getOperand(0), Indices, Mask, VL);
}
SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op,
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 91fdd5e887482..672086da46d82 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -286,8 +286,8 @@ enum NodeType : unsigned {
VMCLR_VL,
VMSET_VL,
- // Matches the semantics of vrgather.vx and vrgather.vv with an extra operand
- // for VL.
+ // Matches the semantics of vrgather.vx and vrgather.vv with extra operands
+ // for passthru and VL. First operand is the passthru operand.
VRGATHER_VX_VL,
VRGATHER_VV_VL,
VRGATHEREI16_VV_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 78c81557aea97..5b1ce03984a2d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -143,30 +143,33 @@ def riscv_setcc_vl : SDNode<"RISCVISD::SETCC_VL",
SDTCisVT<5, XLenVT>]>>;
def riscv_vrgather_vx_vl : SDNode<"RISCVISD::VRGATHER_VX_VL",
- SDTypeProfile<1, 4, [SDTCisVec<0>,
+ SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
- SDTCisVT<2, XLenVT>,
- SDTCVecEltisVT<3, i1>,
- SDTCisSameNumEltsAs<0, 3>,
- SDTCisVT<4, XLenVT>]>>;
+ SDTCisSameAs<0, 2>,
+ SDTCisVT<3, XLenVT>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCisVT<5, XLenVT>]>>;
def riscv_vrgather_vv_vl : SDNode<"RISCVISD::VRGATHER_VV_VL",
- SDTypeProfile<1, 4, [SDTCisVec<0>,
+ SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
- SDTCisInt<2>,
- SDTCisSameNumEltsAs<0, 2>,
- SDTCisSameSizeAs<0, 2>,
- SDTCVecEltisVT<3, i1>,
+ SDTCisSameAs<0, 2>,
+ SDTCisInt<3>,
SDTCisSameNumEltsAs<0, 3>,
- SDTCisVT<4, XLenVT>]>>;
+ SDTCisSameSizeAs<0, 3>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCisVT<5, XLenVT>]>>;
def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL",
- SDTypeProfile<1, 4, [SDTCisVec<0>,
+ SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
- SDTCisInt<2>,
- SDTCVecEltisVT<2, i16>,
- SDTCisSameNumEltsAs<0, 2>,
- SDTCVecEltisVT<3, i1>,
+ SDTCisSameAs<0, 2>,
+ SDTCisInt<3>,
+ SDTCVecEltisVT<3, i16>,
SDTCisSameNumEltsAs<0, 3>,
- SDTCisVT<4, XLenVT>]>>;
+ SDTCVecEltisVT<4, i1>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCisVT<5, XLenVT>]>>;
def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
@@ -1835,43 +1838,40 @@ foreach vti = AllIntegerVectors in {
(!cast<Instruction>("PseudoVMV_S_X_"#vti.LMul.MX)
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2,
+ def : Pat<(vti.Vector (riscv_vrgather_vv_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2,
(vti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX)
vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
+ def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VX_"# vti.LMul.MX)
vti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, uimm5:$imm,
+ def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2, uimm5:$imm,
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX)
vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
- (riscv_vrgather_vv_vl
- vti.RegClass:$rs2,
- vti.RegClass:$rs1,
- (vti.Mask true_mask),
- VLOpFrag),
- vti.RegClass:$merge,
- VLOpFrag)),
+ def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$merge,
+ vti.RegClass:$rs2,
+ vti.RegClass:$rs1,
+ (vti.Mask V0),
+ VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX#"_MASK")
vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
- def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
- (riscv_vrgather_vx_vl
- vti.RegClass:$rs2,
- uimm5:$imm,
- (vti.Mask true_mask),
- VLOpFrag),
- vti.RegClass:$merge,
- VLOpFrag)),
+ def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$merge,
+ vti.RegClass:$rs2,
+ uimm5:$imm,
+ (vti.Mask V0),
+ VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX#"_MASK")
vti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$imm,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
@@ -1884,21 +1884,20 @@ foreach vti = AllIntegerVectors in {
defvar emul_str = octuple_to_str<octuple_emul>.ret;
defvar ivti = !cast<VTypeInfo>("VI16" # emul_str);
defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str;
- def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2,
+ def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2,
(ivti.Vector ivti.RegClass:$rs1),
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>(inst)
vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
- (riscv_vrgatherei16_vv_vl
- vti.RegClass:$rs2,
- (ivti.Vector ivti.RegClass:$rs1),
- (vti.Mask true_mask),
- VLOpFrag),
- vti.RegClass:$merge,
- VLOpFrag)),
+ def : Pat<(vti.Vector
+ (riscv_vrgatherei16_vv_vl vti.RegClass:$merge,
+ vti.RegClass:$rs2,
+ (ivti.Vector ivti.RegClass:$rs1),
+ (vti.Mask V0),
+ VLOpFrag)),
(!cast<Instruction>(inst#"_MASK")
vti.RegClass:$merge, vti.RegClass:$rs2, ivti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
@@ -1923,43 +1922,42 @@ foreach vti = AllFloatVectors in {
vti.RegClass:$merge,
(vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>;
defvar ivti = GetIntVTypeInfo<vti>.Vti;
- def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2,
+ def : Pat<(vti.Vector (riscv_vrgather_vv_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2,
(ivti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX)
vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1,
+ def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VX_"# vti.LMul.MX)
vti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, uimm5:$imm,
+ def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2, uimm5:$imm,
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX)
vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
- (riscv_vrgather_vv_vl
- vti.RegClass:$rs2,
- (ivti.Vector vti.RegClass:$rs1),
- (vti.Mask true_mask),
- VLOpFrag),
- vti.RegClass:$merge,
- VLOpFrag)),
+ def : Pat<(vti.Vector
+ (riscv_vrgather_vv_vl vti.RegClass:$merge,
+ vti.RegClass:$rs2,
+ (ivti.Vector vti.RegClass:$rs1),
+ (vti.Mask V0),
+ VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VV_"# vti.LMul.MX#"_MASK")
vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
- def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
- (riscv_vrgather_vx_vl
- vti.RegClass:$rs2,
- uimm5:$imm,
- (vti.Mask true_mask),
- VLOpFrag),
- vti.RegClass:$merge,
- VLOpFrag)),
+ def : Pat<(vti.Vector
+ (riscv_vrgather_vx_vl vti.RegClass:$merge,
+ vti.RegClass:$rs2,
+ uimm5:$imm,
+ (vti.Mask V0),
+ VLOpFrag)),
(!cast<Instruction>("PseudoVRGATHER_VI_"# vti.LMul.MX#"_MASK")
vti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$imm,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
@@ -1971,21 +1969,20 @@ foreach vti = AllFloatVectors in {
defvar emul_str = octuple_to_str<octuple_emul>.ret;
defvar ivti = !cast<VTypeInfo>("VI16" # emul_str);
defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str;
- def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2,
+ def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl (vti.Vector srcvalue),
+ vti.RegClass:$rs2,
(ivti.Vector ivti.RegClass:$rs1),
(vti.Mask true_mask),
VLOpFrag)),
(!cast<Instruction>(inst)
vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>;
- def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
- (riscv_vrgatherei16_vv_vl
- vti.RegClass:$rs2,
- (ivti.Vector ivti.RegClass:$rs1),
- (vti.Mask true_mask),
- VLOpFrag),
- vti.RegClass:$merge,
- VLOpFrag)),
+ def : Pat<(vti.Vector
+ (riscv_vrgatherei16_vv_vl vti.RegClass:$merge,
+ vti.RegClass:$rs2,
+ (ivti.Vector ivti.RegClass:$rs1),
+ (vti.Mask V0),
+ VLOpFrag)),
(!cast<Instruction>(inst#"_MASK")
vti.RegClass:$merge, vti.RegClass:$rs2, ivti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
More information about the llvm-commits
mailing list