[llvm] [RISCV] Vector sub (zext, zext) -> sext (sub (zext, zext)) (PR #82455)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 10:33:33 PST 2024


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/82455

>From f17b2647e29e42d356199d82f300d8ae536a96b6 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 20 Feb 2024 18:15:41 -0800
Subject: [PATCH 1/2] [RISCV] Vector sub (zext, zext) -> sext (sub (zext,
 zext))

This is legal as long as the inner zext retains at least one bit
of increase so that the sub overflow case (0 - UINT_MAX) can be
represented.  Alive2 proof: https://alive2.llvm.org/ce/z/BKeV3W

For RVV, restrict this to power of two sizes with the operation
type being at least e8 to stick to legal extends.  We could
arguably handle i1 source types with some care if we wanted to.

This is likely profitable because it may allow us to perform the sub
instruction in a narrow LMUL (equivalently, in fewer DLEN-sized
pieces)  before widening for the user.  We could arguably avoid
narrowing below DLEN, but the transform should at worst introduce
one extra extend and one extra vsetvli toggle if the source
could previously be handled via loads explicit w/EEW.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 24 +++++++++++++-
 .../CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll | 32 +++++++++----------
 2 files changed, 39 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 874c851cd9147a..64e06e2648dc23 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12846,6 +12846,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineSubOfBoolean(N, DAG))
     return V;
 
+  EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   // fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
@@ -12853,7 +12854,6 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
       isNullConstant(N1.getOperand(1))) {
     ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
     if (CCVal == ISD::SETLT) {
-      EVT VT = N->getValueType(0);
       SDLoc DL(N);
       unsigned ShAmt = N0.getValueSizeInBits() - 1;
       return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
@@ -12861,6 +12861,28 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  // sub (zext, zext) -> sext (sub (zext, zext))
+  //   where the sum of the extend widths match, and the inner zexts
+  //   add at least one bit.  (For profitability on rvv, we use a
+  //   power of two for both inner and outer extend.)
+  if (VT.isVector() && N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() &&
+      N1.hasOneUse() && N0.getOpcode() == ISD::ZERO_EXTEND) {
+    SDValue Src0 = N0.getOperand(0);
+    SDValue Src1 = N1.getOperand(0);
+    EVT SrcVT = Src0.getValueType();
+    if (SrcVT == Src1.getValueType() &&
+        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2 &&
+        SrcVT.getScalarSizeInBits() >= 8) {
+      LLVMContext &C = *DAG.getContext();
+      EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+      EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+      Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+      Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+      return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
+                         DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
+    }
+  }
+
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
   //      (select lhs, rhs, cc, x, (sub x, y))
   return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index 574c2652ccfacd..a084b5383b4030 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -385,12 +385,12 @@ define <32 x i64> @vwsubu_v32i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -899,12 +899,12 @@ define <2 x i64> @vwsubu_vx_v2i64_i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -917,12 +917,12 @@ define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -935,12 +935,12 @@ define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i16>, ptr %x
   %b = load <2 x i16>, ptr %y

>From 1fd007533adb52839d56a96f7fbcf3de8d257ebf Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 22 Feb 2024 10:29:26 -0800
Subject: [PATCH 2/2] Add isTypeLegal checks

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 64e06e2648dc23..1aa57c98ad5d96 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12865,14 +12865,15 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
   //   where the sum of the extend widths match, and the inner zexts
   //   add at least one bit.  (For profitability on rvv, we use a
   //   power of two for both inner and outer extend.)
-  if (VT.isVector() && N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() &&
-      N1.hasOneUse() && N0.getOpcode() == ISD::ZERO_EXTEND) {
+  if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
+      N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() && N1.hasOneUse() &&
+      N0.getOpcode() == ISD::ZERO_EXTEND) {
     SDValue Src0 = N0.getOperand(0);
     SDValue Src1 = N1.getOperand(0);
     EVT SrcVT = Src0.getValueType();
-    if (SrcVT == Src1.getValueType() &&
-        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2 &&
-        SrcVT.getScalarSizeInBits() >= 8) {
+    if (Subtarget.getTargetLowering()->isTypeLegal(SrcVT) &&
+        SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
+        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
       LLVMContext &C = *DAG.getContext();
       EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
       EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());



More information about the llvm-commits mailing list