[llvm] [RISCV] Add DAG combine to convert (iN reduce.add (zext (vXi1 A to vXiN)) into vcpop.m (PR #127497)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 18 12:15:36 PST 2025
================
@@ -18100,25 +18100,38 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
// (iX ctpop (bitcast (vXi1 A)))
// ->
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
+// and
+// (iN reduce.add (zext (vXi1 A to vXiN))
+// ->
+// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
// FIXME: It's complicated to match all the variations of this after type
// legalization so we only handle the pre-type legalization pattern, but that
// requires the fixed vector type to be legal.
-static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
+static SDValue combineToVCPOP(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ unsigned Opc = N->getOpcode();
+ assert((Opc == ISD::CTPOP || Opc == ISD::VECREDUCE_ADD) &&
+ "Unexpected opcode");
EVT VT = N->getValueType(0);
if (!VT.isScalarInteger())
return SDValue();
SDValue Src = N->getOperand(0);
- // Peek through zero_extend. It doesn't change the count.
- if (Src.getOpcode() == ISD::ZERO_EXTEND)
- Src = Src.getOperand(0);
+ if (Opc == ISD::CTPOP) {
+ // Peek through zero_extend. It doesn't change the count.
+ if (Src.getOpcode() == ISD::ZERO_EXTEND)
+ Src = Src.getOperand(0);
- if (Src.getOpcode() != ISD::BITCAST)
- return SDValue();
+ if (Src.getOpcode() != ISD::BITCAST)
+ return SDValue();
+ Src = Src.getOperand(0);
+ } else if (Opc == ISD::VECREDUCE_ADD) {
+ if (Src.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
----------------
preames wrote:
There's a subtle, nasty bug here.
Consider the case where the pattern is: (iN reduce.add (zext (vXi1 A to vXi4))
If runtime VLENB is such that the number of mask bits is greater than 16, this is *not* equal to the vcpop - due the wrapping behavior on the add reduce. You need to prove that the intermediate type is sufficiently wide to hold the element count of the mask source without overflow.
https://github.com/llvm/llvm-project/pull/127497
More information about the llvm-commits
mailing list