[llvm] [RISCV] Generalize (sub zext, zext) -> (sext (sub zext, zext)) to add (PR #86248)
Wang Pengcheng via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 22 08:41:08 PDT 2024
================
@@ -12899,6 +12899,56 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT));
}
+// add (zext, zext) -> zext (add (zext, zext))
+// sub (zext, zext) -> sext (sub (zext, zext))
+//
+// where the sum of the extend widths match, and the the range of the bin op
+// fits inside the width of the narrower bin op. (For profitability on rvv, we
+// use a power of two for both inner and outer extend.)
+//
+// TODO: Extend this to other binary ops
+static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+
+ EVT VT = N->getValueType(0);
+ if (!VT.isVector() || !Subtarget.getTargetLowering()->isTypeLegal(VT))
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+ if (!N0.hasOneUse() || !N1.hasOneUse())
+ return SDValue();
+
+ SDValue Src0 = N0.getOperand(0);
+ SDValue Src1 = N1.getOperand(0);
+ EVT SrcVT = Src0.getValueType();
+ if (!Subtarget.getTargetLowering()->isTypeLegal(SrcVT) ||
+ SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
+ SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+ return SDValue();
+
+ LLVMContext &C = *DAG.getContext();
+ EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+ EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+
+ Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+ Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+
+ // Src0 and Src1 are zero extended, so they're always +ve if signed.
+ //
+ // sub can produce a -ve from two +ve operands, so it needs sign
+ // extended. Other nodes produce a +ve from two +ve operands, so zero extend
+ // instead.
+ unsigned OuterExtend =
+ N->getOpcode() == ISD::SUB ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
----------------
wangpc-pp wrote:
If we are going to extend this combination to more opcodes, we can make `OuterExtend` a parameter.
https://github.com/llvm/llvm-project/pull/86248
More information about the llvm-commits
mailing list