[llvm] 3e1317f - [RISCV] Support extraction of misaligned subvectors

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 20 07:51:00 PST 2021


Author: Fraser Cormack
Date: 2021-02-20T15:43:54Z
New Revision: 3e1317fd323bf92c6adaf67598697049b08bb373

URL: https://github.com/llvm/llvm-project/commit/3e1317fd323bf92c6adaf67598697049b08bb373
DIFF: https://github.com/llvm/llvm-project/commit/3e1317fd323bf92c6adaf67598697049b08bb373.diff

LOG: [RISCV] Support extraction of misaligned subvectors

This patch extends the support for RVV EXTRACT_SUBVECTOR to cover those
which don't align to a vector register boundary. It accomplishes this by
extracting the nearest register-sized subvector (a subregister
operation), then sliding the vector down with VSLIDEDOWN and extracting
the subvector from the first position (a COPY operation).

Since this procedure involves the use of VSCALE and multiplication, the
handling of such operations is done during lowering to simplify the
implementation and make use of DAG combining. This necessitated moving
some helper functions from RISCVISelDAGToDAG to RISCVTargetLowering.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D96959

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index adaf84959770..9b797bd509b0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -13,6 +13,7 @@
 #include "RISCVISelDAGToDAG.h"
 #include "MCTargetDesc/RISCVMCTargetDesc.h"
 #include "MCTargetDesc/RISCVMatInt.h"
+#include "RISCVISelLowering.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/IR/IntrinsicsRISCV.h"
 #include "llvm/Support/Alignment.h"
@@ -62,64 +63,6 @@ static SDNode *selectImm(SelectionDAG *CurDAG, const SDLoc &DL, int64_t Imm,
   return Result;
 }
 
-static RISCVVLMUL getLMUL(MVT VT) {
-  switch (VT.getSizeInBits().getKnownMinValue() / 8) {
-  default:
-    llvm_unreachable("Invalid LMUL.");
-  case 1:
-    return RISCVVLMUL::LMUL_F8;
-  case 2:
-    return RISCVVLMUL::LMUL_F4;
-  case 4:
-    return RISCVVLMUL::LMUL_F2;
-  case 8:
-    return RISCVVLMUL::LMUL_1;
-  case 16:
-    return RISCVVLMUL::LMUL_2;
-  case 32:
-    return RISCVVLMUL::LMUL_4;
-  case 64:
-    return RISCVVLMUL::LMUL_8;
-  }
-}
-
-static unsigned getRegClassIDForLMUL(RISCVVLMUL LMul) {
-  switch (LMul) {
-  default:
-    llvm_unreachable("Invalid LMUL.");
-  case RISCVVLMUL::LMUL_F8:
-  case RISCVVLMUL::LMUL_F4:
-  case RISCVVLMUL::LMUL_F2:
-  case RISCVVLMUL::LMUL_1:
-    return RISCV::VRRegClassID;
-  case RISCVVLMUL::LMUL_2:
-    return RISCV::VRM2RegClassID;
-  case RISCVVLMUL::LMUL_4:
-    return RISCV::VRM4RegClassID;
-  case RISCVVLMUL::LMUL_8:
-    return RISCV::VRM8RegClassID;
-  }
-}
-
-static unsigned getSubregIndexByMVT(MVT VT, unsigned Index) {
-  RISCVVLMUL LMUL = getLMUL(VT);
-  if (LMUL == RISCVVLMUL::LMUL_F8 || LMUL == RISCVVLMUL::LMUL_F4 ||
-      LMUL == RISCVVLMUL::LMUL_F2 || LMUL == RISCVVLMUL::LMUL_1) {
-    static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
-                  "Unexpected subreg numbering");
-    return RISCV::sub_vrm1_0 + Index;
-  } else if (LMUL == RISCVVLMUL::LMUL_2) {
-    static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
-                  "Unexpected subreg numbering");
-    return RISCV::sub_vrm2_0 + Index;
-  } else if (LMUL == RISCVVLMUL::LMUL_4) {
-    static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
-                  "Unexpected subreg numbering");
-    return RISCV::sub_vrm4_0 + Index;
-  }
-  llvm_unreachable("Invalid vector type.");
-}
-
 static SDValue createTupleImpl(SelectionDAG &CurDAG, ArrayRef<SDValue> Regs,
                                unsigned RegClassID, unsigned SubReg0) {
   assert(Regs.size() >= 2 && Regs.size() <= 8);
@@ -187,7 +130,7 @@ void RISCVDAGToDAGISel::selectVLSEG(SDNode *Node, bool IsMasked,
   MVT VT = Node->getSimpleValueType(0);
   unsigned ScalarSize = VT.getScalarSizeInBits();
   MVT XLenVT = Subtarget->getXLenVT();
-  RISCVVLMUL LMUL = getLMUL(VT);
+  RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
   SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
   unsigned CurOp = 2;
   SmallVector<SDValue, 7> Operands;
@@ -218,10 +161,11 @@ void RISCVDAGToDAGISel::selectVLSEG(SDNode *Node, bool IsMasked,
     CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()});
 
   SDValue SuperReg = SDValue(Load, 0);
-  for (unsigned I = 0; I < NF; ++I)
+  for (unsigned I = 0; I < NF; ++I) {
+    unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I);
     ReplaceUses(SDValue(Node, I),
-                CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL,
-                                               VT, SuperReg));
+                CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg));
+  }
 
   ReplaceUses(SDValue(Node, NF), SDValue(Load, 1));
   CurDAG->RemoveDeadNode(Node);
@@ -233,7 +177,7 @@ void RISCVDAGToDAGISel::selectVLSEGFF(SDNode *Node, bool IsMasked) {
   MVT VT = Node->getSimpleValueType(0);
   MVT XLenVT = Subtarget->getXLenVT();
   unsigned ScalarSize = VT.getScalarSizeInBits();
-  RISCVVLMUL LMUL = getLMUL(VT);
+  RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
   SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
 
   unsigned CurOp = 2;
@@ -265,10 +209,11 @@ void RISCVDAGToDAGISel::selectVLSEGFF(SDNode *Node, bool IsMasked) {
     CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()});
 
   SDValue SuperReg = SDValue(Load, 0);
-  for (unsigned I = 0; I < NF; ++I)
+  for (unsigned I = 0; I < NF; ++I) {
+    unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I);
     ReplaceUses(SDValue(Node, I),
-                CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL,
-                                               VT, SuperReg));
+                CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg));
+  }
 
   ReplaceUses(SDValue(Node, NF), SDValue(ReadVL, 0));   // VL
   ReplaceUses(SDValue(Node, NF + 1), SDValue(Load, 1)); // Chain
@@ -282,7 +227,7 @@ void RISCVDAGToDAGISel::selectVLXSEG(SDNode *Node, bool IsMasked,
   MVT VT = Node->getSimpleValueType(0);
   unsigned ScalarSize = VT.getScalarSizeInBits();
   MVT XLenVT = Subtarget->getXLenVT();
-  RISCVVLMUL LMUL = getLMUL(VT);
+  RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
   SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
   unsigned CurOp = 2;
   SmallVector<SDValue, 7> Operands;
@@ -307,7 +252,7 @@ void RISCVDAGToDAGISel::selectVLXSEG(SDNode *Node, bool IsMasked,
   assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
          "Element count mismatch");
 
-  RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
+  RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
   unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
   const RISCV::VLXSEGPseudo *P = RISCV::getVLXSEGPseudo(
       NF, IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -319,10 +264,11 @@ void RISCVDAGToDAGISel::selectVLXSEG(SDNode *Node, bool IsMasked,
     CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()});
 
   SDValue SuperReg = SDValue(Load, 0);
-  for (unsigned I = 0; I < NF; ++I)
+  for (unsigned I = 0; I < NF; ++I) {
+    unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I);
     ReplaceUses(SDValue(Node, I),
-                CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL,
-                                               VT, SuperReg));
+                CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg));
+  }
 
   ReplaceUses(SDValue(Node, NF), SDValue(Load, 1));
   CurDAG->RemoveDeadNode(Node);
@@ -339,7 +285,7 @@ void RISCVDAGToDAGISel::selectVSSEG(SDNode *Node, bool IsMasked,
   MVT VT = Node->getOperand(2)->getSimpleValueType(0);
   unsigned ScalarSize = VT.getScalarSizeInBits();
   MVT XLenVT = Subtarget->getXLenVT();
-  RISCVVLMUL LMUL = getLMUL(VT);
+  RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
   SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
   SmallVector<SDValue, 8> Regs(Node->op_begin() + 2, Node->op_begin() + 2 + NF);
   SDValue StoreVal = createTuple(*CurDAG, Regs, NF, LMUL);
@@ -376,7 +322,7 @@ void RISCVDAGToDAGISel::selectVSXSEG(SDNode *Node, bool IsMasked,
   MVT VT = Node->getOperand(2)->getSimpleValueType(0);
   unsigned ScalarSize = VT.getScalarSizeInBits();
   MVT XLenVT = Subtarget->getXLenVT();
-  RISCVVLMUL LMUL = getLMUL(VT);
+  RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
   SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT);
   SmallVector<SDValue, 7> Operands;
   SmallVector<SDValue, 8> Regs(Node->op_begin() + 2, Node->op_begin() + 2 + NF);
@@ -397,7 +343,7 @@ void RISCVDAGToDAGISel::selectVSXSEG(SDNode *Node, bool IsMasked,
   assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
          "Element count mismatch");
 
-  RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
+  RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
   unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
   const RISCV::VSXSEGPseudo *P = RISCV::getVSXSEGPseudo(
       NF, IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -411,47 +357,6 @@ void RISCVDAGToDAGISel::selectVSXSEG(SDNode *Node, bool IsMasked,
   ReplaceNode(Node, Store);
 }
 
-static unsigned getRegClassIDForVecVT(MVT VT) {
-  if (VT.getVectorElementType() == MVT::i1)
-    return RISCV::VRRegClassID;
-  return getRegClassIDForLMUL(getLMUL(VT));
-}
-
-// Attempt to decompose a subvector insert/extract between VecVT and
-// SubVecVT via subregister indices. Returns the subregister index that
-// can perform the subvector insert/extract with the given element index, as
-// well as the index corresponding to any leftover subvectors that must be
-// further inserted/extracted within the register class for SubVecVT.
-static std::pair<unsigned, unsigned>
-decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT,
-                                         unsigned InsertExtractIdx,
-                                         const RISCVRegisterInfo *TRI) {
-  static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID &&
-                 RISCV::VRM4RegClassID > RISCV::VRM2RegClassID &&
-                 RISCV::VRM2RegClassID > RISCV::VRRegClassID),
-                "Register classes not ordered");
-  unsigned VecRegClassID = getRegClassIDForVecVT(VecVT);
-  unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT);
-  // Try to compose a subregister index that takes us from the incoming
-  // LMUL>1 register class down to the outgoing one. At each step we half
-  // the LMUL:
-  //   nxv16i32 at 12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0
-  // Note that this is not guaranteed to find a subregister index, such as
-  // when we are extracting from one VR type to another.
-  unsigned SubRegIdx = RISCV::NoSubRegister;
-  for (const unsigned RCID :
-       {RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID})
-    if (VecRegClassID > RCID && SubRegClassID <= RCID) {
-      VecVT = VecVT.getHalfNumVectorElementsVT();
-      bool IsHi =
-          InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue();
-      SubRegIdx = TRI->composeSubRegIndices(SubRegIdx,
-                                            getSubregIndexByMVT(VecVT, IsHi));
-      if (IsHi)
-        InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue();
-    }
-  return {SubRegIdx, InsertExtractIdx};
-}
 
 void RISCVDAGToDAGISel::Select(SDNode *Node) {
   // If we have a custom node, we have already selected.
@@ -726,8 +631,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
              "Element count mismatch");
 
-      RISCVVLMUL LMUL = getLMUL(VT);
-      RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
+      RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
+      RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
       unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
       const RISCV::VLX_VSXPseudo *P = RISCV::getVLXPseudo(
           IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -855,8 +760,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() &&
              "Element count mismatch");
 
-      RISCVVLMUL LMUL = getLMUL(VT);
-      RISCVVLMUL IndexLMUL = getLMUL(IndexVT);
+      RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT);
+      RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT);
       unsigned IndexScalarSize = IndexVT.getScalarSizeInBits();
       const RISCV::VLX_VSXPseudo *P = RISCV::getVSXPseudo(
           IsMasked, IsOrdered, IndexScalarSize, static_cast<unsigned>(LMUL),
@@ -895,7 +800,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
     // For now, keep the two paths separate.
     if (VT.isScalableVector() && SubVecVT.isScalableVector()) {
       bool IsFullVecReg = false;
-      switch (getLMUL(SubVecVT)) {
+      switch (RISCVTargetLowering::getLMUL(SubVecVT)) {
       default:
         break;
       case RISCVVLMUL::LMUL_1:
@@ -915,10 +820,11 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       const auto *TRI = Subtarget->getRegisterInfo();
       unsigned SubRegIdx;
       std::tie(SubRegIdx, Idx) =
-          decomposeSubvectorInsertExtractToSubRegs(VT, SubVecVT, Idx, TRI);
+          RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
+              VT, SubVecVT, Idx, TRI);
 
       // If the Idx hasn't been completely eliminated then this is a subvector
-      // extract which doesn't naturally align to a vector register. These must
+      // insert which doesn't naturally align to a vector register. These must
       // be handled using instructions to manipulate the vector registers.
       if (Idx != 0)
         break;
@@ -936,7 +842,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       if (!Node->getOperand(0).isUndef())
         break;
 
-      unsigned RegClassID = getRegClassIDForVecVT(VT);
+      unsigned RegClassID = RISCVTargetLowering::getRegClassIDForVecVT(VT);
 
       SDValue RC =
           CurDAG->getTargetConstant(RegClassID, DL, Subtarget->getXLenVT());
@@ -961,7 +867,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       const auto *TRI = Subtarget->getRegisterInfo();
       unsigned SubRegIdx;
       std::tie(SubRegIdx, Idx) =
-          decomposeSubvectorInsertExtractToSubRegs(InVT, VT, Idx, TRI);
+          RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
+              InVT, VT, Idx, TRI);
 
       // If the Idx hasn't been completely eliminated then this is a subvector
       // extract which doesn't naturally align to a vector register. These must
@@ -972,8 +879,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       // If we haven't set a SubRegIdx, then we must be going between LMUL<=1
       // types (VR -> VR). This can be done as a copy.
       if (SubRegIdx == RISCV::NoSubRegister) {
-        unsigned InRegClassID = getRegClassIDForVecVT(InVT);
-        assert(getRegClassIDForVecVT(VT) == RISCV::VRRegClassID &&
+        unsigned InRegClassID =
+            RISCVTargetLowering::getRegClassIDForVecVT(InVT);
+        assert(RISCVTargetLowering::getRegClassIDForVecVT(VT) ==
+                   RISCV::VRRegClassID &&
                InRegClassID == RISCV::VRRegClassID &&
                "Unexpected subvector extraction");
         SDValue RC =
@@ -993,7 +902,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       if (Idx != 0)
         break;
 
-      unsigned InRegClassID = getRegClassIDForVecVT(InVT);
+      unsigned InRegClassID = RISCVTargetLowering::getRegClassIDForVecVT(InVT);
 
       SDValue RC =
           CurDAG->getTargetConstant(InRegClassID, DL, Subtarget->getXLenVT());

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8c07f4b1a28a..761f15f31ecb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -464,6 +464,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
       setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
       setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+
+      setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
     }
 
     // Expand various CCs to best match the RVV ISA, which natively supports UNE
@@ -498,6 +500,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
       setOperationAction(ISD::FCOPYSIGN, VT, Legal);
+
+      setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
     };
 
     if (Subtarget.hasStdExtZfh())
@@ -800,6 +804,108 @@ static unsigned getBranchOpcodeForIntCondCode(ISD::CondCode CC) {
   }
 }
 
+RISCVVLMUL RISCVTargetLowering::getLMUL(MVT VT) {
+  switch (VT.getSizeInBits().getKnownMinValue() / 8) {
+  default:
+    llvm_unreachable("Invalid LMUL.");
+  case 1:
+    return RISCVVLMUL::LMUL_F8;
+  case 2:
+    return RISCVVLMUL::LMUL_F4;
+  case 4:
+    return RISCVVLMUL::LMUL_F2;
+  case 8:
+    return RISCVVLMUL::LMUL_1;
+  case 16:
+    return RISCVVLMUL::LMUL_2;
+  case 32:
+    return RISCVVLMUL::LMUL_4;
+  case 64:
+    return RISCVVLMUL::LMUL_8;
+  }
+}
+
+unsigned RISCVTargetLowering::getRegClassIDForLMUL(RISCVVLMUL LMul) {
+  switch (LMul) {
+  default:
+    llvm_unreachable("Invalid LMUL.");
+  case RISCVVLMUL::LMUL_F8:
+  case RISCVVLMUL::LMUL_F4:
+  case RISCVVLMUL::LMUL_F2:
+  case RISCVVLMUL::LMUL_1:
+    return RISCV::VRRegClassID;
+  case RISCVVLMUL::LMUL_2:
+    return RISCV::VRM2RegClassID;
+  case RISCVVLMUL::LMUL_4:
+    return RISCV::VRM4RegClassID;
+  case RISCVVLMUL::LMUL_8:
+    return RISCV::VRM8RegClassID;
+  }
+}
+
+unsigned RISCVTargetLowering::getSubregIndexByMVT(MVT VT, unsigned Index) {
+  RISCVVLMUL LMUL = getLMUL(VT);
+  if (LMUL == RISCVVLMUL::LMUL_F8 || LMUL == RISCVVLMUL::LMUL_F4 ||
+      LMUL == RISCVVLMUL::LMUL_F2 || LMUL == RISCVVLMUL::LMUL_1) {
+    static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7,
+                  "Unexpected subreg numbering");
+    return RISCV::sub_vrm1_0 + Index;
+  }
+  if (LMUL == RISCVVLMUL::LMUL_2) {
+    static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3,
+                  "Unexpected subreg numbering");
+    return RISCV::sub_vrm2_0 + Index;
+  }
+  if (LMUL == RISCVVLMUL::LMUL_4) {
+    static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1,
+                  "Unexpected subreg numbering");
+    return RISCV::sub_vrm4_0 + Index;
+  }
+  llvm_unreachable("Invalid vector type.");
+}
+
+unsigned RISCVTargetLowering::getRegClassIDForVecVT(MVT VT) {
+  if (VT.getVectorElementType() == MVT::i1)
+    return RISCV::VRRegClassID;
+  return getRegClassIDForLMUL(getLMUL(VT));
+}
+
+// Attempt to decompose a subvector insert/extract between VecVT and
+// SubVecVT via subregister indices. Returns the subregister index that
+// can perform the subvector insert/extract with the given element index, as
+// well as the index corresponding to any leftover subvectors that must be
+// further inserted/extracted within the register class for SubVecVT.
+std::pair<unsigned, unsigned>
+RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
+    MVT VecVT, MVT SubVecVT, unsigned InsertExtractIdx,
+    const RISCVRegisterInfo *TRI) {
+  static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID &&
+                 RISCV::VRM4RegClassID > RISCV::VRM2RegClassID &&
+                 RISCV::VRM2RegClassID > RISCV::VRRegClassID),
+                "Register classes not ordered");
+  unsigned VecRegClassID = getRegClassIDForVecVT(VecVT);
+  unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT);
+  // Try to compose a subregister index that takes us from the incoming
+  // LMUL>1 register class down to the outgoing one. At each step we half
+  // the LMUL:
+  //   nxv16i32 at 12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0
+  // Note that this is not guaranteed to find a subregister index, such as
+  // when we are extracting from one VR type to another.
+  unsigned SubRegIdx = RISCV::NoSubRegister;
+  for (const unsigned RCID :
+       {RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID})
+    if (VecRegClassID > RCID && SubRegClassID <= RCID) {
+      VecVT = VecVT.getHalfNumVectorElementsVT();
+      bool IsHi =
+          InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue();
+      SubRegIdx = TRI->composeSubRegIndices(SubRegIdx,
+                                            getSubregIndexByMVT(VecVT, IsHi));
+      if (IsHi)
+        InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue();
+    }
+  return {SubRegIdx, InsertExtractIdx};
+}
+
 // Return the largest legal scalable vector type that matches VT's element type.
 static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT,
                                             const RISCVSubtarget &Subtarget) {
@@ -1206,6 +1312,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::VECREDUCE_FADD:
   case ISD::VECREDUCE_SEQ_FADD:
     return lowerFPVECREDUCE(Op, DAG);
+  case ISD::EXTRACT_SUBVECTOR:
+    return lowerEXTRACT_SUBVECTOR(Op, DAG);
   case ISD::BUILD_VECTOR:
     return lowerBUILD_VECTOR(Op, DAG, Subtarget);
   case ISD::VECTOR_SHUFFLE:
@@ -2129,6 +2237,70 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
                      DAG.getConstant(0, DL, Subtarget.getXLenVT()));
 }
 
+static MVT getLMUL1VT(MVT VT) {
+  assert(VT.getVectorElementType().getSizeInBits() <= 64 &&
+         "Unexpected vector MVT");
+  return MVT::getScalableVectorVT(
+      VT.getVectorElementType(),
+      RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
+}
+
+SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  SDValue Vec = Op.getOperand(0);
+  MVT SubVecVT = Op.getSimpleValueType();
+  MVT VecVT = Vec.getSimpleValueType();
+
+  // TODO: Only handle scalable->scalable extracts for now, and revisit this
+  // for fixed-length vectors later.
+  if (!SubVecVT.isScalableVector() || !VecVT.isScalableVector())
+    return Op;
+
+  SDLoc DL(Op);
+  unsigned OrigIdx = Op.getConstantOperandVal(1);
+  const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
+
+  unsigned SubRegIdx, RemIdx;
+  std::tie(SubRegIdx, RemIdx) =
+      RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
+          VecVT, SubVecVT, OrigIdx, TRI);
+
+  // If the Idx has been completely eliminated then this is a subvector extract
+  // which naturally aligns to a vector register. These can easily be handled
+  // using subregister manipulation.
+  if (RemIdx == 0)
+    return Op;
+
+  // Else we must shift our vector register directly to extract the subvector.
+  // Do this using VSLIDEDOWN.
+  MVT XLenVT = Subtarget.getXLenVT();
+
+  // Extract a subvector equal to the nearest full vector register type. This
+  // should resolve to a EXTRACT_SUBREG instruction.
+  unsigned AlignedIdx = OrigIdx - RemIdx;
+  MVT InterSubVT = getLMUL1VT(VecVT);
+  SDValue AlignedExtract =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
+                  DAG.getConstant(AlignedIdx, DL, XLenVT));
+
+  // Slide this vector register down by the desired number of elements in order
+  // to place the desired subvector starting at element 0.
+  SDValue SlidedownAmt = DAG.getConstant(RemIdx, DL, XLenVT);
+  // For scalable vectors this must be further multiplied by vscale.
+  SlidedownAmt = DAG.getNode(ISD::VSCALE, DL, XLenVT, SlidedownAmt);
+
+  SDValue Mask, VL;
+  std::tie(Mask, VL) = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
+  SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, InterSubVT,
+                                  DAG.getUNDEF(InterSubVT), AlignedExtract,
+                                  SlidedownAmt, Mask, VL);
+
+  // Now the vector is in the right position, extract our final subvector. This
+  // should resolve to a COPY.
+  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown,
+                     DAG.getConstant(0, DL, XLenVT));
+}
+
 SDValue
 RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
                                                      SelectionDAG &DAG) const {

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 65c25e3e9719..f612d81136ec 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -20,6 +20,7 @@
 
 namespace llvm {
 class RISCVSubtarget;
+struct RISCVRegisterInfo;
 namespace RISCVISD {
 enum NodeType : unsigned {
   FIRST_NUMBER = ISD::BUILTIN_OP_END,
@@ -374,6 +375,15 @@ class RISCVTargetLowering : public TargetLowering {
       MachineMemOperand::Flags Flags = MachineMemOperand::MONone,
       bool *Fast = nullptr) const override;
 
+  static RISCVVLMUL getLMUL(MVT VT);
+  static unsigned getRegClassIDForLMUL(RISCVVLMUL LMul);
+  static unsigned getSubregIndexByMVT(MVT VT, unsigned Index);
+  static unsigned getRegClassIDForVecVT(MVT VT);
+  static std::pair<unsigned, unsigned>
+  decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT,
+                                           unsigned InsertExtractIdx,
+                                           const RISCVRegisterInfo *TRI);
+
 private:
   void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
                         const SmallVectorImpl<ISD::InputArg> &Ins,
@@ -410,6 +420,7 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFPVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const;

diff  --git a/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll b/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
index c14abab5440f..a28fc4c849bf 100644
--- a/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple riscv64 -mattr=+experimental-v -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple riscv64 -mattr=+m,+d,+experimental-zfh,+experimental-v -verify-machineinstrs < %s | FileCheck %s
 
 define <vscale x 4 x i32> @extract_nxv8i32_nxv4i32_0(<vscale x 8 x i32> %vec) {
 ; CHECK-LABEL: extract_nxv8i32_nxv4i32_0:
@@ -190,13 +190,41 @@ define <vscale x 1 x i32> @extract_nxv16i32_nxv1i32_0(<vscale x 16 x i32> %vec)
   ret <vscale x 1 x i32> %c
 }
 
-; TODO: Extracts that don't align to a vector register are not yet supported.
-; In this case we want to extract the upper half of the lowest VR subregister
-; in the LMUL group.
-; define <vscale x 1 x i32> @extract_nxv16i32_nxv1i32_1(<vscale x 16 x i32> %vec) {
-;   %c = call <vscale x 1 x i32> @llvm.experimental.vector.extract.nxv1i32.nxv16i32(<vscale x 16 x i32> %vec, i64 1)
-;   ret <vscale x 1 x i32> %c
-; }
+define <vscale x 1 x i32> @extract_nxv16i32_nxv1i32_1(<vscale x 16 x i32> %vec) {
+; CHECK-LABEL: extract_nxv16i32_nxv1i32_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    vsetvli a1, zero, e32,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v8, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 1 x i32> @llvm.experimental.vector.extract.nxv1i32.nxv16i32(<vscale x 16 x i32> %vec, i64 1)
+  ret <vscale x 1 x i32> %c
+}
+
+define <vscale x 1 x i32> @extract_nxv16i32_nxv1i32_3(<vscale x 16 x i32> %vec) {
+; CHECK-LABEL: extract_nxv16i32_nxv1i32_3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    vsetvli a1, zero, e32,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v9, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 1 x i32> @llvm.experimental.vector.extract.nxv1i32.nxv16i32(<vscale x 16 x i32> %vec, i64 3)
+  ret <vscale x 1 x i32> %c
+}
+
+define <vscale x 1 x i32> @extract_nxv16i32_nxv1i32_15(<vscale x 16 x i32> %vec) {
+; CHECK-LABEL: extract_nxv16i32_nxv1i32_15:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    vsetvli a1, zero, e32,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v15, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 1 x i32> @llvm.experimental.vector.extract.nxv1i32.nxv16i32(<vscale x 16 x i32> %vec, i64 15)
+  ret <vscale x 1 x i32> %c
+}
 
 define <vscale x 1 x i32> @extract_nxv16i32_nxv1i32_2(<vscale x 16 x i32> %vec) {
 ; CHECK-LABEL: extract_nxv16i32_nxv1i32_2:
@@ -215,6 +243,124 @@ define <vscale x 1 x i32> @extract_nxv2i32_nxv1i32_0(<vscale x 2 x i32> %vec) {
   ret <vscale x 1 x i32> %c
 }
 
+define <vscale x 2 x i8> @extract_nxv32i8_nxv2i8_0(<vscale x 32 x i8> %vec) {
+; CHECK-LABEL: extract_nxv32i8_nxv2i8_0:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    # kill: def $v8 killed $v8 killed $v8m4
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 0)
+  ret <vscale x 2 x i8> %c
+}
+
+define <vscale x 2 x i8> @extract_nxv32i8_nxv2i8_2(<vscale x 32 x i8> %vec) {
+; CHECK-LABEL: extract_nxv32i8_nxv2i8_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 2
+; CHECK-NEXT:    vsetvli a1, zero, e8,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v8, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 2)
+  ret <vscale x 2 x i8> %c
+}
+
+define <vscale x 2 x i8> @extract_nxv32i8_nxv2i8_4(<vscale x 32 x i8> %vec) {
+; CHECK-LABEL: extract_nxv32i8_nxv2i8_4:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 1
+; CHECK-NEXT:    vsetvli a1, zero, e8,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v8, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 4)
+  ret <vscale x 2 x i8> %c
+}
+
+define <vscale x 2 x i8> @extract_nxv32i8_nxv2i8_6(<vscale x 32 x i8> %vec) {
+; CHECK-LABEL: extract_nxv32i8_nxv2i8_6:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    addi a1, zero, 6
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    vsetvli a1, zero, e8,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v8, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 6)
+  ret <vscale x 2 x i8> %c
+}
+
+define <vscale x 2 x i8> @extract_nxv32i8_nxv2i8_8(<vscale x 32 x i8> %vec) {
+; CHECK-LABEL: extract_nxv32i8_nxv2i8_8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 8)
+  ret <vscale x 2 x i8> %c
+}
+
+define <vscale x 2 x i8> @extract_nxv32i8_nxv2i8_22(<vscale x 32 x i8> %vec) {
+; CHECK-LABEL: extract_nxv32i8_nxv2i8_22:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    addi a1, zero, 6
+; CHECK-NEXT:    mul a0, a0, a1
+; CHECK-NEXT:    vsetvli a1, zero, e8,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v10, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 22)
+  ret <vscale x 2 x i8> %c
+}
+
+define <vscale x 1 x i8> @extract_nxv8i8_nxv1i8_7(<vscale x 8 x i8> %vec) {
+; CHECK-LABEL: extract_nxv8i8_nxv1i8_7:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 3
+; CHECK-NEXT:    slli a1, a0, 3
+; CHECK-NEXT:    sub a0, a1, a0
+; CHECK-NEXT:    vsetvli a1, zero, e8,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v8, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 1 x i8> @llvm.experimental.vector.extract.nxv1i8.nxv8i8(<vscale x 8 x i8> %vec, i64 7)
+  ret <vscale x 1 x i8> %c
+}
+
+define <vscale x 2 x half> @extract_nxv2f16_nxv16f16_0(<vscale x 16 x half> %vec) {
+; CHECK-LABEL: extract_nxv2f16_nxv16f16_0:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    # kill: def $v8 killed $v8 killed $v8m4
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x half> @llvm.experimental.vector.extract.nxv2f16.nxv16f16(<vscale x 16 x half> %vec, i64 0)
+  ret <vscale x 2 x half> %c
+}
+
+define <vscale x 2 x half> @extract_nxv2f16_nxv16f16_2(<vscale x 16 x half> %vec) {
+; CHECK-LABEL: extract_nxv2f16_nxv16f16_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    srli a0, a0, 2
+; CHECK-NEXT:    vsetvli a1, zero, e16,m1,ta,mu
+; CHECK-NEXT:    vslidedown.vx v8, v8, a0
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x half> @llvm.experimental.vector.extract.nxv2f16.nxv16f16(<vscale x 16 x half> %vec, i64 2)
+  ret <vscale x 2 x half> %c
+}
+
+define <vscale x 2 x half> @extract_nxv2f16_nxv16f16_4(<vscale x 16 x half> %vec) {
+; CHECK-LABEL: extract_nxv2f16_nxv16f16_4:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %c = call <vscale x 2 x half> @llvm.experimental.vector.extract.nxv2f16.nxv16f16(<vscale x 16 x half> %vec, i64 4)
+  ret <vscale x 2 x half> %c
+}
+
+declare <vscale x 1 x i8> @llvm.experimental.vector.extract.nxv1i8.nxv8i8(<vscale x 8 x i8> %vec, i64 %idx)
+
+declare <vscale x 2 x i8> @llvm.experimental.vector.extract.nxv2i8.nxv32i8(<vscale x 32 x i8> %vec, i64 %idx)
+
 declare <vscale x 1 x i32> @llvm.experimental.vector.extract.nxv1i32.nxv2i32(<vscale x 2 x i32> %vec, i64 %idx)
 
 declare <vscale x 2 x i32> @llvm.experimental.vector.extract.nxv2i32.nxv8i32(<vscale x 8 x i32> %vec, i64 %idx)
@@ -224,3 +370,5 @@ declare <vscale x 1 x i32> @llvm.experimental.vector.extract.nxv1i32.nxv16i32(<v
 declare <vscale x 2 x i32> @llvm.experimental.vector.extract.nxv2i32.nxv16i32(<vscale x 16 x i32> %vec, i64 %idx)
 declare <vscale x 4 x i32> @llvm.experimental.vector.extract.nxv4i32.nxv16i32(<vscale x 16 x i32> %vec, i64 %idx)
 declare <vscale x 8 x i32> @llvm.experimental.vector.extract.nxv8i32.nxv16i32(<vscale x 16 x i32> %vec, i64 %idx)
+
+declare <vscale x 2 x half> @llvm.experimental.vector.extract.nxv2f16.nxv16f16(<vscale x 16 x half> %vec, i64 %idx)


        


More information about the llvm-commits mailing list