[llvm] [RISCV] Vector sub (zext, zext) -> sext (sub (zext, zext)) (PR #82455)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 20 18:42:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

This is legal as long as the inner zext retains at least one bit of increase so that the sub overflow case (0 - UINT_MAX) can be represented.  Alive2 proof: https://alive2.llvm.org/ce/z/BKeV3W

For RVV, restrict this to power of two sizes with the operation type being at least e8 to stick to legal extends.  We could arguably handle i1 source types with some care if we wanted to.

This is likely profitable because it may allow us to perform the sub instruction in a narrow LMUL (equivalently, in fewer DLEN-sized pieces)  before widening for the user.  We could arguably avoid narrowing below DLEN, but the transform should at worst introduce one extra extend and one extra vsetvli toggle if the source could previously be handled via loads explicit w/EEW.

---
Full diff: https://github.com/llvm/llvm-project/pull/82455.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+23-1) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll (+16-16) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 874c851cd9147a..64e06e2648dc23 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12846,6 +12846,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineSubOfBoolean(N, DAG))
     return V;
 
+  EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   // fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
@@ -12853,7 +12854,6 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
       isNullConstant(N1.getOperand(1))) {
     ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
     if (CCVal == ISD::SETLT) {
-      EVT VT = N->getValueType(0);
       SDLoc DL(N);
       unsigned ShAmt = N0.getValueSizeInBits() - 1;
       return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
@@ -12861,6 +12861,28 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  // sub (zext, zext) -> sext (sub (zext, zext))
+  //   where the sum of the extend widths match, and the inner zexts
+  //   add at least one bit.  (For profitability on rvv, we use a
+  //   power of two for both inner and outer extend.)
+  if (VT.isVector() && N0.getOpcode() == N1.getOpcode() && N0.hasOneUse() &&
+      N1.hasOneUse() && N0.getOpcode() == ISD::ZERO_EXTEND) {
+    SDValue Src0 = N0.getOperand(0);
+    SDValue Src1 = N1.getOperand(0);
+    EVT SrcVT = Src0.getValueType();
+    if (SrcVT == Src1.getValueType() &&
+        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2 &&
+        SrcVT.getScalarSizeInBits() >= 8) {
+      LLVMContext &C = *DAG.getContext();
+      EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+      EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+      Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+      Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+      return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
+                         DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
+    }
+  }
+
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
   //      (select lhs, rhs, cc, x, (sub x, y))
   return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index 574c2652ccfacd..a084b5383b4030 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -385,12 +385,12 @@ define <32 x i64> @vwsubu_v32i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -899,12 +899,12 @@ define <2 x i64> @vwsubu_vx_v2i64_i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -917,12 +917,12 @@ define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -935,12 +935,12 @@ define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i16>, ptr %x
   %b = load <2 x i16>, ptr %y

``````````

</details>


https://github.com/llvm/llvm-project/pull/82455


More information about the llvm-commits mailing list