[llvm] [RISCV] Add DAG combine for forming VAADDU_VL from VP intrinsics. (PR #124848)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 15:33:59 PST 2025


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/124848

>From 8185a93a6b8b46be983a4b404ca0127550cd709d Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 16 Aug 2024 13:41:14 -0700
Subject: [PATCH 1/2] [RISCV} Add DAG combine for forming VAADDU_VL from VP
 intrinsics.

This adds a VP version of an existing DAG combine. I've put it in
RISCV since we would need to add a ISD::VP_AVGCEIL opcode otherwise.

This pattern appears in 525.264_r.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 119 ++++++++++++--
 llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll    | 169 ++++++++++++++++++++
 2 files changed, 276 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8d09e534b1858b..f004a00b19d205 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1526,18 +1526,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
                          ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
   if (Subtarget.hasVInstructions())
-    setTargetDAGCombine({ISD::FCOPYSIGN,     ISD::MGATHER,
-                         ISD::MSCATTER,      ISD::VP_GATHER,
-                         ISD::VP_SCATTER,    ISD::SRA,
-                         ISD::SRL,           ISD::SHL,
-                         ISD::STORE,         ISD::SPLAT_VECTOR,
-                         ISD::BUILD_VECTOR,  ISD::CONCAT_VECTORS,
-                         ISD::VP_STORE,      ISD::EXPERIMENTAL_VP_REVERSE,
-                         ISD::MUL,           ISD::SDIV,
-                         ISD::UDIV,          ISD::SREM,
-                         ISD::UREM,          ISD::INSERT_VECTOR_ELT,
-                         ISD::ABS,           ISD::CTPOP,
-                         ISD::VECTOR_SHUFFLE, ISD::VSELECT});
+    setTargetDAGCombine(
+        {ISD::FCOPYSIGN,    ISD::MGATHER,      ISD::MSCATTER,
+         ISD::VP_GATHER,    ISD::VP_SCATTER,   ISD::SRA,
+         ISD::SRL,          ISD::SHL,          ISD::STORE,
+         ISD::SPLAT_VECTOR, ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
+         ISD::VP_STORE,     ISD::VP_TRUNCATE,  ISD::EXPERIMENTAL_VP_REVERSE,
+         ISD::MUL,          ISD::SDIV,         ISD::UDIV,
+         ISD::SREM,         ISD::UREM,         ISD::INSERT_VECTOR_ELT,
+         ISD::ABS,          ISD::CTPOP,        ISD::VECTOR_SHUFFLE,
+         ISD::VSELECT});
 
   if (Subtarget.hasVendorXTHeadMemPair())
     setTargetDAGCombine({ISD::LOAD, ISD::STORE});
@@ -16373,6 +16371,101 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
       VPStore->isTruncatingStore(), VPStore->isCompressingStore());
 }
 
+// Peephole avgceil pattern.
+//   %1 = zext <N x i8> %a to <N x i32>
+//   %2 = zext <N x i8> %b to <N x i32>
+//   %3 = add nuw nsw <N x i32> %1, splat (i32 1)
+//   %4 = add nuw nsw <N x i32> %3, %2
+//   %5 = lshr <N x i32> %N, <i32 1 x N>
+//   %6 = trunc <N x i32> %5 to <N x i8>
+static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
+                                         const RISCVSubtarget &Subtarget) {
+  EVT VT = N->getValueType(0);
+
+  // Ignore fixed vectors.
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (!VT.isScalableVector() || !TLI.isTypeLegal(VT))
+    return SDValue();
+
+  SDValue In = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  // Input should be a vp_srl with same mask and VL.
+  if (In.getOpcode() != ISD::VP_SRL || In.getOperand(2) != Mask ||
+      In.getOperand(3) != VL)
+    return SDValue();
+
+  // Shift amount should be 1.
+  if (!isOneOrOneSplat(In.getOperand(1)))
+    return SDValue();
+
+  // Shifted value should be a vp_add with same mask and VL.
+  SDValue LHS = In.getOperand(0);
+  if (LHS.getOpcode() != ISD::VP_ADD || LHS.getOperand(2) != Mask ||
+      LHS.getOperand(3) != VL)
+    return SDValue();
+
+  SDValue Operands[3];
+  Operands[0] = LHS.getOperand(0);
+  Operands[1] = LHS.getOperand(1);
+
+  // Matches another VP_ADD with same VL and Mask.
+  auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) {
+    if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
+        V.getOperand(3) != VL)
+      return false;
+
+    Op0 = V.getOperand(0);
+    Op1 = V.getOperand(1);
+    return true;
+  };
+
+  // We need to find another VP_ADD in one of the operands.
+  SDValue Op0, Op1;
+  if (FindAdd(Operands[0], Op0, Op1))
+    Operands[0] = Operands[1];
+  else if (!FindAdd(Operands[1], Op0, Op1))
+    return SDValue();
+  Operands[2] = Op0;
+  Operands[1] = Op1;
+
+  // Now we have three operands of two additions. Check that one of them is a
+  // constant vector with ones.
+  auto I = llvm::find_if(Operands,
+                         [](const SDValue &Op) { return isOneOrOneSplat(Op); });
+  if (I == std::end(Operands))
+    return SDValue();
+  // We found a vector with ones, move if it to the end of the Operands array.
+  std::swap(Operands[I - std::begin(Operands)], Operands[2]);
+
+  // Make sure the other 2 operands can be promoted from the result type.
+  for (int i = 0; i < 2; ++i) {
+    if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND ||
+        Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL)
+      return SDValue();
+    // Input must be smaller than our result.
+    if (Operands[i].getOperand(0).getScalarValueSizeInBits() >
+        VT.getScalarSizeInBits())
+      return SDValue();
+  }
+
+  // Pattern is detected.
+  Op0 = Operands[0].getOperand(0);
+  Op1 = Operands[1].getOperand(0);
+  // Rebuild the zero extends if the inputs are smaller than our result.
+  if (Op0.getValueType() != VT)
+    Op0 =
+        DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, Op0, Mask, VL);
+  if (Op1.getValueType() != VT)
+    Op1 =
+        DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, Op1, Mask, VL);
+  // Build a VAADDU with RNU rounding mode.
+  SDLoc DL(N);
+  return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT,
+                     {Op0, Op1, DAG.getUNDEF(VT), Mask, VL});
+}
+
 // Convert from one FMA opcode to another based on whether we are negating the
 // multiply result and/or the accumulator.
 // NOTE: Only supports RVV operations with VL.
@@ -17930,6 +18023,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     if (SDValue V = combineTruncOfSraSext(N, DAG))
       return V;
     return combineTruncToVnclip(N, DAG, Subtarget);
+  case ISD::VP_TRUNCATE:
+    return performVP_TRUNCATECombine(N, DAG, Subtarget);
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll
new file mode 100644
index 00000000000000..42f8e38d5ac163
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-vaaddu.ll
@@ -0,0 +1,169 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s
+
+declare <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i16(<vscale x 2 x i16>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i16> @llvm.vp.trunc.nxv2i16.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16>, <vscale x 2 x i16>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16>, <vscale x 2 x i16>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i32> @llvm.vp.lshr.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
+
+define <vscale x 2 x i8> @vaaddu_1(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %xz, <vscale x 2 x i16> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %a, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16> %b, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+define <vscale x 2 x i8> @vaaddu_2(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %xz, <vscale x 2 x i16> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i16> %a, <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16> %b, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+define <vscale x 2 x i8> @vaaddu_3(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v9, v8, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %xz, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %a, <vscale x 2 x i16> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16> %b, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+define <vscale x 2 x i8> @vaaddu_4(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_4:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v9, v8, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %xz, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %yz, <vscale x 2 x i16> %a, <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16> %b, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+define <vscale x 2 x i8> @vaaddu_5(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_5:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v9, v8, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i16> %xz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %a, <vscale x 2 x i16> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16> %b, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+define <vscale x 2 x i8> @vaaddu_6(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_6:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v9, v8, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i16> @llvm.vp.zext.nxv2i16.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i16> %xz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i16> @llvm.vp.add.nxv2i16(<vscale x 2 x i16> %yz, <vscale x 2 x i16> %a, <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i16> @llvm.vp.lshr.nxv2i16(<vscale x 2 x i16> %b, <vscale x 2 x i16> shufflevector (<vscale x 2 x i16> insertelement (<vscale x 2 x i16> poison, i16 1, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i16(<vscale x 2 x i16> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+; Test where the size is reduced by 4x instead of 2x.
+define <vscale x 2 x i8> @vaaddu_7(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_7:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vsetvli zero, a0, e8, mf4, ta, ma
+; CHECK-NEXT:    vaaddu.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %xz, <vscale x 2 x i32> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %a, <vscale x 2 x i32> shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i32> @llvm.vp.lshr.nxv2i32(<vscale x 2 x i32> %b, <vscale x 2 x i32> shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i32(<vscale x 2 x i32> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}
+
+; Test where the zext can't be completely removed.
+define <vscale x 2 x i16> @vaaddu_8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vzext.vf2 v10, v8, v0.t
+; CHECK-NEXT:    csrwi vxrm, 0
+; CHECK-NEXT:    vzext.vf2 v8, v9, v0.t
+; CHECK-NEXT:    vaaddu.vv v8, v10, v8, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %xz, <vscale x 2 x i32> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %a, <vscale x 2 x i32> shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i32> @llvm.vp.lshr.nxv2i32(<vscale x 2 x i32> %b, <vscale x 2 x i32> shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i16> @llvm.vp.trunc.nxv2i16.nxv2i32(<vscale x 2 x i32> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i16> %d
+}
+
+; Negative test. The truncate has a smaller type than the zero extend.
+; TODO: Could still handle this by truncating after an i16 vaaddu.
+define <vscale x 2 x i8> @vaaddu_9(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: vaaddu_9:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vadd.vi v8, v10, 1, v0.t
+; CHECK-NEXT:    vsrl.vi v8, v8, 1, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vnsrl.wi v8, v8, 0, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vnsrl.wi v8, v8, 0, v0.t
+; CHECK-NEXT:    ret
+  %xz = call <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %vl)
+  %yz = call <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %vl)
+  %a = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %xz, <vscale x 2 x i32> %yz, <vscale x 2 x i1> %m, i32 %vl)
+  %b = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %a, <vscale x 2 x i32> shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %c = call <vscale x 2 x i32> @llvm.vp.lshr.nxv2i32(<vscale x 2 x i32> %b, <vscale x 2 x i32> shufflevector (<vscale x 2 x i32> insertelement (<vscale x 2 x i32> poison, i32 1, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %m, i32 %vl)
+  %d = call <vscale x 2 x i8> @llvm.vp.trunc.nxv2i8.nxv2i32(<vscale x 2 x i32> %c, <vscale x 2 x i1> %m, i32 %vl)
+  ret <vscale x 2 x i8> %d
+}

>From ffa8ae006ff1a7d86110e6fa86edf4dece3b1f41 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 29 Jan 2025 14:56:40 -0800
Subject: [PATCH 2/2] fixup! address review comments

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 47 +++++++++------------
 1 file changed, 19 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f004a00b19d205..5b2e8ceede53e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16376,7 +16376,7 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
 //   %2 = zext <N x i8> %b to <N x i32>
 //   %3 = add nuw nsw <N x i32> %1, splat (i32 1)
 //   %4 = add nuw nsw <N x i32> %3, %2
-//   %5 = lshr <N x i32> %N, <i32 1 x N>
+//   %5 = lshr <N x i32> %4, splat (i32 1)
 //   %6 = trunc <N x i32> %5 to <N x i8>
 static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
                                          const RISCVSubtarget &Subtarget) {
@@ -16407,28 +16407,24 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   SDValue Operands[3];
-  Operands[0] = LHS.getOperand(0);
-  Operands[1] = LHS.getOperand(1);
 
   // Matches another VP_ADD with same VL and Mask.
-  auto FindAdd = [&](SDValue V, SDValue &Op0, SDValue &Op1) {
+  auto FindAdd = [&](SDValue V, SDValue Other) {
     if (V.getOpcode() != ISD::VP_ADD || V.getOperand(2) != Mask ||
         V.getOperand(3) != VL)
       return false;
 
-    Op0 = V.getOperand(0);
-    Op1 = V.getOperand(1);
+    Operands[0] = Other;
+    Operands[1] = V.getOperand(1);
+    Operands[2] = V.getOperand(0);
     return true;
   };
 
   // We need to find another VP_ADD in one of the operands.
-  SDValue Op0, Op1;
-  if (FindAdd(Operands[0], Op0, Op1))
-    Operands[0] = Operands[1];
-  else if (!FindAdd(Operands[1], Op0, Op1))
+  SDValue LHS0 = LHS.getOperand(0);
+  SDValue LHS1 = LHS.getOperand(1);
+  if (!FindAdd(LHS0, LHS1) && !FindAdd(LHS1, LHS0))
     return SDValue();
-  Operands[2] = Op0;
-  Operands[1] = Op1;
 
   // Now we have three operands of two additions. Check that one of them is a
   // constant vector with ones.
@@ -16437,33 +16433,28 @@ static SDValue performVP_TRUNCATECombine(SDNode *N, SelectionDAG &DAG,
   if (I == std::end(Operands))
     return SDValue();
   // We found a vector with ones, move if it to the end of the Operands array.
-  std::swap(Operands[I - std::begin(Operands)], Operands[2]);
+  std::swap(*I, Operands[2]);
 
   // Make sure the other 2 operands can be promoted from the result type.
-  for (int i = 0; i < 2; ++i) {
-    if (Operands[i].getOpcode() != ISD::VP_ZERO_EXTEND ||
-        Operands[i].getOperand(1) != Mask || Operands[i].getOperand(2) != VL)
+  for (SDValue Op : drop_end(Operands)) {
+    if (Op.getOpcode() != ISD::VP_ZERO_EXTEND || Op.getOperand(1) != Mask ||
+        Op.getOperand(2) != VL)
       return SDValue();
     // Input must be smaller than our result.
-    if (Operands[i].getOperand(0).getScalarValueSizeInBits() >
-        VT.getScalarSizeInBits())
+    if (Op.getOperand(0).getScalarValueSizeInBits() > VT.getScalarSizeInBits())
       return SDValue();
   }
 
   // Pattern is detected.
-  Op0 = Operands[0].getOperand(0);
-  Op1 = Operands[1].getOperand(0);
-  // Rebuild the zero extends if the inputs are smaller than our result.
-  if (Op0.getValueType() != VT)
-    Op0 =
-        DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT, Op0, Mask, VL);
-  if (Op1.getValueType() != VT)
-    Op1 =
-        DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT, Op1, Mask, VL);
+  // Rebuild the zero extends in case the inputs are smaller than our result.
+  SDValue NewOp0 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[0]), VT,
+                               Operands[0].getOperand(0), Mask, VL);
+  SDValue NewOp1 = DAG.getNode(ISD::VP_ZERO_EXTEND, SDLoc(Operands[1]), VT,
+                               Operands[1].getOperand(0), Mask, VL);
   // Build a VAADDU with RNU rounding mode.
   SDLoc DL(N);
   return DAG.getNode(RISCVISD::AVGCEILU_VL, DL, VT,
-                     {Op0, Op1, DAG.getUNDEF(VT), Mask, VL});
+                     {NewOp0, NewOp1, DAG.getUNDEF(VT), Mask, VL});
 }
 
 // Convert from one FMA opcode to another based on whether we are negating the



More information about the llvm-commits mailing list