[llvm] [RISCV] Generalize (sub zext, zext) -> (sext (sub zext, zext)) to add (PR #86248)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 24 20:57:20 PDT 2024


================
@@ -12899,6 +12899,56 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT));
 }
 
+// add (zext, zext) -> zext (add (zext, zext))
+// sub (zext, zext) -> sext (sub (zext, zext))
+//
+// where the sum of the extend widths match, and the the range of the bin op
+// fits inside the width of the narrower bin op. (For profitability on rvv, we
+// use a power of two for both inner and outer extend.)
+//
+// TODO: Extend this to other binary ops
+static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG,
+                                  const RISCVSubtarget &Subtarget) {
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isVector() || !Subtarget.getTargetLowering()->isTypeLegal(VT))
+    return SDValue();
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
+    return SDValue();
+  if (!N0.hasOneUse() || !N1.hasOneUse())
+    return SDValue();
+
+  SDValue Src0 = N0.getOperand(0);
+  SDValue Src1 = N1.getOperand(0);
+  EVT SrcVT = Src0.getValueType();
+  if (!Subtarget.getTargetLowering()->isTypeLegal(SrcVT) ||
+      SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
+      SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+    return SDValue();
+
+  LLVMContext &C = *DAG.getContext();
+  EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+  EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+
+  Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+  Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+
+  // Src0 and Src1 are zero extended, so they're always +ve if signed.
+  //
+  // sub can produce a -ve from two +ve operands, so it needs sign
+  // extended. Other nodes produce a +ve from two +ve operands, so zero extend
+  // instead.
+  unsigned OuterExtend =
+      N->getOpcode() == ISD::SUB ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
----------------
lukel97 wrote:

Out of all the binary ops we can do this transform with it looks like sub is the only one that actually requires sign_extend: https://alive2.llvm.org/ce/z/PEPb9j which was why I was just handling it as a specific case

https://github.com/llvm/llvm-project/pull/86248


More information about the llvm-commits mailing list