[llvm] [RISCV] Use vwadd.vx for splat vector with extension (PR #87249)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 08:56:20 PDT 2024


https://github.com/sun-jacobi created https://github.com/llvm/llvm-project/pull/87249

This patch allows `combineBinOp_VLToVWBinOp_VL` to handle patterns like `(splat_vector (sext op))` or `(splat_vector (zext op))`. Then we can use `vwadd.vx` and `vwadd.w` for such a case.

### Source code
```
define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) {
     %sb = sext i32 %b to i64
     %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
     %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
     %vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64>
     %ve = add <vscale x 8 x i64> %vc, %splat
     ret <vscale x 8 x i64> %ve
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/sq191PsT4)
```
vwadd_vx_splat_sext:
  sext.w a0, a0
  vsetvli a1, zero, e64, m8, ta, ma
  vmv.v.x v16, a0
  vsetvli zero, zero, e32, m4, ta, ma
  vwadd.wv v16, v16, v8
  vmv8r.v v8, v16
  ret
```
### After this patch
```
vwadd_vx_splat_sext
  vsetvli a1, zero, e32, m4, ta, ma
  vwadd.vx v16, v8, a0
  vmv8r.v v8, v16
  ret
```

>From 802663b0e123f5ee563fcad60ee4146730d0243c Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Apr 2024 00:47:02 +0900
Subject: [PATCH] [RISCV] use vwadd.vx for extended splat.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  33 +++++
 llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll | 153 ++++++++++++++++++++
 2 files changed, 186 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f693cbd3bea51e..f422ee53874e32 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13589,6 +13589,8 @@ struct NodeExtensionHelper {
     case RISCVISD::VZEXT_VL:
     case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
+    case ISD::SPLAT_VECTOR:
+      return OrigOperand.getOperand(0).getOperand(0);
     default:
       return OrigOperand;
     }
@@ -13640,6 +13642,8 @@ struct NodeExtensionHelper {
     case RISCVISD::VZEXT_VL:
     case RISCVISD::FP_EXTEND_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+    case ISD::SPLAT_VECTOR:
+      return DAG.getNode(ISD::SPLAT_VECTOR, DL, NarrowVT, Source, Mask, VL);
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
                          DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -13817,6 +13821,35 @@ struct NodeExtensionHelper {
       Mask = OrigOperand.getOperand(1);
       VL = OrigOperand.getOperand(2);
       break;
+    case ISD::SPLAT_VECTOR: {
+      SDValue ScalarOp = OrigOperand.getOperand(0);
+      unsigned ScalarOpc = ScalarOp.getOpcode();
+
+      MVT ScalarVT = ScalarOp.getSimpleValueType();
+      unsigned ScalarSize = ScalarVT.getScalarSizeInBits();
+      unsigned NarrowSize = ScalarSize / 2;
+
+      // Ensuring the scalar element is legal.
+      if (NarrowSize < 8)
+        break;
+
+      SupportsSExt = ScalarOpc == ISD::SIGN_EXTEND_INREG;
+
+      if (ScalarOpc == ISD::AND) {
+        if (ConstantSDNode *MaskNode =
+                dyn_cast<ConstantSDNode>(ScalarOp.getOperand(1)))
+          SupportsZExt = MaskNode->getAPIntValue() ==
+                         APInt::getBitsSet(ScalarSize, 0, NarrowSize);
+      }
+
+      EnforceOneUse = false;
+      CheckMask = false;
+
+      MVT VT = OrigOperand.getSimpleValueType();
+      SDLoc DL(Root);
+      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      break;
+    }
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66e6883dd1d3e3..eeb29285594477 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1466,3 +1466,156 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
   %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
   ret <vscale x 2 x i32> %or
 }
+
+define <vscale x 8 x i64> @vwadd_vx_splat_zext(<vscale x 8 x i32> %va, i32 %b) {
+; RV32-LABEL: vwadd_vx_splat_zext:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw zero, 12(sp)
+; RV32-NEXT:    sw a0, 8(sp)
+; RV32-NEXT:    addi a0, sp, 8
+; RV32-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV32-NEXT:    vlse64.v v16, (a0), zero
+; RV32-NEXT:    vwaddu.wv v16, v16, v8
+; RV32-NEXT:    vmv8r.v v8, v16
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: vwadd_vx_splat_zext:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT:    vwaddu.vx v16, v8, a0
+; RV64-NEXT:    vmv8r.v v8, v16
+; RV64-NEXT:    ret
+  %sb = zext i32 %b to i64
+  %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+  %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+  %vc = zext <vscale x 8 x i32> %va to <vscale x 8 x i64>
+  %ve = add <vscale x 8 x i64> %vc, %splat
+  ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i32> @vwadd_vx_splat_zext_i1(<vscale x 8 x i1> %va, i16 %b) {
+; RV32-LABEL: vwadd_vx_splat_zext_i1:
+; RV32:       # %bb.0:
+; RV32-NEXT:    slli a0, a0, 16
+; RV32-NEXT:    srli a0, a0, 16
+; RV32-NEXT:    vsetvli a1, zero, e32, m4, ta, mu
+; RV32-NEXT:    vmv.v.x v8, a0
+; RV32-NEXT:    vadd.vi v8, v8, 1, v0.t
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: vwadd_vx_splat_zext_i1:
+; RV64:       # %bb.0:
+; RV64-NEXT:    slli a0, a0, 48
+; RV64-NEXT:    srli a0, a0, 48
+; RV64-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT:    vmv.v.i v8, 0
+; RV64-NEXT:    vmerge.vim v8, v8, 1, v0
+; RV64-NEXT:    vadd.vx v8, v8, a0
+; RV64-NEXT:    ret
+  %sb = zext i16 %b to i32
+  %head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
+  %splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+  %vc = zext <vscale x 8 x i1> %va to <vscale x 8 x i32>
+  %ve = add <vscale x 8 x i32> %vc, %splat
+  ret <vscale x 8 x i32> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_wx_splat_zext(<vscale x 8 x i64> %va, i32 %b) {
+; RV32-LABEL: vwadd_wx_splat_zext:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw zero, 12(sp)
+; RV32-NEXT:    sw a0, 8(sp)
+; RV32-NEXT:    addi a0, sp, 8
+; RV32-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT:    vlse64.v v16, (a0), zero
+; RV32-NEXT:    vadd.vv v8, v8, v16
+; RV32-NEXT:    addi sp, sp, 16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: vwadd_wx_splat_zext:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT:    vwaddu.wx v8, v8, a0
+; RV64-NEXT:    ret
+  %sb = zext i32 %b to i64
+  %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+  %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+  %ve = add <vscale x 8 x i64> %va, %splat
+  ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) {
+; RV32-LABEL: vwadd_vx_splat_sext:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT:    vmv.v.x v16, a0
+; RV32-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; RV32-NEXT:    vwadd.wv v16, v16, v8
+; RV32-NEXT:    vmv8r.v v8, v16
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: vwadd_vx_splat_sext:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT:    vwadd.vx v16, v8, a0
+; RV64-NEXT:    vmv8r.v v8, v16
+; RV64-NEXT:    ret
+  %sb = sext i32 %b to i64
+  %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+  %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+  %vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64>
+  %ve = add <vscale x 8 x i64> %vc, %splat
+  ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i32> @vwadd_vx_splat_sext_i1(<vscale x 8 x i1> %va, i16 %b) {
+; RV32-LABEL: vwadd_vx_splat_sext_i1:
+; RV32:       # %bb.0:
+; RV32-NEXT:    slli a0, a0, 16
+; RV32-NEXT:    srai a0, a0, 16
+; RV32-NEXT:    vsetvli a1, zero, e32, m4, ta, mu
+; RV32-NEXT:    vmv.v.x v8, a0
+; RV32-NEXT:    li a0, 1
+; RV32-NEXT:    vsub.vx v8, v8, a0, v0.t
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: vwadd_vx_splat_sext_i1:
+; RV64:       # %bb.0:
+; RV64-NEXT:    slli a0, a0, 48
+; RV64-NEXT:    srai a0, a0, 48
+; RV64-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT:    vmv.v.i v8, 0
+; RV64-NEXT:    vmerge.vim v8, v8, 1, v0
+; RV64-NEXT:    vrsub.vx v8, v8, a0
+; RV64-NEXT:    ret
+  %sb = sext i16 %b to i32
+  %head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
+  %splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+  %vc = sext <vscale x 8 x i1> %va to <vscale x 8 x i32>
+  %ve = add <vscale x 8 x i32> %vc, %splat
+  ret <vscale x 8 x i32> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_wx_splat_sext(<vscale x 8 x i64> %va, i32 %b) {
+; RV32-LABEL: vwadd_wx_splat_sext:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT:    vadd.vx v8, v8, a0
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: vwadd_wx_splat_sext:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT:    vwadd.wx v8, v8, a0
+; RV64-NEXT:    ret
+  %sb = sext i32 %b to i64
+  %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+  %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+  %ve = add <vscale x 8 x i64> %va, %splat
+  ret <vscale x 8 x i64> %ve
+}



More information about the llvm-commits mailing list