[llvm] [RISCV] Teach combineBinOpOfZExt to narrow based on known bits (PR #86680)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 08:28:32 PDT 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/86680

This extends the existing narrowing transform for binop (zext, zext) to use known zero bits from the source of the zext if the zext is not at greater than 2x in size.  This is essentially a generic narrowing for vector binops (currently add/sub) with operands known to be positive in the half-bitwidth w/a restriction to the case where we eliminate a source zext.

This patch is currently slightly WIP.  I want to add a few more tests, and will rebase.  I went ahead and posted it now as it seems to expose the same basic widening op matching issue as https://github.com/llvm/llvm-project/pull/86465.

>From 232dbc9e2becc7bcaff67753168704dccbc82378 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 26 Mar 2024 08:05:44 -0700
Subject: [PATCH] [RISCV] Teach combineBinOpOfZExt to narrow based on known
 bits

This extends the existing narrowing transform for binop (zext, zext)
to use known zero bits from the source of the zext if the zext is not
at greater than 2x in size.  This is essentially a generic narrowing
for vector binops (currently add/sub) with operands known to be
positive in the half-bitwidth w/a restriction to the case where we
eliminate a source zext.

This patch is currently slightly WIP.  I want to add a few more tests,
and will rebase.  I went ahead and posted it now as it seems to expose
the same basic widening op matching issue as https://github.com/llvm/llvm-project/pull/86465.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 36 +++++++++++----
 .../CodeGen/RISCV/rvv/fixed-vectors-sad.ll    | 45 +++++++++----------
 .../CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll |  7 +--
 .../CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll |  4 +-
 4 files changed, 55 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e6814c5f71a09b..507f5a600f51ab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12944,32 +12944,50 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
 static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
 
   EVT VT = N->getValueType(0);
-  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
+  if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
+      VT.getScalarSizeInBits() <= 8)
     return SDValue();
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
     return SDValue();
+  // TODO: Can relax these checks when we're not needing to insert a new extend
+  // on one side or the other..
   if (!N0.hasOneUse() || !N1.hasOneUse())
     return SDValue();
 
   SDValue Src0 = N0.getOperand(0);
   SDValue Src1 = N1.getOperand(0);
-  EVT SrcVT = Src0.getValueType();
-  if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
-      SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
-      SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
+  EVT Src0VT = Src0.getValueType();
+  EVT Src1VT = Src0.getValueType();
+
+  if (!DAG.getTargetLoweringInfo().isTypeLegal(Src0VT) ||
+      !DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
     return SDValue();
 
+  unsigned HalfBitWidth =  VT.getScalarSizeInBits() / 2;
+  if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
+    KnownBits Known = DAG.computeKnownBits(Src0);
+    if (Known.countMinLeadingZeros() <= HalfBitWidth)
+      return SDValue();
+  }
+  if (Src1VT.getScalarSizeInBits() >= HalfBitWidth) {
+    KnownBits Known = DAG.computeKnownBits(Src0);
+    if (Known.countMinLeadingZeros() <= HalfBitWidth)
+      return SDValue();
+  }
+
   LLVMContext &C = *DAG.getContext();
-  EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
+  EVT ElemVT = EVT::getIntegerVT(C, HalfBitWidth);
   EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
 
-  Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
-  Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
+  if (Src0VT != NarrowVT)
+    Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
+  if (Src1VT != NarrowVT)
+    Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
 
-  // Src0 and Src1 are zero extended, so they're always positive if signed.
+  // Src0 and Src1 are always positive if signed.
   //
   // sub can produce a negative from two positive operands, so it needs sign
   // extended. Other nodes produce a positive from two positive operands, so
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
index a4ab67f41595d4..19ade65db59f43 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
@@ -98,41 +98,38 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea
 ; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    vminu.vv v10, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vsub.vv v8, v8, v10
 ; CHECK-NEXT:    add a0, a0, a2
 ; CHECK-NEXT:    add a1, a1, a3
-; CHECK-NEXT:    vle8.v v10, (a0)
-; CHECK-NEXT:    vle8.v v11, (a1)
-; CHECK-NEXT:    vminu.vv v12, v8, v9
-; CHECK-NEXT:    vmaxu.vv v8, v8, v9
-; CHECK-NEXT:    vsub.vv v8, v8, v12
-; CHECK-NEXT:    vminu.vv v9, v10, v11
+; CHECK-NEXT:    vle8.v v9, (a0)
+; CHECK-NEXT:    vle8.v v10, (a1)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vle8.v v12, (a1)
+; CHECK-NEXT:    vminu.vv v13, v9, v10
+; CHECK-NEXT:    vmaxu.vv v9, v9, v10
+; CHECK-NEXT:    vsub.vv v9, v9, v13
+; CHECK-NEXT:    vminu.vv v10, v11, v12
+; CHECK-NEXT:    vmaxu.vv v11, v11, v12
 ; CHECK-NEXT:    add a0, a0, a2
 ; CHECK-NEXT:    add a1, a1, a3
 ; CHECK-NEXT:    vle8.v v12, (a0)
 ; CHECK-NEXT:    vle8.v v13, (a1)
-; CHECK-NEXT:    vmaxu.vv v10, v10, v11
-; CHECK-NEXT:    vsub.vv v9, v10, v9
-; CHECK-NEXT:    vwaddu.vv v10, v9, v8
+; CHECK-NEXT:    vsub.vv v10, v11, v10
+; CHECK-NEXT:    vwaddu.vv v14, v9, v8
+; CHECK-NEXT:    vwaddu.wv v14, v14, v10
 ; CHECK-NEXT:    vminu.vv v8, v12, v13
 ; CHECK-NEXT:    vmaxu.vv v9, v12, v13
 ; CHECK-NEXT:    vsub.vv v8, v9, v8
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    add a0, a0, a2
-; CHECK-NEXT:    add a1, a1, a3
-; CHECK-NEXT:    vle8.v v9, (a0)
-; CHECK-NEXT:    vle8.v v12, (a1)
-; CHECK-NEXT:    vzext.vf2 v14, v8
-; CHECK-NEXT:    vwaddu.vv v16, v14, v10
-; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
-; CHECK-NEXT:    vminu.vv v8, v9, v12
-; CHECK-NEXT:    vmaxu.vv v9, v9, v12
-; CHECK-NEXT:    vsub.vv v8, v9, v8
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwaddu.wv v16, v16, v10
+; CHECK-NEXT:    vwaddu.wv v14, v14, v8
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
 ; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v16, v8
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwredsumu.vs v8, v14, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
 ; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
index bc0bf5dd76ad45..ccf76f97ac8b6b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
@@ -403,11 +403,12 @@ define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
 define <4 x i32> @vwaddu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
 ; CHECK-LABEL: vwaddu_v4i32_v4i8_v4i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwaddu.vv v8, v10, v9
+; CHECK-NEXT:    vwaddu.wv v9, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %a = load <4 x i8>, ptr %x
   %b = load <4 x i16>, ptr %y
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
index a084b5383b4030..7c53577309b576 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
@@ -407,7 +407,9 @@ define <4 x i32> @vwsubu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle16.v v9, (a1)
 ; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vwsubu.vv v8, v10, v9
+; CHECK-NEXT:    vsub.vv v9, v10, v9
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v8, v9
 ; CHECK-NEXT:    ret
   %a = load <4 x i8>, ptr %x
   %b = load <4 x i16>, ptr %y



More information about the llvm-commits mailing list