[llvm] cd79599 - [RISCV] Teach lowerScalarInsert to handle scalar value is the first element of a fixed vector.
Yeting Kuo via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 2 19:53:20 PDT 2023
Author: Yeting Kuo
Date: 2023-08-03T10:53:14+08:00
New Revision: cd7959930443f09403df7a9303634b50163a3d80
URL: https://github.com/llvm/llvm-project/commit/cd7959930443f09403df7a9303634b50163a3d80
DIFF: https://github.com/llvm/llvm-project/commit/cd7959930443f09403df7a9303634b50163a3d80.diff
LOG: [RISCV] Teach lowerScalarInsert to handle scalar value is the first element of a fixed vector.
D155929 teach lowerScalarInsert to handl start value (extractelement scalable_vector, 0)
and specifically converts fixed extracted vectors to scalable vectors when
lowering vector reduction. It's not enough because there is another way to
create (extractelement fixed_vector, 0) as a start value of lowerScalarInsert
like #64327.
#64327: https://github.com/llvm/llvm-project/issues/64327.
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D156863
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ddd06d129a08d2..220396b84d9619 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3466,11 +3466,19 @@ static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT,
if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
isNullConstant(Scalar.getOperand(1))) {
- MVT ExtractedVT = Scalar.getOperand(0).getSimpleValueType();
- if (ExtractedVT.bitsLE(VT))
- return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru,
- Scalar.getOperand(0), DAG.getConstant(0, DL, XLenVT));
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Scalar.getOperand(0),
+ SDValue ExtractedVal = Scalar.getOperand(0);
+ MVT ExtractedVT = ExtractedVal.getSimpleValueType();
+ MVT ExtractedContainerVT = ExtractedVT;
+ if (ExtractedContainerVT.isFixedLengthVector()) {
+ ExtractedContainerVT = getContainerForFixedLengthVector(
+ DAG, ExtractedContainerVT, Subtarget);
+ ExtractedVal = convertToScalableVector(ExtractedContainerVT, ExtractedVal,
+ DAG, Subtarget);
+ }
+ if (ExtractedContainerVT.bitsLE(VT))
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, ExtractedVal,
+ DAG.getConstant(0, DL, XLenVT));
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
DAG.getConstant(0, DL, XLenVT));
}
@@ -7722,25 +7730,6 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
DAG.getConstant(0, DL, XLenVT));
}
-// Function to extract the first element of Vec. For fixed vector Vec, this
-// converts it to a scalable vector before extraction, so subsequent
-// optimizations don't have to handle fixed vectors.
-static SDValue getFirstElement(SDValue Vec, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- SDLoc DL(Vec);
- MVT XLenVT = Subtarget.getXLenVT();
- MVT VecVT = Vec.getSimpleValueType();
- MVT VecEltVT = VecVT.getVectorElementType();
-
- MVT ContainerVT = VecVT;
- if (VecVT.isFixedLengthVector()) {
- ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
- Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
- }
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
- DAG.getConstant(0, DL, XLenVT));
-}
-
SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
@@ -7783,7 +7772,9 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
case ISD::UMIN:
case ISD::SMAX:
case ISD::SMIN:
- StartV = getFirstElement(Vec, DAG, Subtarget);
+ MVT XLenVT = Subtarget.getXLenVT();
+ StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
+ DAG.getConstant(0, DL, XLenVT));
}
return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec,
Mask, VL, DL, DAG, Subtarget);
@@ -7811,11 +7802,16 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1),
Op.getOperand(0));
case ISD::VECREDUCE_FMIN:
- return std::make_tuple(RISCVISD::VECREDUCE_FMIN_VL, Op.getOperand(0),
- getFirstElement(Op.getOperand(0), DAG, Subtarget));
- case ISD::VECREDUCE_FMAX:
- return std::make_tuple(RISCVISD::VECREDUCE_FMAX_VL, Op.getOperand(0),
- getFirstElement(Op.getOperand(0), DAG, Subtarget));
+ case ISD::VECREDUCE_FMAX: {
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDValue Front =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
+ DAG.getConstant(0, DL, XLenVT));
+ unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN)
+ ? RISCVISD::VECREDUCE_FMIN_VL
+ : RISCVISD::VECREDUCE_FMAX_VL;
+ return std::make_tuple(RVVOpc, Op.getOperand(0), Front);
+ }
}
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
index e9231b208c7ea2..3453545c9adab0 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
@@ -1875,3 +1875,32 @@ define signext i8 @vpreduce_mul_v64i8(i8 signext %s, <64 x i8> %v, <64 x i1> %m,
%r = call i8 @llvm.vp.reduce.mul.v64i8(i8 %s, <64 x i8> %v, <64 x i1> %m, i32 %evl)
ret i8 %r
}
+
+; Test start value is the first element of a vector.
+define zeroext i8 @front_ele_v4i8(<4 x i8> %v, <4 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: front_ele_v4i8:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT: vredand.vs v8, v8, v8, v0.t
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: andi a0, a0, 255
+; CHECK-NEXT: ret
+ %s = extractelement <4 x i8> %v, i64 0
+ %r = call i8 @llvm.vp.reduce.and.v4i8(i8 %s, <4 x i8> %v, <4 x i1> %m, i32 %evl)
+ ret i8 %r
+}
+
+; Test start value is the first element of a vector which longer than M1.
+declare i8 @llvm.vp.reduce.and.v32i8(i8, <32 x i8>, <32 x i1>, i32)
+define zeroext i8 @front_ele_v32i8(<32 x i8> %v, <32 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: front_ele_v32i8:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e8, m2, ta, ma
+; CHECK-NEXT: vredand.vs v8, v8, v8, v0.t
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: andi a0, a0, 255
+; CHECK-NEXT: ret
+ %s = extractelement <32 x i8> %v, i64 0
+ %r = call i8 @llvm.vp.reduce.and.v32i8(i8 %s, <32 x i8> %v, <32 x i1> %m, i32 %evl)
+ ret i8 %r
+}
More information about the llvm-commits
mailing list