[llvm] 84413e1 - [RISCV] Support fixed-length vector truncates
Fraser Cormack via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 25 04:17:52 PST 2021
Author: Fraser Cormack
Date: 2021-02-25T12:11:34Z
New Revision: 84413e1947427a917a3e55abfc1f66c42adc751b
URL: https://github.com/llvm/llvm-project/commit/84413e1947427a917a3e55abfc1f66c42adc751b
DIFF: https://github.com/llvm/llvm-project/commit/84413e1947427a917a3e55abfc1f66c42adc751b.diff
LOG: [RISCV] Support fixed-length vector truncates
This patch extends support for our custom-lowering of scalable-vector
truncates to include those of fixed-length vectors. It does this by
co-opting the custom RISCVISD::TRUNCATE_VECTOR node and adding mask and
VL operands. This avoids unnecessary duplication of patterns and
inflation of the ISel table.
Some truncates go through CONCAT_VECTORS which currently isn't
efficiently handled, as it goes through the stack. This can be improved
upon in the future.
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D97202
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index abd9a3d6d8c4..f18966706639 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -446,7 +446,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
- // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
+ // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
setOperationAction(ISD::TRUNCATE, VT, Custom);
@@ -526,6 +526,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// By default everything must be expanded.
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
setOperationAction(Op, VT, Expand);
+ for (MVT OtherVT : MVT::fixedlen_vector_valuetypes())
+ setTruncStoreAction(VT, OtherVT, Expand);
// We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
@@ -571,6 +573,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VSELECT, VT, Custom);
+ setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
@@ -1171,7 +1174,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
}
case ISD::TRUNCATE: {
SDLoc DL(Op);
- EVT VT = Op.getValueType();
+ MVT VT = Op.getSimpleValueType();
// Only custom-lower vector truncates
if (!VT.isVector())
return Op;
@@ -1181,28 +1184,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerVectorMaskTrunc(Op, DAG);
// RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary
- // truncates as a series of "RISCVISD::TRUNCATE_VECTOR" nodes which
+ // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which
// truncate by one power of two at a time.
- EVT DstEltVT = VT.getVectorElementType();
+ MVT DstEltVT = VT.getVectorElementType();
SDValue Src = Op.getOperand(0);
- EVT SrcVT = Src.getValueType();
- EVT SrcEltVT = SrcVT.getVectorElementType();
+ MVT SrcVT = Src.getSimpleValueType();
+ MVT SrcEltVT = SrcVT.getVectorElementType();
assert(DstEltVT.bitsLT(SrcEltVT) &&
isPowerOf2_64(DstEltVT.getSizeInBits()) &&
isPowerOf2_64(SrcEltVT.getSizeInBits()) &&
"Unexpected vector truncate lowering");
+ MVT ContainerVT = SrcVT;
+ if (SrcVT.isFixedLengthVector()) {
+ ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
+ DAG, SrcVT, Subtarget);
+ Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
+ }
+
SDValue Result = Src;
+ SDValue Mask, VL;
+ std::tie(Mask, VL) =
+ getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
LLVMContext &Context = *DAG.getContext();
- const ElementCount Count = SrcVT.getVectorElementCount();
+ const ElementCount Count = ContainerVT.getVectorElementCount();
do {
- SrcEltVT = EVT::getIntegerVT(Context, SrcEltVT.getSizeInBits() / 2);
+ SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
- Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR, DL, ResultVT, Result);
+ Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
+ Mask, VL);
} while (SrcEltVT != DstEltVT);
+ if (SrcVT.isFixedLengthVector())
+ Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
+
return Result;
}
case ISD::ANY_EXTEND:
@@ -5437,7 +5454,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VMV_X_S)
NODE_NAME_CASE(SPLAT_VECTOR_I64)
NODE_NAME_CASE(READ_VLENB)
- NODE_NAME_CASE(TRUNCATE_VECTOR)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+ NODE_NAME_CASE(VLEFF)
+ NODE_NAME_CASE(VLEFF_MASK)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
NODE_NAME_CASE(VID_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index dc7e05ea6704..a75ebc38cc2e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -105,8 +105,12 @@ enum NodeType : unsigned {
SPLAT_VECTOR_I64,
// Read VLENB CSR
READ_VLENB,
- // Truncates a RVV integer vector by one power-of-two.
- TRUNCATE_VECTOR,
+ // Truncates a RVV integer vector by one power-of-two. Carries both an extra
+ // mask and VL operand.
+ TRUNCATE_VECTOR_VL,
+ // Unit-stride fault-only-first load
+ VLEFF,
+ VLEFF_MASK,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the
// XLenVT index (either constant or non-constant), the fourth is the mask
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index ea0e5f137359..c552865c6ec9 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -28,10 +28,6 @@ def SDTSplatI64 : SDTypeProfile<1, 1, [
def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>;
-def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR",
- SDTypeProfile<1, 1,
- [SDTCisVec<0>, SDTCisVec<1>]>>;
-
// Give explicit Complexity to prefer simm5/uimm5.
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector, rv32_splat_i64], [], 1>;
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
@@ -433,15 +429,6 @@ defm "" : VPatBinarySDNode_VV_VX_VI<shl, "PseudoVSLL", uimm5>;
defm "" : VPatBinarySDNode_VV_VX_VI<srl, "PseudoVSRL", uimm5>;
defm "" : VPatBinarySDNode_VV_VX_VI<sra, "PseudoVSRA", uimm5>;
-// 12.7. Vector Narrowing Integer Right Shift Instructions
-foreach vtiTofti = AllFractionableVF2IntVectors in {
- defvar vti = vtiTofti.Vti;
- defvar fti = vtiTofti.Fti;
- def : Pat<(fti.Vector (riscv_trunc_vector (vti.Vector vti.RegClass:$rs1))),
- (!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
- vti.RegClass:$rs1, 0, fti.AVL, fti.SEW)>;
-}
-
// 12.8. Vector Integer Comparison Instructions
defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETEQ, "PseudoVMSEQ">;
defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETNE, "PseudoVMSNE">;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 76eb5f68a0c4..2d5f8fa447fc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -148,6 +148,13 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
+def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
+ SDTypeProfile<1, 3, [SDTCisVec<0>,
+ SDTCisVec<1>,
+ SDTCisSameNumEltsAs<0, 2>,
+ SDTCVecEltisVT<2, i1>,
+ SDTCisVT<3, XLenVT>]>>;
+
// Ignore the vl operand.
def SplatFPOp : PatFrag<(ops node:$op),
(riscv_vfmv_v_f_vl node:$op, srcvalue)>;
@@ -443,6 +450,17 @@ defm "" : VPatBinaryVL_VV_VX_VI<riscv_shl_vl, "PseudoVSLL", uimm5>;
defm "" : VPatBinaryVL_VV_VX_VI<riscv_srl_vl, "PseudoVSRL", uimm5>;
defm "" : VPatBinaryVL_VV_VX_VI<riscv_sra_vl, "PseudoVSRA", uimm5>;
+// 12.7. Vector Narrowing Integer Right Shift Instructions
+foreach vtiTofti = AllFractionableVF2IntVectors in {
+ defvar vti = vtiTofti.Vti;
+ defvar fti = vtiTofti.Fti;
+ def : Pat<(fti.Vector (riscv_trunc_vector_vl (vti.Vector vti.RegClass:$rs1),
+ (vti.Mask true_mask),
+ (XLenVT (VLOp GPR:$vl)))),
+ (!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
+ vti.RegClass:$rs1, 0, GPR:$vl, fti.SEW)>;
+}
+
// 12.8. Vector Integer Comparison Instructions
foreach vti = AllIntegerVectors in {
defm "" : VPatIntegerSetCCVL_VV<vti, "PseudoVMSEQ", SETEQ>;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
index ad04c1ad7ba0..e4e033abc06f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
@@ -165,3 +165,80 @@ define void @sext_v32i8_v32i32(<32 x i8>* %x, <32 x i32>* %z) {
store <32 x i32> %b, <32 x i32>* %z
ret void
}
+
+define void @trunc_v4i8_v4i32(<4 x i32>* %x, <4 x i8>* %z) {
+; CHECK-LABEL: trunc_v4i8_v4i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli a2, 4, e32,m1,ta,mu
+; CHECK-NEXT: vle32.v v25, (a0)
+; CHECK-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
+; CHECK-NEXT: vnsrl.wi v26, v25, 0
+; CHECK-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
+; CHECK-NEXT: vnsrl.wi v25, v26, 0
+; CHECK-NEXT: vsetivli a0, 4, e8,m1,ta,mu
+; CHECK-NEXT: vse8.v v25, (a1)
+; CHECK-NEXT: ret
+ %a = load <4 x i32>, <4 x i32>* %x
+ %b = trunc <4 x i32> %a to <4 x i8>
+ store <4 x i8> %b, <4 x i8>* %z
+ ret void
+}
+
+define void @trunc_v8i8_v8i32(<8 x i32>* %x, <8 x i8>* %z) {
+; LMULMAX8-LABEL: trunc_v8i8_v8i32:
+; LMULMAX8: # %bb.0:
+; LMULMAX8-NEXT: vsetivli a2, 8, e32,m2,ta,mu
+; LMULMAX8-NEXT: vle32.v v26, (a0)
+; LMULMAX8-NEXT: vsetivli a0, 8, e16,m1,ta,mu
+; LMULMAX8-NEXT: vnsrl.wi v25, v26, 0
+; LMULMAX8-NEXT: vsetivli a0, 8, e8,mf2,ta,mu
+; LMULMAX8-NEXT: vnsrl.wi v26, v25, 0
+; LMULMAX8-NEXT: vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX8-NEXT: vse8.v v26, (a1)
+; LMULMAX8-NEXT: ret
+;
+; LMULMAX2-LABEL: trunc_v8i8_v8i32:
+; LMULMAX2: # %bb.0:
+; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu
+; LMULMAX2-NEXT: vle32.v v26, (a0)
+; LMULMAX2-NEXT: vsetivli a0, 8, e16,m1,ta,mu
+; LMULMAX2-NEXT: vnsrl.wi v25, v26, 0
+; LMULMAX2-NEXT: vsetivli a0, 8, e8,mf2,ta,mu
+; LMULMAX2-NEXT: vnsrl.wi v26, v25, 0
+; LMULMAX2-NEXT: vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX2-NEXT: vse8.v v26, (a1)
+; LMULMAX2-NEXT: ret
+;
+; LMULMAX1-LABEL: trunc_v8i8_v8i32:
+; LMULMAX1: # %bb.0:
+; LMULMAX1-NEXT: addi sp, sp, -16
+; LMULMAX1-NEXT: .cfi_def_cfa_offset 16
+; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT: addi a2, a0, 16
+; LMULMAX1-NEXT: vle32.v v25, (a2)
+; LMULMAX1-NEXT: vle32.v v26, (a0)
+; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
+; LMULMAX1-NEXT: vnsrl.wi v27, v25, 0
+; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
+; LMULMAX1-NEXT: vnsrl.wi v25, v27, 0
+; LMULMAX1-NEXT: addi a0, sp, 12
+; LMULMAX1-NEXT: vsetivli a2, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT: vse8.v v25, (a0)
+; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu
+; LMULMAX1-NEXT: vnsrl.wi v25, v26, 0
+; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu
+; LMULMAX1-NEXT: vnsrl.wi v26, v25, 0
+; LMULMAX1-NEXT: vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT: addi a0, sp, 8
+; LMULMAX1-NEXT: vse8.v v26, (a0)
+; LMULMAX1-NEXT: vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX1-NEXT: addi a0, sp, 8
+; LMULMAX1-NEXT: vle8.v v25, (a0)
+; LMULMAX1-NEXT: vse8.v v25, (a1)
+; LMULMAX1-NEXT: addi sp, sp, 16
+; LMULMAX1-NEXT: ret
+ %a = load <8 x i32>, <8 x i32>* %x
+ %b = trunc <8 x i32> %a to <8 x i8>
+ store <8 x i8> %b, <8 x i8>* %z
+ ret void
+}
More information about the llvm-commits
mailing list