[llvm] [RISCV] Shrink vslidedown when lowering fixed extract_subvector (PR #65598)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 11 09:18:59 PDT 2023


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/65598:

>From e3534d365e3fb0bcc72715a3b955e0d84d71b4a9 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 7 Sep 2023 11:01:05 +0100
Subject: [PATCH 1/3] [RISCV] Shrink vslidedown when lowering fixed
 extract_subvector

As noted in
https://github.com/llvm/llvm-project/pull/65392#discussion_r1316259471, when
lowering an extract of a fixed length vector from another vector, we don't need
to perform the vslidedown on the full vector type. Instead we can extract the
smallest subregister that contains the subvector to be extracted and perform
the vslidedown with a smaller LMUL. E.g, with +Zvl128b:

v2i64 = extract_subvector nxv4i64, 2

is currently lowered as

vsetivli zero, 2, e64, m4, ta, ma
vslidedown.vi v8, v8, 2

This patch shrinks the vslidedown to LMUL=2:

vsetivli zero, 2, e64, m2, ta, ma
vslidedown.vi v8, v8, 2

Because we know that there's at least 128*2=256 bits in v8 at LMUL=2, and we
only need the first 256 bits to extract a v2i64 at index 2.

lowerEXTRACT_VECTOR_ELT already has this logic, so this extracts it out and
reuses it.

I've split this out into a separate PR rather than include it in #65392, with
the hope that we'll be able to generalize it later.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 33 +++++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ad56bc757115f1d..68eeb584f79d000 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8751,6 +8751,39 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
       ContainerVT = getContainerForFixedLengthVector(VecVT);
       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
     }
+
+    // The minimum number of elements for a scalable vector type, e.g. nxv1i32
+    // is not legal on Zve32x.
+    const uint64_t MinLegalNumElts =
+        RISCV::RVVBitsPerBlock / Subtarget.getELen();
+    const uint64_t MinVscale =
+        Subtarget.getRealMinVLen() / RISCV::RVVBitsPerBlock;
+
+    // Even if we don't know the exact subregister the subvector is going to
+    // reside in, we know that the subvector is located within the first N bits
+    // of Vec:
+    //
+    // N = (OrigIdx + SubVecVT.getVectorNumElements()) * EltSizeInBits
+    //   = MinVscale * MinEltsNeeded * EltSizeInBits
+    //
+    // From this we can work out the smallest type that contains everything we
+    // need to extract, <vscale x MinEltsNeeded x Elt>
+    uint64_t MinEltsNeeded =
+        (OrigIdx + SubVecVT.getVectorNumElements()) / MinVscale;
+
+    // Round up the number of elements so it's a valid power of 2 scalable
+    // vector type, and make sure it's not less than smallest legal vector type.
+    MinEltsNeeded = std::max(MinLegalNumElts, PowerOf2Ceil(MinEltsNeeded));
+
+    assert(MinEltsNeeded <= ContainerVT.getVectorMinNumElements());
+
+    // Shrink down Vec so we're performing the slidedown on the smallest
+    // possible type.
+    ContainerVT = MVT::getScalableVectorVT(ContainerVT.getVectorElementType(),
+                                           MinEltsNeeded);
+    Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
+                      DAG.getVectorIdxConstant(0, DL));
+
     SDValue Mask =
         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
     // Set the vector length to only the number of elements we care about. This

>From b788713c283e708bf0520b7ed5236581c11b778d Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 8 Sep 2023 03:14:00 +0100
Subject: [PATCH 2/3] Share logic from extract_vector_elt

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 83 +++++++++----------
 .../rvv/fixed-vectors-extract-subvector.ll    | 20 ++---
 2 files changed, 47 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68eeb584f79d000..5e7813b612ce32f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -7466,6 +7466,32 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
   return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);
 }
 
+// Given a scalable vector type and an index into it, returns the type for the
+// smallest subvector that the index fits in. This can be used to reduce LMUL
+// for operations like vslidedown.
+//
+// E.g. With Zvl128b, index 3 in a nxv4i32 fits within the first nxv2i32.
+static std::optional<MVT>
+getSmallestVTForIndex(MVT VecVT, unsigned MaxIdx, SDLoc DL, SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) {
+  assert(VecVT.isScalableVector());
+  const unsigned EltSize = VecVT.getScalarSizeInBits();
+  const unsigned VectorBitsMin = Subtarget.getRealMinVLen();
+  const unsigned MinVLMAX = VectorBitsMin / EltSize;
+  MVT SmallerVT;
+  if (MaxIdx < MinVLMAX)
+    SmallerVT = getLMUL1VT(VecVT);
+  else if (MaxIdx < MinVLMAX * 2)
+    SmallerVT = getLMUL1VT(VecVT).getDoubleNumVectorElementsVT();
+  else if (MaxIdx < MinVLMAX * 4)
+    SmallerVT = getLMUL1VT(VecVT)
+                    .getDoubleNumVectorElementsVT()
+                    .getDoubleNumVectorElementsVT();
+  if (!SmallerVT.isValid() || !VecVT.bitsGT(SmallerVT))
+    return std::nullopt;
+  return SmallerVT;
+}
+
 // Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
 // extract the first element: (extractelt (slidedown vec, idx), 0). For integer
 // types this is done using VMV_X_S to allow us to glean information about the
@@ -7554,21 +7580,9 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
   if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx))
     MaxIdx = IdxC->getZExtValue();
   if (MaxIdx) {
-    const unsigned EltSize = ContainerVT.getScalarSizeInBits();
-    const unsigned VectorBitsMin = Subtarget.getRealMinVLen();
-    const unsigned MinVLMAX = VectorBitsMin/EltSize;
-    MVT SmallerVT;
-    if (*MaxIdx < MinVLMAX)
-      SmallerVT = getLMUL1VT(ContainerVT);
-    else if (*MaxIdx < MinVLMAX * 2)
-      SmallerVT = getLMUL1VT(ContainerVT)
-        .getDoubleNumVectorElementsVT();
-    else if (*MaxIdx < MinVLMAX * 4)
-      SmallerVT = getLMUL1VT(ContainerVT)
-        .getDoubleNumVectorElementsVT()
-        .getDoubleNumVectorElementsVT();
-    if (SmallerVT.isValid() && ContainerVT.bitsGT(SmallerVT)) {
-      ContainerVT = SmallerVT;
+    if (auto SmallerVT =
+            getSmallestVTForIndex(ContainerVT, *MaxIdx, DL, DAG, Subtarget)) {
+      ContainerVT = *SmallerVT;
       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
                         DAG.getConstant(0, DL, XLenVT));
     }
@@ -8752,37 +8766,14 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
     }
 
-    // The minimum number of elements for a scalable vector type, e.g. nxv1i32
-    // is not legal on Zve32x.
-    const uint64_t MinLegalNumElts =
-        RISCV::RVVBitsPerBlock / Subtarget.getELen();
-    const uint64_t MinVscale =
-        Subtarget.getRealMinVLen() / RISCV::RVVBitsPerBlock;
-
-    // Even if we don't know the exact subregister the subvector is going to
-    // reside in, we know that the subvector is located within the first N bits
-    // of Vec:
-    //
-    // N = (OrigIdx + SubVecVT.getVectorNumElements()) * EltSizeInBits
-    //   = MinVscale * MinEltsNeeded * EltSizeInBits
-    //
-    // From this we can work out the smallest type that contains everything we
-    // need to extract, <vscale x MinEltsNeeded x Elt>
-    uint64_t MinEltsNeeded =
-        (OrigIdx + SubVecVT.getVectorNumElements()) / MinVscale;
-
-    // Round up the number of elements so it's a valid power of 2 scalable
-    // vector type, and make sure it's not less than smallest legal vector type.
-    MinEltsNeeded = std::max(MinLegalNumElts, PowerOf2Ceil(MinEltsNeeded));
-
-    assert(MinEltsNeeded <= ContainerVT.getVectorMinNumElements());
-
-    // Shrink down Vec so we're performing the slidedown on the smallest
-    // possible type.
-    ContainerVT = MVT::getScalableVectorVT(ContainerVT.getVectorElementType(),
-                                           MinEltsNeeded);
-    Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
-                      DAG.getVectorIdxConstant(0, DL));
+    // Shrink down Vec so we're performing the slidedown a smaller LMUL.
+    unsigned LastIdx = OrigIdx + SubVecVT.getVectorNumElements() - 1;
+    if (auto ShrunkVT =
+            getSmallestVTForIndex(ContainerVT, LastIdx, DL, DAG, Subtarget)) {
+      ContainerVT = *ShrunkVT;
+      Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
+                        DAG.getVectorIdxConstant(0, DL));
+    }
 
     SDValue Mask =
         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-extract-subvector.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-extract-subvector.ll
index fa23b9a1b76fd3a..b4260b04604cd06 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-extract-subvector.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-extract-subvector.ll
@@ -113,7 +113,7 @@ define void @extract_v2i32_v8i32_2(ptr %x, ptr %y) {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
 ; CHECK-NEXT:    vle32.v v8, (a0)
-; CHECK-NEXT:    vsetivli zero, 2, e32, m2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e32, m1, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 2
 ; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a1)
@@ -171,7 +171,7 @@ define void @extract_v2i32_nxv16i32_0(<vscale x 16 x i32> %x, ptr %y) {
 define void @extract_v2i32_nxv16i32_2(<vscale x 16 x i32> %x, ptr %y) {
 ; CHECK-LABEL: extract_v2i32_nxv16i32_2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e32, m1, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 2
 ; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a0)
@@ -184,7 +184,7 @@ define void @extract_v2i32_nxv16i32_2(<vscale x 16 x i32> %x, ptr %y) {
 define void @extract_v2i32_nxv16i32_4(<vscale x 16 x i32> %x, ptr %y) {
 ; CHECK-LABEL: extract_v2i32_nxv16i32_4:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e32, m2, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 4
 ; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a0)
@@ -197,7 +197,7 @@ define void @extract_v2i32_nxv16i32_4(<vscale x 16 x i32> %x, ptr %y) {
 define void @extract_v2i32_nxv16i32_6(<vscale x 16 x i32> %x, ptr %y) {
 ; CHECK-LABEL: extract_v2i32_nxv16i32_6:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e32, m2, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 6
 ; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a0)
@@ -210,7 +210,7 @@ define void @extract_v2i32_nxv16i32_6(<vscale x 16 x i32> %x, ptr %y) {
 define void @extract_v2i32_nxv16i32_8(<vscale x 16 x i32> %x, ptr %y) {
 ; CHECK-LABEL: extract_v2i32_nxv16i32_8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e32, m4, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 8
 ; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a0)
@@ -273,7 +273,7 @@ define void @extract_v2i8_nxv2i8_6(<vscale x 2 x i8> %x, ptr %y) {
 define void @extract_v8i32_nxv16i32_8(<vscale x 16 x i32> %x, ptr %y) {
 ; CHECK-LABEL: extract_v8i32_nxv16i32_8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 8, e32, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 8, e32, m4, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 8
 ; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
 ; CHECK-NEXT:    vse32.v v8, (a0)
@@ -437,7 +437,7 @@ define void @extract_v2i1_v64i1_2(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vlm.v v0, (a0)
 ; CHECK-NEXT:    vmv.v.i v8, 0
 ; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
-; CHECK-NEXT:    vsetivli zero, 2, e8, m4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, m1, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 2
 ; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
@@ -555,7 +555,7 @@ define void @extract_v2i1_nxv64i1_2(<vscale x 64 x i1> %x, ptr %y) {
 ; CHECK-NEXT:    vsetvli a1, zero, e8, m8, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
 ; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
-; CHECK-NEXT:    vsetivli zero, 2, e8, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, m1, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 2
 ; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
@@ -581,7 +581,7 @@ define void @extract_v2i1_nxv64i1_42(<vscale x 64 x i1> %x, ptr %y) {
 ; CHECK-NEXT:    vmv.v.i v8, 0
 ; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
 ; CHECK-NEXT:    li a1, 42
-; CHECK-NEXT:    vsetivli zero, 2, e8, m8, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, m4, ta, ma
 ; CHECK-NEXT:    vslidedown.vx v8, v8, a1
 ; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0
@@ -606,7 +606,7 @@ define void @extract_v2i1_nxv32i1_26(<vscale x 32 x i1> %x, ptr %y) {
 ; CHECK-NEXT:    vsetvli a1, zero, e8, m4, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
 ; CHECK-NEXT:    vmerge.vim v8, v8, 1, v0
-; CHECK-NEXT:    vsetivli zero, 2, e8, m4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, m2, ta, ma
 ; CHECK-NEXT:    vslidedown.vi v8, v8, 26
 ; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vmsne.vi v0, v8, 0

>From 07e5fed3d95bef11992867ba46294cc237ec35ad Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 8 Sep 2023 17:50:26 +0100
Subject: [PATCH 3/3] Fix typo in comment

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5e7813b612ce32f..cc45f175f492d3d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8766,7 +8766,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
       Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
     }
 
-    // Shrink down Vec so we're performing the slidedown a smaller LMUL.
+    // Shrink down Vec so we're performing the slidedown on a smaller LMUL.
     unsigned LastIdx = OrigIdx + SubVecVT.getVectorNumElements() - 1;
     if (auto ShrunkVT =
             getSmallestVTForIndex(ContainerVT, LastIdx, DL, DAG, Subtarget)) {



More information about the llvm-commits mailing list