[llvm] 84413e1 - [RISCV] Support fixed-length vector truncates

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 25 04:17:52 PST 2021


Author: Fraser Cormack
Date: 2021-02-25T12:11:34Z
New Revision: 84413e1947427a917a3e55abfc1f66c42adc751b

URL: https://github.com/llvm/llvm-project/commit/84413e1947427a917a3e55abfc1f66c42adc751b
DIFF: https://github.com/llvm/llvm-project/commit/84413e1947427a917a3e55abfc1f66c42adc751b.diff

LOG: [RISCV] Support fixed-length vector truncates

This patch extends support for our custom-lowering of scalable-vector
truncates to include those of fixed-length vectors. It does this by
co-opting the custom RISCVISD::TRUNCATE_VECTOR node and adding mask and
VL operands. This avoids unnecessary duplication of patterns and
inflation of the ISel table.

Some truncates go through CONCAT_VECTORS which currently isn't
efficiently handled, as it goes through the stack. This can be improved
upon in the future.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D97202

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index abd9a3d6d8c4..f18966706639 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -446,7 +446,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FP_TO_SINT, VT, Custom);
       setOperationAction(ISD::FP_TO_UINT, VT, Custom);
 
-      // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
+      // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
       // nodes which truncate by one power of two at a time.
       setOperationAction(ISD::TRUNCATE, VT, Custom);
 
@@ -526,6 +526,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         // By default everything must be expanded.
         for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
           setOperationAction(Op, VT, Expand);
+        for (MVT OtherVT : MVT::fixedlen_vector_valuetypes())
+          setTruncStoreAction(VT, OtherVT, Expand);
 
         // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed.
         setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
@@ -571,6 +573,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
         setOperationAction(ISD::VSELECT, VT, Custom);
 
+        setOperationAction(ISD::TRUNCATE, VT, Custom);
         setOperationAction(ISD::ANY_EXTEND, VT, Custom);
         setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
         setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
@@ -1171,7 +1174,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   }
   case ISD::TRUNCATE: {
     SDLoc DL(Op);
-    EVT VT = Op.getValueType();
+    MVT VT = Op.getSimpleValueType();
     // Only custom-lower vector truncates
     if (!VT.isVector())
       return Op;
@@ -1181,28 +1184,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       return lowerVectorMaskTrunc(Op, DAG);
 
     // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary
-    // truncates as a series of "RISCVISD::TRUNCATE_VECTOR" nodes which
+    // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which
     // truncate by one power of two at a time.
-    EVT DstEltVT = VT.getVectorElementType();
+    MVT DstEltVT = VT.getVectorElementType();
 
     SDValue Src = Op.getOperand(0);
-    EVT SrcVT = Src.getValueType();
-    EVT SrcEltVT = SrcVT.getVectorElementType();
+    MVT SrcVT = Src.getSimpleValueType();
+    MVT SrcEltVT = SrcVT.getVectorElementType();
 
     assert(DstEltVT.bitsLT(SrcEltVT) &&
            isPowerOf2_64(DstEltVT.getSizeInBits()) &&
            isPowerOf2_64(SrcEltVT.getSizeInBits()) &&
            "Unexpected vector truncate lowering");
 
+    MVT ContainerVT = SrcVT;
+    if (SrcVT.isFixedLengthVector()) {
+      ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
+          DAG, SrcVT, Subtarget);
+      Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
+    }
+
     SDValue Result = Src;
+    SDValue Mask, VL;
+    std::tie(Mask, VL) =
+        getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
     LLVMContext &Context = *DAG.getContext();
-    const ElementCount Count = SrcVT.getVectorElementCount();
+    const ElementCount Count = ContainerVT.getVectorElementCount();
     do {
-      SrcEltVT = EVT::getIntegerVT(Context, SrcEltVT.getSizeInBits() / 2);
+      SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
       EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
-      Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR, DL, ResultVT, Result);
+      Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
+                           Mask, VL);
     } while (SrcEltVT != DstEltVT);
 
+    if (SrcVT.isFixedLengthVector())
+      Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
+
     return Result;
   }
   case ISD::ANY_EXTEND:
@@ -5437,7 +5454,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VMV_X_S)
   NODE_NAME_CASE(SPLAT_VECTOR_I64)
   NODE_NAME_CASE(READ_VLENB)
-  NODE_NAME_CASE(TRUNCATE_VECTOR)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+  NODE_NAME_CASE(VLEFF)
+  NODE_NAME_CASE(VLEFF_MASK)
   NODE_NAME_CASE(VSLIDEUP_VL)
   NODE_NAME_CASE(VSLIDEDOWN_VL)
   NODE_NAME_CASE(VID_VL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index dc7e05ea6704..a75ebc38cc2e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -105,8 +105,12 @@ enum NodeType : unsigned {
   SPLAT_VECTOR_I64,
   // Read VLENB CSR
   READ_VLENB,
-  // Truncates a RVV integer vector by one power-of-two.
-  TRUNCATE_VECTOR,
+  // Truncates a RVV integer vector by one power-of-two. Carries both an extra
+  // mask and VL operand.
+  TRUNCATE_VECTOR_VL,
+  // Unit-stride fault-only-first load
+  VLEFF,
+  VLEFF_MASK,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
   // pass-thru operand, the second is the source vector, the third is the
   // XLenVT index (either constant or non-constant), the fourth is the mask

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index ea0e5f137359..c552865c6ec9 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -28,10 +28,6 @@ def SDTSplatI64 : SDTypeProfile<1, 1, [
 
 def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>;
 
-def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR",
-                                SDTypeProfile<1, 1,
-                                 [SDTCisVec<0>, SDTCisVec<1>]>>;
-
 // Give explicit Complexity to prefer simm5/uimm5.
 def SplatPat       : ComplexPattern<vAny, 1, "selectVSplat",      [splat_vector, rv32_splat_i64], [], 1>;
 def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
@@ -433,15 +429,6 @@ defm "" : VPatBinarySDNode_VV_VX_VI<shl, "PseudoVSLL", uimm5>;
 defm "" : VPatBinarySDNode_VV_VX_VI<srl, "PseudoVSRL", uimm5>;
 defm "" : VPatBinarySDNode_VV_VX_VI<sra, "PseudoVSRA", uimm5>;
 
-// 12.7. Vector Narrowing Integer Right Shift Instructions
-foreach vtiTofti = AllFractionableVF2IntVectors in {
-  defvar vti = vtiTofti.Vti;
-  defvar fti = vtiTofti.Fti;
-  def : Pat<(fti.Vector (riscv_trunc_vector (vti.Vector vti.RegClass:$rs1))),
-            (!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
-                vti.RegClass:$rs1, 0, fti.AVL, fti.SEW)>;
-}
-
 // 12.8. Vector Integer Comparison Instructions
 defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETEQ,  "PseudoVMSEQ">;
 defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETNE,  "PseudoVMSNE">;

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 76eb5f68a0c4..2d5f8fa447fc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -148,6 +148,13 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
 def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
 def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
 
+def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
+                                   SDTypeProfile<1, 3, [SDTCisVec<0>,
+                                                        SDTCisVec<1>,
+                                                        SDTCisSameNumEltsAs<0, 2>,
+                                                        SDTCVecEltisVT<2, i1>,
+                                                        SDTCisVT<3, XLenVT>]>>;
+
 // Ignore the vl operand.
 def SplatFPOp : PatFrag<(ops node:$op),
                         (riscv_vfmv_v_f_vl node:$op, srcvalue)>;
@@ -443,6 +450,17 @@ defm "" : VPatBinaryVL_VV_VX_VI<riscv_shl_vl, "PseudoVSLL", uimm5>;
 defm "" : VPatBinaryVL_VV_VX_VI<riscv_srl_vl, "PseudoVSRL", uimm5>;
 defm "" : VPatBinaryVL_VV_VX_VI<riscv_sra_vl, "PseudoVSRA", uimm5>;
 
+// 12.7. Vector Narrowing Integer Right Shift Instructions
+foreach vtiTofti = AllFractionableVF2IntVectors in {
+  defvar vti = vtiTofti.Vti;
+  defvar fti = vtiTofti.Fti;
+  def : Pat<(fti.Vector (riscv_trunc_vector_vl (vti.Vector vti.RegClass:$rs1),
+                                               (vti.Mask true_mask),
+                                               (XLenVT (VLOp GPR:$vl)))),
+            (!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX)
+                vti.RegClass:$rs1, 0, GPR:$vl, fti.SEW)>;
+}
+
 // 12.8. Vector Integer Comparison Instructions
 foreach vti = AllIntegerVectors in {
   defm "" : VPatIntegerSetCCVL_VV<vti, "PseudoVMSEQ", SETEQ>;

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
index ad04c1ad7ba0..e4e033abc06f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
@@ -165,3 +165,80 @@ define void @sext_v32i8_v32i32(<32 x i8>* %x, <32 x i32>* %z) {
   store <32 x i32> %b, <32 x i32>* %z
   ret void
 }
+
+define void @trunc_v4i8_v4i32(<4 x i32>* %x, <4 x i8>* %z) {
+; CHECK-LABEL: trunc_v4i8_v4i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli a2, 4, e32,m1,ta,mu
+; CHECK-NEXT:    vle32.v v25, (a0)
+; CHECK-NEXT:    vsetivli a0, 4, e16,mf2,ta,mu
+; CHECK-NEXT:    vnsrl.wi v26, v25, 0
+; CHECK-NEXT:    vsetivli a0, 4, e8,mf4,ta,mu
+; CHECK-NEXT:    vnsrl.wi v25, v26, 0
+; CHECK-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; CHECK-NEXT:    vse8.v v25, (a1)
+; CHECK-NEXT:    ret
+  %a = load <4 x i32>, <4 x i32>* %x
+  %b = trunc <4 x i32> %a to <4 x i8>
+  store <4 x i8> %b, <4 x i8>* %z
+  ret void
+}
+
+define void @trunc_v8i8_v8i32(<8 x i32>* %x, <8 x i8>* %z) {
+; LMULMAX8-LABEL: trunc_v8i8_v8i32:
+; LMULMAX8:       # %bb.0:
+; LMULMAX8-NEXT:    vsetivli a2, 8, e32,m2,ta,mu
+; LMULMAX8-NEXT:    vle32.v v26, (a0)
+; LMULMAX8-NEXT:    vsetivli a0, 8, e16,m1,ta,mu
+; LMULMAX8-NEXT:    vnsrl.wi v25, v26, 0
+; LMULMAX8-NEXT:    vsetivli a0, 8, e8,mf2,ta,mu
+; LMULMAX8-NEXT:    vnsrl.wi v26, v25, 0
+; LMULMAX8-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX8-NEXT:    vse8.v v26, (a1)
+; LMULMAX8-NEXT:    ret
+;
+; LMULMAX2-LABEL: trunc_v8i8_v8i32:
+; LMULMAX2:       # %bb.0:
+; LMULMAX2-NEXT:    vsetivli a2, 8, e32,m2,ta,mu
+; LMULMAX2-NEXT:    vle32.v v26, (a0)
+; LMULMAX2-NEXT:    vsetivli a0, 8, e16,m1,ta,mu
+; LMULMAX2-NEXT:    vnsrl.wi v25, v26, 0
+; LMULMAX2-NEXT:    vsetivli a0, 8, e8,mf2,ta,mu
+; LMULMAX2-NEXT:    vnsrl.wi v26, v25, 0
+; LMULMAX2-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX2-NEXT:    vse8.v v26, (a1)
+; LMULMAX2-NEXT:    ret
+;
+; LMULMAX1-LABEL: trunc_v8i8_v8i32:
+; LMULMAX1:       # %bb.0:
+; LMULMAX1-NEXT:    addi sp, sp, -16
+; LMULMAX1-NEXT:    .cfi_def_cfa_offset 16
+; LMULMAX1-NEXT:    vsetivli a2, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    addi a2, a0, 16
+; LMULMAX1-NEXT:    vle32.v v25, (a2)
+; LMULMAX1-NEXT:    vle32.v v26, (a0)
+; LMULMAX1-NEXT:    vsetivli a0, 4, e16,mf2,ta,mu
+; LMULMAX1-NEXT:    vnsrl.wi v27, v25, 0
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,mf4,ta,mu
+; LMULMAX1-NEXT:    vnsrl.wi v25, v27, 0
+; LMULMAX1-NEXT:    addi a0, sp, 12
+; LMULMAX1-NEXT:    vsetivli a2, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vse8.v v25, (a0)
+; LMULMAX1-NEXT:    vsetivli a0, 4, e16,mf2,ta,mu
+; LMULMAX1-NEXT:    vnsrl.wi v25, v26, 0
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,mf4,ta,mu
+; LMULMAX1-NEXT:    vnsrl.wi v26, v25, 0
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    addi a0, sp, 8
+; LMULMAX1-NEXT:    vse8.v v26, (a0)
+; LMULMAX1-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX1-NEXT:    addi a0, sp, 8
+; LMULMAX1-NEXT:    vle8.v v25, (a0)
+; LMULMAX1-NEXT:    vse8.v v25, (a1)
+; LMULMAX1-NEXT:    addi sp, sp, 16
+; LMULMAX1-NEXT:    ret
+  %a = load <8 x i32>, <8 x i32>* %x
+  %b = trunc <8 x i32> %a to <8 x i8>
+  store <8 x i8> %b, <8 x i8>* %z
+  ret void
+}


        


More information about the llvm-commits mailing list