[llvm] [RISCV] Teach combineBinOpOfZExt to narrow based on known bits (PR #86680)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 08:28:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

This extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size.  This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext.

This patch is currently slightly WIP.  I want to add a few more tests, and will rebase.  I went ahead and posted it now as it seems to expose the same basic widening op matching issue as https://github.com/llvm/llvm-project/pull/86465.

---
Full diff: https://github.com/llvm/llvm-project/pull/86680.diff


4 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+27-9) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll (+21-24) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll (+4-3) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll (+3-1) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e6814c5f71a09b..507f5a600f51ab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12944,32 +12944,50 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
 static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
 
   EVT VT = N->getValueType(0);
-  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
+  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+      VT.getScalarSizeInBits() <= 8)
     return SDValue();
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
     return SDValue();
+  // TODO: Can relax these checks when we're not needing to insert a new extend
+  // on one side or the other..
   if (!N0.hasOneUse() || !N1.hasOneUse())
     return SDValue();
 
   SDValue Src0 = N0.getOperand(0);
   SDValue Src1 = N1.getOperand(0);
-  EVT SrcVT = Src0.getValueType();
-  if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
-      SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
-      SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+  EVT Src0VT = Src0.getValueType();
+  EVT Src1VT = Src0.getValueType();
+
+  if (!DAG.getTargetLoweringInfo().isTypeLegal(Src0VT) ||
+      !DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
     return SDValue();
 
+  unsigned HalfBitWidth =  VT.getScalarSizeInBits() / 2;
+  if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
+    KnownBits Known = DAG.computeKnownBits(Src0);
+    if (Known.countMinLeadingZeros() <= HalfBitWidth)
+      return SDValue();
+  }
+  if (Src1VT.getScalarSizeInBits() >= HalfBitWidth) {
+    KnownBits Known = DAG.computeKnownBits(Src0);
+    if (Known.countMinLeadingZeros() <= HalfBitWidth)
+      return SDValue();
+  }
+
   LLVMContext &C = *DAG.getContext();
-  EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+  EVT ElemVT = EVT::getIntegerVT(C, HalfBitWidth);
   EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
 
-  Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
-  Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+  if (Src0VT != NarrowVT)
+    Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+  if (Src1VT != NarrowVT)
+    Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
 
-  // Src0 and Src1 are zero extended, so they're always positive if signed.
+  // Src0 and Src1 are always positive if signed.
   //
   // sub can produce a negative from two positive operands, so it needs sign
   // extended. Other nodes produce a positive from two positive operands, so
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
index a4ab67f41595d4..19ade65db59f43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
@@ -98,41 +98,38 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea
 ; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    vminu.vv v10, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vsub.vv v8, v8, v10
 ; CHECK-NEXT:    add a0, a0, a2
 ; CHECK-NEXT:    add a1, a1, a3
-; CHECK-NEXT:    vle8.v v10, (a0)
-; CHECK-NEXT:    vle8.v v11, (a1)
-; CHECK-NEXT:    vminu.vv v12, v8, v9
-; CHECK-NEXT:    vmaxu.vv v8, v8, v9
-; CHECK-NEXT:    vsub.vv v8, v8, v12
-; CHECK-NEXT:    vminu.vv v9, v10, v11
+; CHECK-NEXT:    vle8.v v9, (a0)
+; CHECK-NEXT:    vle8.v v10, (a1)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vle8.v v12, (a1)
+; CHECK-NEXT:    vminu.vv v13, v9, v10
+; CHECK-NEXT:    vmaxu.vv v9, v9, v10
+; CHECK-NEXT:    vsub.vv v9, v9, v13
+; CHECK-NEXT:    vminu.vv v10, v11, v12
+; CHECK-NEXT:    vmaxu.vv v11, v11, v12
 ; CHECK-NEXT:    add a0, a0, a2
 ; CHECK-NEXT:    add a1, a1, a3
 ; CHECK-NEXT:    vle8.v v12, (a0)
 ; CHECK-NEXT:    vle8.v v13, (a1)
-; CHECK-NEXT:    vmaxu.vv v10, v10, v11
-; CHECK-NEXT:    vsub.vv v9, v10, v9
-; CHECK-NEXT:    vwaddu.vv v10, v9, v8
+; CHECK-NEXT:    vsub.vv v10, v11, v10
+; CHECK-NEXT:    vwaddu.vv v14, v9, v8
+; CHECK-NEXT:    vwaddu.wv v14, v14, v10
 ; CHECK-NEXT:    vminu.vv v8, v12, v13
 ; CHECK-NEXT:    vmaxu.vv v9, v12, v13
 ; CHECK-NEXT:    vsub.vv v8, v9, v8
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    add a0, a0, a2
-; CHECK-NEXT:    add a1, a1, a3
-; CHECK-NEXT:    vle8.v v9, (a0)
-; CHECK-NEXT:    vle8.v v12, (a1)
-; CHECK-NEXT:    vzext.vf2 v14, v8
-; CHECK-NEXT:    vwaddu.vv v16, v14, v10
-; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
-; CHECK-NEXT:    vminu.vv v8, v9, v12
-; CHECK-NEXT:    vmaxu.vv v9, v9, v12
-; CHECK-NEXT:    vsub.vv v8, v9, v8
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwaddu.wv v16, v16, v10
+; CHECK-NEXT:    vwaddu.wv v14, v14, v8
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
 ; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v16, v8
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwredsumu.vs v8, v14, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
 ; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
index bc0bf5dd76ad45..ccf76f97ac8b6b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
@@ -403,11 +403,12 @@ define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
 define <4 x i32> @vwaddu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v4i32_v4i8_v4i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwaddu.vv v8, v10, v9
+; CHECK-NEXT:    vwaddu.wv v9, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %a = load <4 x i8>, ptr %x
   %b = load <4 x i16>, ptr %y
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index a084b5383b4030..7c53577309b576 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -407,7 +407,9 @@ define <4 x i32> @vwsubu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
 ; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwsubu.vv v8, v10, v9
+; CHECK-NEXT:    vsub.vv v9, v10, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %a = load <4 x i8>, ptr %x
   %b = load <4 x i16>, ptr %y

``````````

</details>


https://github.com/llvm/llvm-project/pull/86680


More information about the llvm-commits mailing list