[llvm] 3bc5ed3 - [RISCV] Support fixed-length vector sign/zero extension

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 25 04:11:31 PST 2021


Author: Fraser Cormack
Date: 2021-02-25T12:05:17Z
New Revision: 3bc5ed38750c6a6daff39ad524b75e40c8c09183

URL: https://github.com/llvm/llvm-project/commit/3bc5ed38750c6a6daff39ad524b75e40c8c09183
DIFF: https://github.com/llvm/llvm-project/commit/3bc5ed38750c6a6daff39ad524b75e40c8c09183.diff

LOG: [RISCV] Support fixed-length vector sign/zero extension

This patch adds support for the custom lowering sign- and zero-extension
of fixed-length vector types. It does so through custom nodes. Since the
source and destination types are (necessarily) of different sizes, it is
possible that the source type is legal whilst the larger destination
type isn't. In this case the legalization makes heavy use of
EXTRACT_SUBVECTOR.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D97194

Added: 
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2254d93f9acc..abd9a3d6d8c4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1207,9 +1207,15 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   }
   case ISD::ANY_EXTEND:
   case ISD::ZERO_EXTEND:
-    return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
+    if (Op.getOperand(0).getValueType().isVector() &&
+        Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
+      return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ 1);
+    return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VZEXT_VL);
   case ISD::SIGN_EXTEND:
-    return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
+    if (Op.getOperand(0).getValueType().isVector() &&
+        Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
+      return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
+    return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
   case ISD::SPLAT_VECTOR:
     return lowerSPLATVECTOR(Op, DAG);
   case ISD::INSERT_VECTOR_ELT:
@@ -1885,9 +1891,8 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
   MVT VecVT = Op.getSimpleValueType();
   SDValue Src = Op.getOperand(0);
   // Only custom-lower extensions from mask types
-  if (!Src.getValueType().isVector() ||
-      Src.getValueType().getVectorElementType() != MVT::i1)
-    return Op;
+  assert(Src.getValueType().isVector() &&
+         Src.getValueType().getVectorElementType() == MVT::i1);
 
   MVT XLenVT = Subtarget.getXLenVT();
   SDValue SplatZero = DAG.getConstant(0, DL, XLenVT);
@@ -1932,6 +1937,35 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
   return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
 }
 
+SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
+    SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const {
+  MVT ExtVT = Op.getSimpleValueType();
+  // Only custom-lower extensions from fixed-length vector types.
+  if (!ExtVT.isFixedLengthVector())
+    return Op;
+  MVT VT = Op.getOperand(0).getSimpleValueType();
+  // Grab the canonical container type for the extended type. Infer the smaller
+  // type from that to ensure the same number of vector elements, as we know
+  // the LMUL will be sufficient to hold the smaller type.
+  MVT ContainerExtVT = RISCVTargetLowering::getContainerForFixedLengthVector(
+      DAG, ExtVT, Subtarget);
+  // Get the extended container type manually to ensure the same number of
+  // vector elements between source and dest.
+  MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
+                                     ContainerExtVT.getVectorElementCount());
+
+  SDValue Op1 =
+      convertToScalableVector(ContainerVT, Op.getOperand(0), DAG, Subtarget);
+
+  SDLoc DL(Op);
+  SDValue Mask, VL;
+  std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+  SDValue Ext = DAG.getNode(ExtendOpc, DL, ContainerExtVT, Op1, Mask, VL);
+
+  return convertFromScalableVector(ExtVT, Ext, DAG, Subtarget);
+}
+
 // Custom-lower truncations from vectors to mask vectors by using a mask and a
 // setcc operation:
 //   (vXi1 = trunc vXiN vec) -> (vXi1 = setcc (and vec, 1), 0, ne)
@@ -5453,6 +5487,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VMCLR_VL)
   NODE_NAME_CASE(VMSET_VL)
   NODE_NAME_CASE(VRGATHER_VX_VL)
+  NODE_NAME_CASE(VSEXT_VL)
+  NODE_NAME_CASE(VZEXT_VL)
   NODE_NAME_CASE(VLE_VL)
   NODE_NAME_CASE(VSE_VL)
   }

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index cbfc56192374..dc7e05ea6704 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -191,6 +191,10 @@ enum NodeType : unsigned {
   // Matches the semantics of vrgather.vx with an extra operand for VL.
   VRGATHER_VX_VL,
 
+  // Vector sign/zero extend with additional mask & VL operands.
+  VSEXT_VL,
+  VZEXT_VL,
+
   // Memory opcodes start here.
   VLE_VL = ISD::FIRST_TARGET_MEMORY_OPCODE,
   VSE_VL,
@@ -431,6 +435,8 @@ class RISCVTargetLowering : public TargetLowering {
                                             SelectionDAG &DAG) const;
   SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc,
                             bool HasMask = true) const;
+  SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
+                                            unsigned ExtendOpc) const;
 
   bool isEligibleForTailCallOptimization(
       CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index b7c08d5b6cbd..76eb5f68a0c4 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -140,6 +140,14 @@ def true_mask : PatLeaf<(riscv_vmset_vl (XLenVT srcvalue))>;
 def riscv_vmnot_vl : PatFrag<(ops node:$rs, node:$vl),
                              (riscv_vmxor_vl node:$rs, true_mask, node:$vl)>;
 
+def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
+                                               SDTCisSameNumEltsAs<0, 1>,
+                                               SDTCisSameNumEltsAs<1, 2>,
+                                               SDTCVecEltisVT<2, i1>,
+                                               SDTCisVT<3, XLenVT>]>;
+def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
+def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
+
 // Ignore the vl operand.
 def SplatFPOp : PatFrag<(ops node:$op),
                         (riscv_vfmv_v_f_vl node:$op, srcvalue)>;
@@ -352,6 +360,18 @@ multiclass VPatFPSetCCVL_VV_VF_FV<CondCode cc,
   }
 }
 
+multiclass VPatExtendSDNode_V_VL<SDNode vop, string inst_name, string suffix,
+                                 list <VTypeInfoToFraction> fraction_list> {
+  foreach vtiTofti = fraction_list in {
+    defvar vti = vtiTofti.Vti;
+    defvar fti = vtiTofti.Fti;
+    def : Pat<(vti.Vector (vop (fti.Vector fti.RegClass:$rs2),
+                               true_mask, (XLenVT (VLOp GPR:$vl)))),
+              (!cast<Instruction>(inst_name#"_"#suffix#"_"#vti.LMul.MX)
+                  fti.RegClass:$rs2, GPR:$vl, vti.SEW)>;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Patterns.
 //===----------------------------------------------------------------------===//
@@ -399,6 +419,20 @@ foreach vti = AllIntegerVectors in {
                  vti.RegClass:$rs1, simm5:$rs2, GPR:$vl, vti.SEW)>;
 }
 
+// 12.3. Vector Integer Extension
+defm "" : VPatExtendSDNode_V_VL<riscv_zext_vl, "PseudoVZEXT", "VF2",
+                                AllFractionableVF2IntVectors>;
+defm "" : VPatExtendSDNode_V_VL<riscv_sext_vl, "PseudoVSEXT", "VF2",
+                                AllFractionableVF2IntVectors>;
+defm "" : VPatExtendSDNode_V_VL<riscv_zext_vl, "PseudoVZEXT", "VF4",
+                                AllFractionableVF4IntVectors>;
+defm "" : VPatExtendSDNode_V_VL<riscv_sext_vl, "PseudoVSEXT", "VF4",
+                                AllFractionableVF4IntVectors>;
+defm "" : VPatExtendSDNode_V_VL<riscv_zext_vl, "PseudoVZEXT", "VF8",
+                                AllFractionableVF8IntVectors>;
+defm "" : VPatExtendSDNode_V_VL<riscv_sext_vl, "PseudoVSEXT", "VF8",
+                                AllFractionableVF8IntVectors>;
+
 // 12.5. Vector Bitwise Logical Instructions
 defm "" : VPatBinaryVL_VV_VX_VI<riscv_and_vl, "PseudoVAND">;
 defm "" : VPatBinaryVL_VV_VX_VI<riscv_or_vl,  "PseudoVOR">;

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
new file mode 100644
index 000000000000..ad04c1ad7ba0
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll
@@ -0,0 +1,167 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=8 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX8
+; RUN: llc -mtriple=riscv64 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=8 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX8
+; RUN: llc -mtriple=riscv32 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=2 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX2
+; RUN: llc -mtriple=riscv64 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=2 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX2
+; RUN: llc -mtriple=riscv32 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=1 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX1
+; RUN: llc -mtriple=riscv64 -mattr=+m,+experimental-v -verify-machineinstrs -riscv-v-vector-bits-min=128 -riscv-v-fixed-length-vector-lmul-max=1 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,LMULMAX1
+
+define void @sext_v4i8_v4i32(<4 x i8>* %x, <4 x i32>* %z) {
+; CHECK-LABEL: sext_v4i8_v4i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli a2, 4, e8,m1,ta,mu
+; CHECK-NEXT:    vle8.v v25, (a0)
+; CHECK-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; CHECK-NEXT:    vsext.vf4 v26, v25
+; CHECK-NEXT:    vse32.v v26, (a1)
+; CHECK-NEXT:    ret
+  %a = load <4 x i8>, <4 x i8>* %x
+  %b = sext <4 x i8> %a to <4 x i32>
+  store <4 x i32> %b, <4 x i32>* %z
+  ret void
+}
+
+define void @zext_v4i8_v4i32(<4 x i8>* %x, <4 x i32>* %z) {
+; CHECK-LABEL: zext_v4i8_v4i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli a2, 4, e8,m1,ta,mu
+; CHECK-NEXT:    vle8.v v25, (a0)
+; CHECK-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; CHECK-NEXT:    vzext.vf4 v26, v25
+; CHECK-NEXT:    vse32.v v26, (a1)
+; CHECK-NEXT:    ret
+  %a = load <4 x i8>, <4 x i8>* %x
+  %b = zext <4 x i8> %a to <4 x i32>
+  store <4 x i32> %b, <4 x i32>* %z
+  ret void
+}
+
+define void @sext_v8i8_v8i32(<8 x i8>* %x, <8 x i32>* %z) {
+; LMULMAX8-LABEL: sext_v8i8_v8i32:
+; LMULMAX8:       # %bb.0:
+; LMULMAX8-NEXT:    vsetivli a2, 8, e8,m1,ta,mu
+; LMULMAX8-NEXT:    vle8.v v25, (a0)
+; LMULMAX8-NEXT:    vsetivli a0, 8, e32,m2,ta,mu
+; LMULMAX8-NEXT:    vsext.vf4 v26, v25
+; LMULMAX8-NEXT:    vse32.v v26, (a1)
+; LMULMAX8-NEXT:    ret
+;
+; LMULMAX2-LABEL: sext_v8i8_v8i32:
+; LMULMAX2:       # %bb.0:
+; LMULMAX2-NEXT:    vsetivli a2, 8, e8,m1,ta,mu
+; LMULMAX2-NEXT:    vle8.v v25, (a0)
+; LMULMAX2-NEXT:    vsetivli a0, 8, e32,m2,ta,mu
+; LMULMAX2-NEXT:    vsext.vf4 v26, v25
+; LMULMAX2-NEXT:    vse32.v v26, (a1)
+; LMULMAX2-NEXT:    ret
+;
+; LMULMAX1-LABEL: sext_v8i8_v8i32:
+; LMULMAX1:       # %bb.0:
+; LMULMAX1-NEXT:    vsetivli a2, 8, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vle8.v v25, (a0)
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v26, v25
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v25, v25, 4
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v27, v25
+; LMULMAX1-NEXT:    addi a0, a1, 16
+; LMULMAX1-NEXT:    vse32.v v27, (a0)
+; LMULMAX1-NEXT:    vse32.v v26, (a1)
+; LMULMAX1-NEXT:    ret
+  %a = load <8 x i8>, <8 x i8>* %x
+  %b = sext <8 x i8> %a to <8 x i32>
+  store <8 x i32> %b, <8 x i32>* %z
+  ret void
+}
+
+define void @sext_v32i8_v32i32(<32 x i8>* %x, <32 x i32>* %z) {
+; LMULMAX8-LABEL: sext_v32i8_v32i32:
+; LMULMAX8:       # %bb.0:
+; LMULMAX8-NEXT:    addi a2, zero, 32
+; LMULMAX8-NEXT:    vsetvli a3, a2, e8,m2,ta,mu
+; LMULMAX8-NEXT:    vle8.v v26, (a0)
+; LMULMAX8-NEXT:    vsetvli a0, a2, e32,m8,ta,mu
+; LMULMAX8-NEXT:    vsext.vf4 v8, v26
+; LMULMAX8-NEXT:    vse32.v v8, (a1)
+; LMULMAX8-NEXT:    ret
+;
+; LMULMAX2-LABEL: sext_v32i8_v32i32:
+; LMULMAX2:       # %bb.0:
+; LMULMAX2-NEXT:    addi a2, zero, 32
+; LMULMAX2-NEXT:    vsetvli a2, a2, e8,m2,ta,mu
+; LMULMAX2-NEXT:    vle8.v v26, (a0)
+; LMULMAX2-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX2-NEXT:    vslidedown.vi v25, v26, 8
+; LMULMAX2-NEXT:    vsetivli a0, 8, e32,m2,ta,mu
+; LMULMAX2-NEXT:    vsext.vf4 v28, v25
+; LMULMAX2-NEXT:    vsetivli a0, 16, e8,m2,ta,mu
+; LMULMAX2-NEXT:    vslidedown.vi v30, v26, 16
+; LMULMAX2-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX2-NEXT:    vslidedown.vi v25, v30, 8
+; LMULMAX2-NEXT:    vsetivli a0, 8, e32,m2,ta,mu
+; LMULMAX2-NEXT:    vsext.vf4 v8, v25
+; LMULMAX2-NEXT:    vsext.vf4 v10, v26
+; LMULMAX2-NEXT:    vsext.vf4 v26, v30
+; LMULMAX2-NEXT:    addi a0, a1, 64
+; LMULMAX2-NEXT:    vse32.v v26, (a0)
+; LMULMAX2-NEXT:    vse32.v v10, (a1)
+; LMULMAX2-NEXT:    addi a0, a1, 96
+; LMULMAX2-NEXT:    vse32.v v8, (a0)
+; LMULMAX2-NEXT:    addi a0, a1, 32
+; LMULMAX2-NEXT:    vse32.v v28, (a0)
+; LMULMAX2-NEXT:    ret
+;
+; LMULMAX1-LABEL: sext_v32i8_v32i32:
+; LMULMAX1:       # %bb.0:
+; LMULMAX1-NEXT:    vsetivli a2, 16, e8,m1,ta,mu
+; LMULMAX1-NEXT:    addi a2, a0, 16
+; LMULMAX1-NEXT:    vle8.v v25, (a2)
+; LMULMAX1-NEXT:    vle8.v v26, (a0)
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v27, v25, 4
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v28, v27
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v27, v26, 4
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v29, v27
+; LMULMAX1-NEXT:    vsext.vf4 v27, v25
+; LMULMAX1-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v25, v25, 8
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v30, v25
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v25, v25, 4
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v31, v25
+; LMULMAX1-NEXT:    vsext.vf4 v25, v26
+; LMULMAX1-NEXT:    vsetivli a0, 8, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v26, v26, 8
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v8, v26
+; LMULMAX1-NEXT:    vsetivli a0, 4, e8,m1,ta,mu
+; LMULMAX1-NEXT:    vslidedown.vi v26, v26, 4
+; LMULMAX1-NEXT:    vsetivli a0, 4, e32,m1,ta,mu
+; LMULMAX1-NEXT:    vsext.vf4 v9, v26
+; LMULMAX1-NEXT:    addi a0, a1, 48
+; LMULMAX1-NEXT:    vse32.v v9, (a0)
+; LMULMAX1-NEXT:    addi a0, a1, 32
+; LMULMAX1-NEXT:    vse32.v v8, (a0)
+; LMULMAX1-NEXT:    vse32.v v25, (a1)
+; LMULMAX1-NEXT:    addi a0, a1, 112
+; LMULMAX1-NEXT:    vse32.v v31, (a0)
+; LMULMAX1-NEXT:    addi a0, a1, 96
+; LMULMAX1-NEXT:    vse32.v v30, (a0)
+; LMULMAX1-NEXT:    addi a0, a1, 64
+; LMULMAX1-NEXT:    vse32.v v27, (a0)
+; LMULMAX1-NEXT:    addi a0, a1, 16
+; LMULMAX1-NEXT:    vse32.v v29, (a0)
+; LMULMAX1-NEXT:    addi a0, a1, 80
+; LMULMAX1-NEXT:    vse32.v v28, (a0)
+; LMULMAX1-NEXT:    ret
+  %a = load <32 x i8>, <32 x i8>* %x
+  %b = sext <32 x i8> %a to <32 x i32>
+  store <32 x i32> %b, <32 x i32>* %z
+  ret void
+}


        


More information about the llvm-commits mailing list