[llvm] e340e9e - [RISCV][NFC] Reuse getDeinterleaveViaVNSRL to lower deinterleave intrinsics

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 23 08:23:18 PST 2023


Author: Luke Lau
Date: 2023-02-23T16:23:05Z
New Revision: e340e9e632124f541d8225ccf0a5c55de402fb3c

URL: https://github.com/llvm/llvm-project/commit/e340e9e632124f541d8225ccf0a5c55de402fb3c
DIFF: https://github.com/llvm/llvm-project/commit/e340e9e632124f541d8225ccf0a5c55de402fb3c.diff

LOG: [RISCV][NFC] Reuse getDeinterleaveViaVNSRL to lower deinterleave intrinsics

This modifies it to work on both scalable and fixed vectors

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D144584

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e97ffc97df97..cb081b7fce3b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3113,27 +3113,36 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
 }
 
 // Lower a deinterleave shuffle to vnsrl.
-static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT,
-                                       MVT ContainerVT,
-                                       SDValue Src, bool EvenElts,
-                                       SDValue TrueMask, SDValue VL,
+// [a, p, b, q, c, r, d, s] -> [a, b, c, d] (EvenElts == true)
+//                          -> [p, q, r, s] (EvenElts == false)
+// VT is the type of the vector to return, <[vscale x ]n x ty>
+// Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
+static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
+                                       bool EvenElts,
                                        const RISCVSubtarget &Subtarget,
                                        SelectionDAG &DAG) {
-  // Convert the source using a container type with twice the elements. Since
-  // source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is
-  // safe to double.
-  MVT DoubleContainerVT =
-      MVT::getVectorVT(ContainerVT.getVectorElementType(),
-                       ContainerVT.getVectorElementCount() * 2);
-  Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget);
-
-  // Convert the vector to a wider integer type with the original element
-  // count. This also converts FP to int.
+  // The result is a vector of type <m x n x ty>
+  MVT ContainerVT = VT;
+  // Convert fixed vectors to scalable if needed
+  if (ContainerVT.isFixedLengthVector()) {
+    assert(Src.getSimpleValueType().isFixedLengthVector());
+    ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);
+
+    // The source is a vector of type <m x n*2 x ty>
+    MVT SrcContainerVT =
+        MVT::getVectorVT(ContainerVT.getVectorElementType(),
+                         ContainerVT.getVectorElementCount() * 2);
+    Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
+  }
+
+  auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+  // Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
+  // This also converts FP to int.
   unsigned EltBits = ContainerVT.getScalarSizeInBits();
-  MVT WideIntContainerVT =
-    MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2),
-                     ContainerVT.getVectorElementCount());
-  Src = DAG.getBitcast(WideIntContainerVT, Src);
+  MVT WideSrcContainerVT = MVT::getVectorVT(
+      MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount());
+  Src = DAG.getBitcast(WideSrcContainerVT, Src);
 
   // The integer version of the container type.
   MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger();
@@ -3150,7 +3159,9 @@ static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT,
   // Cast back to FP if needed.
   Res = DAG.getBitcast(ContainerVT, Res);
 
-  return convertFromScalableVector(VT, Res, DAG, Subtarget);
+  if (VT.isFixedLengthVector())
+    Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
+  return Res;
 }
 
 static SDValue
@@ -3461,9 +3472,12 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
     return convertFromScalableVector(VT, Res, DAG, Subtarget);
   }
 
-  if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget))
-    return getDeinterleaveViaVNSRL(DL, VT, ContainerVT, V1.getOperand(0),
-                                   Mask[0] == 0, TrueMask, VL, Subtarget, DAG);
+  // If this is a deinterleave and we can widen the vector, then we can use
+  // vnsrl to deinterleave.
+  if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget)) {
+    return getDeinterleaveViaVNSRL(DL, VT, V1.getOperand(0), Mask[0] == 0,
+                                   Subtarget, DAG);
+  }
 
   // Detect an interleave shuffle and lower to
   // (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1))
@@ -6619,33 +6633,14 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
   auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget);
   SDValue Passthru = DAG.getUNDEF(ConcatVT);
 
-  // If the element type is smaller than ELEN, then we can deinterleave
-  // through vnsrl.wi
+  // We can deinterleave through vnsrl.wi if the element type is smaller than
+  // ELEN
   if (VecVT.getScalarSizeInBits() < Subtarget.getELEN()) {
-    // Bitcast the concatenated vector from <n x m x ty> -> <n x m / 2 x ty * 2>
-    // This is also casts FPs to ints
-    MVT WideVT = MVT::getVectorVT(
-        MVT::getIntegerVT(ConcatVT.getScalarSizeInBits() * 2),
-        ConcatVT.getVectorElementCount().divideCoefficientBy(2));
-    SDValue Wide = DAG.getBitcast(WideVT, Concat);
-
-    MVT NarrowVT = VecVT.changeVectorElementTypeToInteger();
-    SDValue Passthru = DAG.getUNDEF(VecVT);
-
-    SDValue Even = DAG.getNode(
-        RISCVISD::VNSRL_VL, DL, NarrowVT, Wide,
-        DAG.getSplatVector(NarrowVT, DL, DAG.getConstant(0, DL, XLenVT)),
-        Passthru, Mask, VL);
-    SDValue Odd = DAG.getNode(
-        RISCVISD::VNSRL_VL, DL, NarrowVT, Wide,
-        DAG.getSplatVector(
-            NarrowVT, DL,
-            DAG.getConstant(VecVT.getScalarSizeInBits(), DL, XLenVT)),
-        Passthru, Mask, VL);
-
-    // Bitcast the results back in case it was casted from an FP vector
-    return DAG.getMergeValues(
-        {DAG.getBitcast(VecVT, Even), DAG.getBitcast(VecVT, Odd)}, DL);
+    SDValue Even =
+        getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG);
+    SDValue Odd =
+        getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG);
+    return DAG.getMergeValues({Even, Odd}, DL);
   }
 
   // For the indices, use the same SEW to avoid an extra vsetvli


        


More information about the llvm-commits mailing list