[llvm] [RISCV] Generalize (sub zext, zext) -> (sext (sub zext, zext)) to add (PR #86248)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 24 21:54:37 PDT 2024


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/86248

>From cd4707b3d7b0326c488629f6c5e9bef028572d4b Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 7 Mar 2024 10:14:06 +0800
Subject: [PATCH 1/3] [RISCV] Generalize (sub zext, zext) -> (sext (sub zext,
 zext)) to add

This generalizes the combine added in #82455 to other binary ops, beginning
with adds in this patch.

Because the two zext operands are always +ve when treated as signed, and we don't get any overflow since the add is carried out in at least N * 2 bits of the narrow type, the result of the add will always be +ve. So we can use a zext for the outer extend, unlike sub which may produce a -ve result from two +ve operands.

Although we could still use sext for add, I plan to add support for other binary ops like mul in a later patch, but mul requires zext to be correct (because the maximum value will take up the full N * 2 bits). So I've opted to use zext here too for consistency.

Alive2 proof: https://alive2.llvm.org/ce/z/PRNsUM
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  76 ++++++++---
 .../CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll |  32 ++---
 llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll   | 128 ++++++++----------
 3 files changed, 126 insertions(+), 110 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a2fb0239e0af2..ded4490754b3b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12899,6 +12899,56 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT));
 }
 
+// add (zext, zext) -> zext (add (zext, zext))
+// sub (zext, zext) -> sext (sub (zext, zext))
+//
+// where the sum of the extend widths match, and the the range of the bin op
+// fits inside the width of the narrower bin op. (For profitability on rvv, we
+// use a power of two for both inner and outer extend.)
+//
+// TODO: Extend this to other binary ops
+static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG,
+                                  const RISCVSubtarget &Subtarget) {
+
+  EVT VT = N->getValueType(0);
+  if (!VT.isVector() || !Subtarget.getTargetLowering()->isTypeLegal(VT))
+    return SDValue();
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
+    return SDValue();
+  if (!N0.hasOneUse() || !N1.hasOneUse())
+    return SDValue();
+
+  SDValue Src0 = N0.getOperand(0);
+  SDValue Src1 = N1.getOperand(0);
+  EVT SrcVT = Src0.getValueType();
+  if (!Subtarget.getTargetLowering()->isTypeLegal(SrcVT) ||
+      SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
+      SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+    return SDValue();
+
+  LLVMContext &C = *DAG.getContext();
+  EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+  EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+
+  Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+  Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+
+  // Src0 and Src1 are zero extended, so they're always +ve if signed.
+  //
+  // sub can produce a -ve from two +ve operands, so it needs sign
+  // extended. Other nodes produce a +ve from two +ve operands, so zero extend
+  // instead.
+  unsigned OuterExtend =
+      N->getOpcode() == ISD::SUB ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+
+  return DAG.getNode(
+      OuterExtend, SDLoc(N), VT,
+      DAG.getNode(N->getOpcode(), SDLoc(N), NarrowVT, Src0, Src1));
+}
+
 // Try to turn (add (xor bool, 1) -1) into (neg bool).
 static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) {
   SDValue N0 = N->getOperand(0);
@@ -12936,6 +12986,8 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
     return V;
   if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
     return V;
+  if (SDValue V = combineBinOpOfZExt(N, DAG, Subtarget))
+    return V;
 
   // fold (add (select lhs, rhs, cc, 0, y), x) ->
   //      (select lhs, rhs, cc, x, (add x, y))
@@ -13003,28 +13055,8 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  // sub (zext, zext) -> sext (sub (zext, zext))
-  //   where the sum of the extend widths match, and the inner zexts
-  //   add at least one bit.  (For profitability on rvv, we use a
-  //   power of two for both inner and outer extend.)
-  if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
-      N0.getOpcode() == N1.getOpcode() && N0.getOpcode() == ISD::ZERO_EXTEND &&
-      N0.hasOneUse() && N1.hasOneUse()) {
-    SDValue Src0 = N0.getOperand(0);
-    SDValue Src1 = N1.getOperand(0);
-    EVT SrcVT = Src0.getValueType();
-    if (Subtarget.getTargetLowering()->isTypeLegal(SrcVT) &&
-        SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
-        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
-      LLVMContext &C = *DAG.getContext();
-      EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
-      EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
-      Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
-      Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
-      return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
-                         DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
-    }
-  }
+  if (SDValue V = combineBinOpOfZExt(N, DAG, Subtarget))
+    return V;
 
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
   //      (select lhs, rhs, cc, x, (sub x, y))
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
index 57a72c639b334c..bc0bf5dd76ad45 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
@@ -385,12 +385,12 @@ define <32 x i64> @vwaddu_v32i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v2i32_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -912,12 +912,12 @@ define <4 x i64> @crash(<4 x i16> %x, <4 x i16> %y) {
 define <2 x i32> @vwaddu_v2i32_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v2i32_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -930,12 +930,12 @@ define <2 x i32> @vwaddu_v2i32_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwaddu_v2i64_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v2i64_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -948,12 +948,12 @@ define <2 x i64> @vwaddu_v2i64_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwaddu_v2i64_of_v2i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v2i64_of_v2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i16>, ptr %x
   %b = load <2 x i16>, ptr %y
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66a7eea18be504..0a7051633a19ae 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -435,10 +435,10 @@ define <vscale x 1 x i64> @vwadd_vv_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, <vsc
 define <vscale x 1 x i64> @vwaddu_vv_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, <vscale x 1 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv1i64_nxv1i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 1 x i16> %va to <vscale x 1 x i64>
   %vd = zext <vscale x 1 x i16> %vb to <vscale x 1 x i64>
@@ -468,11 +468,9 @@ define <vscale x 1 x i64> @vwaddu_vx_nxv1i64_nxv1i16(<vscale x 1 x i16> %va, i16
 ; CHECK-LABEL: vwaddu_vx_nxv1i64_nxv1i16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vx v9, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 1 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 1 x i16> %head, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -555,10 +553,10 @@ define <vscale x 2 x i64> @vwadd_vv_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, <vsc
 define <vscale x 2 x i64> @vwaddu_vv_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, <vscale x 2 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv2i64_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 2 x i16> %va to <vscale x 2 x i64>
   %vd = zext <vscale x 2 x i16> %vb to <vscale x 2 x i64>
@@ -588,11 +586,9 @@ define <vscale x 2 x i64> @vwaddu_vx_nxv2i64_nxv2i16(<vscale x 2 x i16> %va, i16
 ; CHECK-LABEL: vwaddu_vx_nxv2i64_nxv2i16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vx v10, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 2 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 2 x i16> %head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -675,10 +671,10 @@ define <vscale x 4 x i64> @vwadd_vv_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, <vsc
 define <vscale x 4 x i64> @vwaddu_vv_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, <vscale x 4 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv4i64_nxv4i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v12, v8
-; CHECK-NEXT:    vzext.vf2 v14, v9
-; CHECK-NEXT:    vwaddu.vv v8, v12, v14
+; CHECK-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vwaddu.vv v12, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v12
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 4 x i16> %va to <vscale x 4 x i64>
   %vd = zext <vscale x 4 x i16> %vb to <vscale x 4 x i64>
@@ -708,11 +704,9 @@ define <vscale x 4 x i64> @vwaddu_vx_nxv4i64_nxv4i16(<vscale x 4 x i16> %va, i16
 ; CHECK-LABEL: vwaddu_vx_nxv4i64_nxv4i16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v12, v8
-; CHECK-NEXT:    vzext.vf2 v14, v9
-; CHECK-NEXT:    vwaddu.vv v8, v12, v14
+; CHECK-NEXT:    vwaddu.vx v12, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v12
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 4 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 4 x i16> %head, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
@@ -795,10 +789,10 @@ define <vscale x 8 x i64> @vwadd_vv_nxv8i64_nxv8i16(<vscale x 8 x i16> %va, <vsc
 define <vscale x 8 x i64> @vwaddu_vv_nxv8i64_nxv8i16(<vscale x 8 x i16> %va, <vscale x 8 x i16> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv8i64_nxv8i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vzext.vf2 v16, v8
-; CHECK-NEXT:    vzext.vf2 v20, v10
-; CHECK-NEXT:    vwaddu.vv v8, v16, v20
+; CHECK-NEXT:    vsetvli a0, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwaddu.vv v16, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v16
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 8 x i16> %va to <vscale x 8 x i64>
   %vd = zext <vscale x 8 x i16> %vb to <vscale x 8 x i64>
@@ -828,11 +822,9 @@ define <vscale x 8 x i64> @vwaddu_vx_nxv8i64_nxv8i16(<vscale x 8 x i16> %va, i16
 ; CHECK-LABEL: vwaddu_vx_nxv8i64_nxv8i16:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vmv.v.x v10, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vzext.vf2 v16, v8
-; CHECK-NEXT:    vzext.vf2 v20, v10
-; CHECK-NEXT:    vwaddu.vv v8, v16, v20
+; CHECK-NEXT:    vwaddu.vx v16, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v16
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 8 x i16> poison, i16 %b, i16 0
   %splat = shufflevector <vscale x 8 x i16> %head, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer
@@ -915,10 +907,10 @@ define <vscale x 1 x i64> @vwadd_vv_nxv1i64_nxv1i8(<vscale x 1 x i8> %va, <vscal
 define <vscale x 1 x i64> @vwaddu_vv_nxv1i64_nxv1i8(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv1i64_nxv1i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vsetvli a0, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 1 x i8> %va to <vscale x 1 x i64>
   %vd = zext <vscale x 1 x i8> %vb to <vscale x 1 x i64>
@@ -948,11 +940,9 @@ define <vscale x 1 x i64> @vwaddu_vx_nxv1i64_nxv1i8(<vscale x 1 x i8> %va, i8 %b
 ; CHECK-LABEL: vwaddu_vx_nxv1i64_nxv1i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vx v9, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v9
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 1 x i8> poison, i8 %b, i8 0
   %splat = shufflevector <vscale x 1 x i8> %head, <vscale x 1 x i8> poison, <vscale x 1 x i32> zeroinitializer
@@ -1035,10 +1025,10 @@ define <vscale x 2 x i64> @vwadd_vv_nxv2i64_nxv2i8(<vscale x 2 x i8> %va, <vscal
 define <vscale x 2 x i64> @vwaddu_vv_nxv2i64_nxv2i8(<vscale x 2 x i8> %va, <vscale x 2 x i8> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv2i64_nxv2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vsetvli a0, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 2 x i8> %va to <vscale x 2 x i64>
   %vd = zext <vscale x 2 x i8> %vb to <vscale x 2 x i64>
@@ -1068,11 +1058,9 @@ define <vscale x 2 x i64> @vwaddu_vx_nxv2i64_nxv2i8(<vscale x 2 x i8> %va, i8 %b
 ; CHECK-LABEL: vwaddu_vx_nxv2i64_nxv2i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf4, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwaddu.vv v8, v10, v11
+; CHECK-NEXT:    vwaddu.vx v10, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 2 x i8> poison, i8 %b, i8 0
   %splat = shufflevector <vscale x 2 x i8> %head, <vscale x 2 x i8> poison, <vscale x 2 x i32> zeroinitializer
@@ -1155,10 +1143,10 @@ define <vscale x 4 x i64> @vwadd_vv_nxv4i64_nxv4i8(<vscale x 4 x i8> %va, <vscal
 define <vscale x 4 x i64> @vwaddu_vv_nxv4i64_nxv4i8(<vscale x 4 x i8> %va, <vscale x 4 x i8> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv4i64_nxv4i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v8
-; CHECK-NEXT:    vzext.vf4 v14, v9
-; CHECK-NEXT:    vwaddu.vv v8, v12, v14
+; CHECK-NEXT:    vsetvli a0, zero, e8, mf2, ta, ma
+; CHECK-NEXT:    vwaddu.vv v12, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v12
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 4 x i8> %va to <vscale x 4 x i64>
   %vd = zext <vscale x 4 x i8> %vb to <vscale x 4 x i64>
@@ -1188,11 +1176,9 @@ define <vscale x 4 x i64> @vwaddu_vx_nxv4i64_nxv4i8(<vscale x 4 x i8> %va, i8 %b
 ; CHECK-LABEL: vwaddu_vx_nxv4i64_nxv4i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v8
-; CHECK-NEXT:    vzext.vf4 v14, v9
-; CHECK-NEXT:    vwaddu.vv v8, v12, v14
+; CHECK-NEXT:    vwaddu.vx v12, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v12
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 4 x i8> poison, i8 %b, i8 0
   %splat = shufflevector <vscale x 4 x i8> %head, <vscale x 4 x i8> poison, <vscale x 4 x i32> zeroinitializer
@@ -1275,10 +1261,10 @@ define <vscale x 8 x i64> @vwadd_vv_nxv8i64_nxv8i8(<vscale x 8 x i8> %va, <vscal
 define <vscale x 8 x i64> @vwaddu_vv_nxv8i64_nxv8i8(<vscale x 8 x i8> %va, <vscale x 8 x i8> %vb) {
 ; CHECK-LABEL: vwaddu_vv_nxv8i64_nxv8i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vzext.vf4 v16, v8
-; CHECK-NEXT:    vzext.vf4 v20, v9
-; CHECK-NEXT:    vwaddu.vv v8, v16, v20
+; CHECK-NEXT:    vsetvli a0, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vwaddu.vv v16, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v16
 ; CHECK-NEXT:    ret
   %vc = zext <vscale x 8 x i8> %va to <vscale x 8 x i64>
   %vd = zext <vscale x 8 x i8> %vb to <vscale x 8 x i64>
@@ -1308,11 +1294,9 @@ define <vscale x 8 x i64> @vwaddu_vx_nxv8i64_nxv8i8(<vscale x 8 x i8> %va, i8 %b
 ; CHECK-LABEL: vwaddu_vx_nxv8i64_nxv8i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a1, zero, e8, m1, ta, ma
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vzext.vf4 v16, v8
-; CHECK-NEXT:    vzext.vf4 v20, v9
-; CHECK-NEXT:    vwaddu.vv v8, v16, v20
+; CHECK-NEXT:    vwaddu.vx v16, v8, a0
+; CHECK-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
+; CHECK-NEXT:    vzext.vf4 v8, v16
 ; CHECK-NEXT:    ret
   %head = insertelement <vscale x 8 x i8> poison, i8 %b, i8 0
   %splat = shufflevector <vscale x 8 x i8> %head, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer

>From a985e26680dcbc7c54e2a2c6067447a4090cbeeb Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 25 Mar 2024 12:47:58 +0800
Subject: [PATCH 2/3] Expand +ve/-ve

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ded4490754b3b2..f5695ece7f2826 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12936,11 +12936,11 @@ static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG,
   Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
   Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
 
-  // Src0 and Src1 are zero extended, so they're always +ve if signed.
+  // Src0 and Src1 are zero extended, so they're always positive if signed.
   //
-  // sub can produce a -ve from two +ve operands, so it needs sign
-  // extended. Other nodes produce a +ve from two +ve operands, so zero extend
-  // instead.
+  // sub can produce a negative from two positive operands, so it needs sign
+  // extended. Other nodes produce a positive from two negative operands, so
+  // zero extend instead.
   unsigned OuterExtend =
       N->getOpcode() == ISD::SUB ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
 

>From ba3bbd9d52f6f606e2c2d71cf6beb5df61517980 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 25 Mar 2024 12:49:44 +0800
Subject: [PATCH 3/3] Use DAG.getTargetLoweringInfo()

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f5695ece7f2826..fad5e2bc057b78 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12907,11 +12907,10 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
 // use a power of two for both inner and outer extend.)
 //
 // TODO: Extend this to other binary ops
-static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG,
-                                  const RISCVSubtarget &Subtarget) {
+static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
 
   EVT VT = N->getValueType(0);
-  if (!VT.isVector() || !Subtarget.getTargetLowering()->isTypeLegal(VT))
+  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
     return SDValue();
 
   SDValue N0 = N->getOperand(0);
@@ -12924,7 +12923,7 @@ static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG,
   SDValue Src0 = N0.getOperand(0);
   SDValue Src1 = N1.getOperand(0);
   EVT SrcVT = Src0.getValueType();
-  if (!Subtarget.getTargetLowering()->isTypeLegal(SrcVT) ||
+  if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
       SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
       SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
     return SDValue();
@@ -12986,7 +12985,7 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
     return V;
   if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
     return V;
-  if (SDValue V = combineBinOpOfZExt(N, DAG, Subtarget))
+  if (SDValue V = combineBinOpOfZExt(N, DAG))
     return V;
 
   // fold (add (select lhs, rhs, cc, 0, y), x) ->
@@ -13055,7 +13054,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  if (SDValue V = combineBinOpOfZExt(N, DAG, Subtarget))
+  if (SDValue V = combineBinOpOfZExt(N, DAG))
     return V;
 
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->



More information about the llvm-commits mailing list