[llvm] [RISCV] Work on subreg for insert_vector_elt when vlen is known (#72666) (PR #73680)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 28 10:19:10 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

If we have a constant index and a known vlen, then we can identify which registers out of a register group is being accessed.  Given this, we can reuse the (slightly generalized) existing handling for working on sub-register groups.  This results in all constant index extracts with known vlen becoming m1 operations.

One bit of weirdness to highlight and explain: the existing code uses the VL from the original vector type, not the inner vector type.  This is correct because the inner register group must be smaller than the original (possibly fixed length) vector type.  Overall, this seems to a reasonable codegen tradeoff as it biases us towards immediate AVLs, which avoids needing the vsetvli form which clobbers a GPR for no real purpose.  The downside is that for large fixed length vectors, we end up materializing an immediate in register for little value.  We should probably generalize this idea and try to optimize the large fixed length vector case, but that can be done in separate work.

---
Full diff: https://github.com/llvm/llvm-project/pull/73680.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+36-12) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll (+18-22) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bd5b1a879f32b9b..72b2e5e78c2991c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -7739,17 +7739,41 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
   }
 
-  MVT OrigContainerVT = ContainerVT;
-  SDValue OrigVec = Vec;
   // If we know the index we're going to insert at, we can shrink Vec so that
   // we're performing the scalar inserts and slideup on a smaller LMUL.
-  if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) {
-    if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, CIdx->getZExtValue(),
+  MVT OrigContainerVT = ContainerVT;
+  SDValue OrigVec = Vec;
+  SDValue AlignedIdx;
+  if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx)) {
+    const unsigned OrigIdx = IdxC->getZExtValue();
+    // Do we know an upper bound on LMUL?
+    if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, OrigIdx,
                                               DL, DAG, Subtarget)) {
       ContainerVT = *ShrunkVT;
+      AlignedIdx = DAG.getVectorIdxConstant(0, DL);
+    }
+
+    // If we're compiling for an exact VLEN value, we can always perform
+    // the insert in m1 as we can determine the register corresponding to
+    // the index in the register group.
+    const unsigned MinVLen = Subtarget.getRealMinVLen();
+    const unsigned MaxVLen = Subtarget.getRealMaxVLen();
+    const MVT M1VT = getLMUL1VT(ContainerVT);
+    if (MinVLen == MaxVLen && ContainerVT.bitsGT(M1VT)) {
+      EVT ElemVT = VecVT.getVectorElementType();
+      unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
+      unsigned RemIdx = OrigIdx % ElemsPerVReg;
+      unsigned SubRegIdx = OrigIdx / ElemsPerVReg;
+      unsigned ExtractIdx =
+          SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue();
+      AlignedIdx = DAG.getVectorIdxConstant(ExtractIdx, DL);
+      Idx = DAG.getVectorIdxConstant(RemIdx, DL);
+      ContainerVT = M1VT;
+    }
+
+    if (AlignedIdx)
       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
-                        DAG.getVectorIdxConstant(0, DL));
-    }
+                        AlignedIdx);
   }
 
   MVT XLenVT = Subtarget.getXLenVT();
@@ -7779,9 +7803,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
         Val = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Val);
       Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL);
 
-      if (ContainerVT != OrigContainerVT)
+      if (AlignedIdx)
         Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
-                          Vec, DAG.getVectorIdxConstant(0, DL));
+                          Vec, AlignedIdx);
       if (!VecVT.isFixedLengthVector())
         return Vec;
       return convertFromScalableVector(VecVT, Vec, DAG, Subtarget);
@@ -7814,10 +7838,10 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
       // Bitcast back to the right container type.
       ValInVec = DAG.getBitcast(ContainerVT, ValInVec);
 
-      if (ContainerVT != OrigContainerVT)
+      if (AlignedIdx)
         ValInVec =
             DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
-                        ValInVec, DAG.getVectorIdxConstant(0, DL));
+                        ValInVec, AlignedIdx);
       if (!VecVT.isFixedLengthVector())
         return ValInVec;
       return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget);
@@ -7849,9 +7873,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
   SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec,
                                 Idx, Mask, InsertVL, Policy);
 
-  if (ContainerVT != OrigContainerVT)
+  if (AlignedIdx)
     Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
-                          Slideup, DAG.getVectorIdxConstant(0, DL));
+                          Slideup, AlignedIdx);
   if (!VecVT.isFixedLengthVector())
     return Slideup;
   return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
index de5c4fbc0876439..a3f41fd842222cc 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll
@@ -614,9 +614,8 @@ define <16 x i32> @insertelt_c3_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_ra
 define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
 ; CHECK-LABEL: insertelt_c12_v16xi32_exact:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 13, e32, m4, tu, ma
-; CHECK-NEXT:    vmv.s.x v12, a0
-; CHECK-NEXT:    vslideup.vi v8, v12, 12
+; CHECK-NEXT:    vsetivli zero, 16, e32, m1, tu, ma
+; CHECK-NEXT:    vmv.s.x v11, a0
 ; CHECK-NEXT:    ret
   %v = insertelement <16 x i32> %vin, i32 %a, i32 12
   ret <16 x i32> %v
@@ -625,9 +624,9 @@ define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
 define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
 ; CHECK-LABEL: insertelt_c13_v16xi32_exact:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 14, e32, m4, tu, ma
+; CHECK-NEXT:    vsetivli zero, 2, e32, m1, tu, ma
 ; CHECK-NEXT:    vmv.s.x v12, a0
-; CHECK-NEXT:    vslideup.vi v8, v12, 13
+; CHECK-NEXT:    vslideup.vi v11, v12, 1
 ; CHECK-NEXT:    ret
   %v = insertelement <16 x i32> %vin, i32 %a, i32 13
   ret <16 x i32> %v
@@ -636,9 +635,9 @@ define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
 define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
 ; CHECK-LABEL: insertelt_c14_v16xi32_exact:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 15, e32, m4, tu, ma
+; CHECK-NEXT:    vsetivli zero, 3, e32, m1, tu, ma
 ; CHECK-NEXT:    vmv.s.x v12, a0
-; CHECK-NEXT:    vslideup.vi v8, v12, 14
+; CHECK-NEXT:    vslideup.vi v11, v12, 2
 ; CHECK-NEXT:    ret
   %v = insertelement <16 x i32> %vin, i32 %a, i32 14
   ret <16 x i32> %v
@@ -647,9 +646,9 @@ define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
 define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
 ; CHECK-LABEL: insertelt_c15_v16xi32_exact:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, tu, ma
 ; CHECK-NEXT:    vmv.s.x v12, a0
-; CHECK-NEXT:    vslideup.vi v8, v12, 15
+; CHECK-NEXT:    vslideup.vi v11, v12, 3
 ; CHECK-NEXT:    ret
   %v = insertelement <16 x i32> %vin, i32 %a, i32 15
   ret <16 x i32> %v
@@ -658,18 +657,15 @@ define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
 define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) {
 ; RV32-LABEL: insertelt_c4_v8xi64_exact:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 2, e32, m4, ta, ma
-; RV32-NEXT:    vslide1down.vx v12, v8, a0
-; RV32-NEXT:    vslide1down.vx v12, v12, a1
-; RV32-NEXT:    vsetivli zero, 5, e64, m4, tu, ma
-; RV32-NEXT:    vslideup.vi v8, v12, 4
+; RV32-NEXT:    vsetivli zero, 2, e32, m1, tu, ma
+; RV32-NEXT:    vslide1down.vx v10, v10, a0
+; RV32-NEXT:    vslide1down.vx v10, v10, a1
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: insertelt_c4_v8xi64_exact:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 5, e64, m4, tu, ma
-; RV64-NEXT:    vmv.s.x v12, a0
-; RV64-NEXT:    vslideup.vi v8, v12, 4
+; RV64-NEXT:    vsetivli zero, 8, e64, m1, tu, ma
+; RV64-NEXT:    vmv.s.x v10, a0
 ; RV64-NEXT:    ret
   %v = insertelement <8 x i64> %vin, i64 %a, i32 4
   ret <8 x i64> %v
@@ -678,18 +674,18 @@ define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range
 define <8 x i64> @insertelt_c5_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) {
 ; RV32-LABEL: insertelt_c5_v8xi64_exact:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 2, e32, m4, ta, ma
+; RV32-NEXT:    vsetivli zero, 2, e32, m1, ta, ma
 ; RV32-NEXT:    vslide1down.vx v12, v8, a0
 ; RV32-NEXT:    vslide1down.vx v12, v12, a1
-; RV32-NEXT:    vsetivli zero, 6, e64, m4, tu, ma
-; RV32-NEXT:    vslideup.vi v8, v12, 5
+; RV32-NEXT:    vsetivli zero, 2, e64, m1, tu, ma
+; RV32-NEXT:    vslideup.vi v10, v12, 1
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: insertelt_c5_v8xi64_exact:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 6, e64, m4, tu, ma
+; RV64-NEXT:    vsetivli zero, 2, e64, m1, tu, ma
 ; RV64-NEXT:    vmv.s.x v12, a0
-; RV64-NEXT:    vslideup.vi v8, v12, 5
+; RV64-NEXT:    vslideup.vi v10, v12, 1
 ; RV64-NEXT:    ret
   %v = insertelement <8 x i64> %vin, i64 %a, i32 5
   ret <8 x i64> %v

``````````

</details>


https://github.com/llvm/llvm-project/pull/73680


More information about the llvm-commits mailing list