[llvm] [RISCV] Extract subreg directly in lowerINSERT_SUBVECTOR (PR #81838)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 15 01:26:39 PST 2024


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/81838

In order to shrink the VSLIDEUP's LMUL, we were manually computing an index to
extract a smaller vector from, relying on it resolving to an EXTRACT_SUBREG.

However the subreg index to extract is already given to us by
decomposeSubvectorInsertExtractToSubRegs. This changes it to emit an
EXTRACT_SUBREG directly with said index.

Note that EXTRACT_SUBVECTOR folds undefs, but EXTRACT_SUBREG doesn't. So to
prevent a ta->tu regression this also incldues a change to RISCVInsertVSETVLI
to detect (EXTRACT_SUBREG undef) as undefined. On its own it's NFC, so if
wanted I'd be happy to split it out.

It's almost NFC, but there's one test case where this seems to eliminate a
copy.


>From 1b8ab4353d24c773dac90cd9a99607dc6eb0ac69 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 15 Feb 2024 15:55:05 +0800
Subject: [PATCH] [RISCV] Extract subreg directly in lowerINSERT_SUBVECTOR

In order to shrink the VSLIDEUP's LMUL, we were manually computing an index to
extract a smaller vector from, relying on it resolving to an EXTRACT_SUBREG.

However the subreg index to extract is already given to us by
decomposeSubvectorInsertExtractToSubRegs. This changes it to emit an
EXTRACT_SUBREG directly with said index.

Note that EXTRACT_SUBVECTOR folds undefs, but EXTRACT_SUBREG doesn't. So to
prevent a ta->tu regression this also incldues a change to RISCVInsertVSETVLI
to detect (EXTRACT_SUBREG undef) as undefined. On its own it's NFC, so if
wanted I'd be happy to split it out.

It's almost NFC, but there's one test case where this seems to eliminate a
copy.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 11 +++------
 llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp  | 24 +++++++++++--------
 .../CodeGen/RISCV/rvv/extract-subvector.ll    |  9 ++++---
 3 files changed, 21 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8235b536c4e00a..f72b9ad6e948f1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -9671,13 +9671,10 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
   // size of the subvector.
   MVT InterSubVT = VecVT;
   SDValue AlignedExtract = Vec;
-  unsigned AlignedIdx = OrigIdx - RemIdx;
   if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
     InterSubVT = getLMUL1VT(VecVT);
-    // Extract a subvector equal to the nearest full vector register type. This
-    // should resolve to a EXTRACT_SUBREG instruction.
-    AlignedExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
-                                 DAG.getConstant(AlignedIdx, DL, XLenVT));
+    // Extract a subvector equal to the nearest full vector register type.
+    AlignedExtract = DAG.getTargetExtractSubreg(SubRegIdx, DL, InterSubVT, Vec);
   }
 
   SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InterSubVT,
@@ -9705,10 +9702,8 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
   }
 
   // If required, insert this subvector back into the correct vector register.
-  // This should resolve to an INSERT_SUBREG instruction.
   if (VecVT.bitsGT(InterSubVT))
-    SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT, Vec, SubVec,
-                         DAG.getConstant(AlignedIdx, DL, XLenVT));
+    SubVec = DAG.getTargetInsertSubreg(SubRegIdx, DL, VecVT, Vec, SubVec);
 
   // We might have bitcast from a mask type: cast back to the original type if
   // required.
diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
index a14f9a28354737..0e8e040be99baa 100644
--- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
@@ -190,18 +190,22 @@ static bool hasUndefinedMergeOp(const MachineInstr &MI,
   if (UseMO.getReg() == RISCV::NoRegister)
     return true;
 
-  if (MachineInstr *UseMI = MRI.getVRegDef(UseMO.getReg())) {
-    if (UseMI->isImplicitDef())
-      return true;
+  Register SrcReg =
+      MRI.getTargetRegisterInfo()->lookThruCopyLike(UseMO.getReg(), &MRI);
+  if (SrcReg.isPhysical())
+    return false;
 
-    if (UseMI->isRegSequence()) {
-      for (unsigned i = 1, e = UseMI->getNumOperands(); i < e; i += 2) {
-        MachineInstr *SourceMI = MRI.getVRegDef(UseMI->getOperand(i).getReg());
-        if (!SourceMI || !SourceMI->isImplicitDef())
-          return false;
-      }
-      return true;
+  MachineInstr *UseMI = MRI.getUniqueVRegDef(SrcReg);
+  if (UseMI->isImplicitDef())
+    return true;
+
+  if (UseMI->isRegSequence()) {
+    for (unsigned i = 1, e = UseMI->getNumOperands(); i < e; i += 2) {
+      MachineInstr *SourceMI = MRI.getVRegDef(UseMI->getOperand(i).getReg());
+      if (!SourceMI || !SourceMI->isImplicitDef())
+        return false;
     }
+    return true;
   }
   return false;
 }
diff --git a/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll b/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
index a2d02b6bb641b2..77ea6c0b26d0a0 100644
--- a/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
@@ -469,14 +469,13 @@ define <vscale x 6 x half> @extract_nxv6f16_nxv12f16_6(<vscale x 12 x half> %in)
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    csrr a0, vlenb
 ; CHECK-NEXT:    srli a0, a0, 2
-; CHECK-NEXT:    vsetvli zero, a0, e16, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vx v13, v10, a0
 ; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vslidedown.vx v12, v9, a0
+; CHECK-NEXT:    vslidedown.vx v8, v9, a0
 ; CHECK-NEXT:    add a1, a0, a0
 ; CHECK-NEXT:    vsetvli zero, a1, e16, m1, tu, ma
-; CHECK-NEXT:    vslideup.vx v12, v10, a0
-; CHECK-NEXT:    vmv2r.v v8, v12
+; CHECK-NEXT:    vslideup.vx v8, v10, a0
+; CHECK-NEXT:    vsetvli zero, a0, e16, m1, tu, ma
+; CHECK-NEXT:    vslidedown.vx v9, v10, a0
 ; CHECK-NEXT:    ret
   %res = call <vscale x 6 x half> @llvm.vector.extract.nxv6f16.nxv12f16(<vscale x 12 x half> %in, i64 6)
   ret <vscale x 6 x half> %res



More information about the llvm-commits mailing list