[llvm] cd79599 - [RISCV] Teach lowerScalarInsert to handle scalar value is the first element of a fixed vector.

Yeting Kuo via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 2 19:53:20 PDT 2023


Author: Yeting Kuo
Date: 2023-08-03T10:53:14+08:00
New Revision: cd7959930443f09403df7a9303634b50163a3d80

URL: https://github.com/llvm/llvm-project/commit/cd7959930443f09403df7a9303634b50163a3d80
DIFF: https://github.com/llvm/llvm-project/commit/cd7959930443f09403df7a9303634b50163a3d80.diff

LOG: [RISCV] Teach lowerScalarInsert to handle scalar value is the first element of a fixed vector.

D155929 teach lowerScalarInsert to handl start value (extractelement scalable_vector, 0)
and specifically converts fixed extracted vectors to scalable vectors when
lowering vector reduction. It's not enough because there is another way to
create (extractelement fixed_vector, 0) as a start value of lowerScalarInsert
like #64327.

#64327: https://github.com/llvm/llvm-project/issues/64327.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D156863

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ddd06d129a08d2..220396b84d9619 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3466,11 +3466,19 @@ static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT,
 
   if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
       isNullConstant(Scalar.getOperand(1))) {
-    MVT ExtractedVT = Scalar.getOperand(0).getSimpleValueType();
-    if (ExtractedVT.bitsLE(VT))
-      return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru,
-                         Scalar.getOperand(0), DAG.getConstant(0, DL, XLenVT));
-    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Scalar.getOperand(0),
+    SDValue ExtractedVal = Scalar.getOperand(0);
+    MVT ExtractedVT = ExtractedVal.getSimpleValueType();
+    MVT ExtractedContainerVT = ExtractedVT;
+    if (ExtractedContainerVT.isFixedLengthVector()) {
+      ExtractedContainerVT = getContainerForFixedLengthVector(
+          DAG, ExtractedContainerVT, Subtarget);
+      ExtractedVal = convertToScalableVector(ExtractedContainerVT, ExtractedVal,
+                                             DAG, Subtarget);
+    }
+    if (ExtractedContainerVT.bitsLE(VT))
+      return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, ExtractedVal,
+                         DAG.getConstant(0, DL, XLenVT));
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
                        DAG.getConstant(0, DL, XLenVT));
   }
 
@@ -7722,25 +7730,6 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
                      DAG.getConstant(0, DL, XLenVT));
 }
 
-// Function to extract the first element of Vec. For fixed vector Vec, this
-// converts it to a scalable vector before extraction, so subsequent
-// optimizations don't have to handle fixed vectors.
-static SDValue getFirstElement(SDValue Vec, SelectionDAG &DAG,
-                               const RISCVSubtarget &Subtarget) {
-  SDLoc DL(Vec);
-  MVT XLenVT = Subtarget.getXLenVT();
-  MVT VecVT = Vec.getSimpleValueType();
-  MVT VecEltVT = VecVT.getVectorElementType();
-
-  MVT ContainerVT = VecVT;
-  if (VecVT.isFixedLengthVector()) {
-    ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
-    Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
-  }
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
-                     DAG.getConstant(0, DL, XLenVT));
-}
-
 SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
                                             SelectionDAG &DAG) const {
   SDLoc DL(Op);
@@ -7783,7 +7772,9 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
   case ISD::UMIN:
   case ISD::SMAX:
   case ISD::SMIN:
-    StartV = getFirstElement(Vec, DAG, Subtarget);
+    MVT XLenVT = Subtarget.getXLenVT();
+    StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
+                         DAG.getConstant(0, DL, XLenVT));
   }
   return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec,
                            Mask, VL, DL, DAG, Subtarget);
@@ -7811,11 +7802,16 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
     return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1),
                            Op.getOperand(0));
   case ISD::VECREDUCE_FMIN:
-    return std::make_tuple(RISCVISD::VECREDUCE_FMIN_VL, Op.getOperand(0),
-                           getFirstElement(Op.getOperand(0), DAG, Subtarget));
-  case ISD::VECREDUCE_FMAX:
-    return std::make_tuple(RISCVISD::VECREDUCE_FMAX_VL, Op.getOperand(0),
-                           getFirstElement(Op.getOperand(0), DAG, Subtarget));
+  case ISD::VECREDUCE_FMAX: {
+    MVT XLenVT = Subtarget.getXLenVT();
+    SDValue Front =
+        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
+                    DAG.getConstant(0, DL, XLenVT));
+    unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN)
+                          ? RISCVISD::VECREDUCE_FMIN_VL
+                          : RISCVISD::VECREDUCE_FMAX_VL;
+    return std::make_tuple(RVVOpc, Op.getOperand(0), Front);
+  }
   }
 }
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
index e9231b208c7ea2..3453545c9adab0 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
@@ -1875,3 +1875,32 @@ define signext i8 @vpreduce_mul_v64i8(i8 signext %s, <64 x i8> %v, <64 x i1> %m,
   %r = call i8 @llvm.vp.reduce.mul.v64i8(i8 %s, <64 x i8> %v, <64 x i1> %m, i32 %evl)
   ret i8 %r
 }
+
+; Test start value is the first element of a vector.
+define zeroext i8 @front_ele_v4i8(<4 x i8> %v, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: front_ele_v4i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vredand.vs v8, v8, v8, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    andi a0, a0, 255
+; CHECK-NEXT:    ret
+  %s = extractelement <4 x i8> %v, i64 0
+  %r = call i8 @llvm.vp.reduce.and.v4i8(i8 %s, <4 x i8> %v, <4 x i1> %m, i32 %evl)
+  ret i8 %r
+}
+
+; Test start value is the first element of a vector which longer than M1.
+declare i8 @llvm.vp.reduce.and.v32i8(i8, <32 x i8>, <32 x i1>, i32)
+define zeroext i8 @front_ele_v32i8(<32 x i8> %v, <32 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: front_ele_v32i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e8, m2, ta, ma
+; CHECK-NEXT:    vredand.vs v8, v8, v8, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    andi a0, a0, 255
+; CHECK-NEXT:    ret
+  %s = extractelement <32 x i8> %v, i64 0
+  %r = call i8 @llvm.vp.reduce.and.v32i8(i8 %s, <32 x i8> %v, <32 x i1> %m, i32 %evl)
+  ret i8 %r
+}


        


More information about the llvm-commits mailing list