[llvm] ac518c7 - [RISCV] Vector sub (zext, zext) -> sext (sub (zext, zext)) (#82455)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 16:17:52 PST 2024


Author: Philip Reames
Date: 2024-02-22T16:17:48-08:00
New Revision: ac518c7c9916a6fde1d898b8c53b74298fd00d5f

URL: https://github.com/llvm/llvm-project/commit/ac518c7c9916a6fde1d898b8c53b74298fd00d5f
DIFF: https://github.com/llvm/llvm-project/commit/ac518c7c9916a6fde1d898b8c53b74298fd00d5f.diff

LOG: [RISCV] Vector sub (zext, zext) -> sext (sub (zext, zext)) (#82455)

This is legal as long as the inner zext retains at least one bit of
increase so that the sub overflow case (0 - UINT_MAX) can be
represented. Alive2 proof: https://alive2.llvm.org/ce/z/BKeV3W

For RVV, restrict this to power of two sizes with the operation type
being at least e8 to stick to legal extends. We could arguably handle i1
source types with some care if we wanted to.

This is likely profitable because it may allow us to perform the sub
instruction in a narrow LMUL (equivalently, in fewer DLEN-sized pieces)
before widening for the user. We could arguably avoid narrowing below
DLEN, but the transform should at worst introduce one extra extend and
one extra vsetvli toggle if the source could previously be handled via
loads explicit w/EEW.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6bf02cf8c0f875..5c67aaf6785669 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12887,6 +12887,7 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineSubOfBoolean(N, DAG))
     return V;
 
+  EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   // fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
@@ -12894,7 +12895,6 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
       isNullConstant(N1.getOperand(1))) {
     ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
     if (CCVal == ISD::SETLT) {
-      EVT VT = N->getValueType(0);
       SDLoc DL(N);
       unsigned ShAmt = N0.getValueSizeInBits() - 1;
       return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
@@ -12902,6 +12902,29 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  // sub (zext, zext) -> sext (sub (zext, zext))
+  //   where the sum of the extend widths match, and the inner zexts
+  //   add at least one bit.  (For profitability on rvv, we use a
+  //   power of two for both inner and outer extend.)
+  if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
+      N0.getOpcode() == N1.getOpcode() && N0.getOpcode() == ISD::ZERO_EXTEND &&
+      N0.hasOneUse() && N1.hasOneUse()) {
+    SDValue Src0 = N0.getOperand(0);
+    SDValue Src1 = N1.getOperand(0);
+    EVT SrcVT = Src0.getValueType();
+    if (Subtarget.getTargetLowering()->isTypeLegal(SrcVT) &&
+        SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
+        SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
+      LLVMContext &C = *DAG.getContext();
+      EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+      EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
+      Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+      Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+      return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
+                         DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
+    }
+  }
+
   // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
   //      (select lhs, rhs, cc, x, (sub x, y))
   return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index 574c2652ccfacd..a084b5383b4030 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -385,12 +385,12 @@ define <32 x i64> @vwsubu_v32i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -899,12 +899,12 @@ define <2 x i64> @vwsubu_vx_v2i64_i64(ptr %x, ptr %y) nounwind {
 define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i32_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -917,12 +917,12 @@ define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
-; CHECK-NEXT:    vzext.vf4 v10, v8
-; CHECK-NEXT:    vzext.vf4 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf4 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, ptr %x
   %b = load <2 x i8>, ptr %y
@@ -935,12 +935,12 @@ define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
 define <2 x i64> @vwsubu_v2i64_of_v2i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwsubu_v2i64_of_v2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v11, v9
-; CHECK-NEXT:    vwsubu.vv v8, v10, v11
+; CHECK-NEXT:    vwsubu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v10
 ; CHECK-NEXT:    ret
   %a = load <2 x i16>, ptr %x
   %b = load <2 x i16>, ptr %y


        


More information about the llvm-commits mailing list