[llvm] [RISCV] Use vwadd.vx for splat vector with extension (PR #87249)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Apr 6 03:13:03 PDT 2024
https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/87249
>From 49745cee2f3151f68013a42d96992d51a338ddef Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sat, 6 Apr 2024 19:12:36 +0900
Subject: [PATCH] [RISCV] use vwadd.vx for extended splat.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 98 ++++++++-----
llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll | 155 ++++++++++++++++++++
2 files changed, 214 insertions(+), 39 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 279d8a435a04ca..124dcbc1067fe4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13597,6 +13597,7 @@ struct NodeExtensionHelper {
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
+ case ISD::SPLAT_VECTOR:
return OrigOperand.getOperand(0);
default:
return OrigOperand;
@@ -13605,7 +13606,8 @@ struct NodeExtensionHelper {
/// Check if this instance represents a splat.
bool isSplat() const {
- return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+ return (OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL) ||
+ (OrigOperand.getOpcode() == ISD::SPLAT_VECTOR);
}
/// Get the extended opcode.
@@ -13649,6 +13651,8 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+ case ISD::SPLAT_VECTOR:
+ return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -13781,6 +13785,57 @@ struct NodeExtensionHelper {
/// Check if this node needs to be fully folded or extended for all users.
bool needToPromoteOtherUsers() const { return EnforceOneUse; }
+ void fillUpExtensionSupportForSplat(SDNode *Root, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ unsigned Opc = OrigOperand.getOpcode();
+ MVT VT = OrigOperand.getSimpleValueType();
+
+ assert((Opc == ISD::SPLAT_VECTOR || Opc == RISCVISD::VMV_V_X_VL) &&
+ "Unexpected Opcode");
+
+ if (Opc == ISD::SPLAT_VECTOR && !VT.isVector())
+ return;
+
+ // The pasthru must be undef for tail agnostic.
+ if (Opc == RISCVISD::VMV_V_X_VL && !OrigOperand.getOperand(0).isUndef())
+ return;
+
+ // Get the scalar value.
+ SDValue Op = Opc == ISD::SPLAT_VECTOR ? OrigOperand.getOperand(0)
+ : OrigOperand.getOperand(1);
+
+ // See if we have enough sign bits or zero bits in the scalar to use a
+ // widening opcode by splatting to smaller element size.
+ unsigned EltBits = VT.getScalarSizeInBits();
+ unsigned ScalarBits = Op.getValueSizeInBits();
+ // Make sure we're getting all element bits from the scalar register.
+ // FIXME: Support implicit sign extension of vmv.v.x?
+ if (ScalarBits < EltBits)
+ return;
+
+ unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+ // If the narrow type cannot be expressed with a legal VMV,
+ // this is not a valid candidate.
+ if (NarrowSize < 8)
+ return;
+
+ if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
+ SupportsSExt = true;
+
+ if (DAG.MaskedValueIsZero(Op,
+ APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
+ SupportsZExt = true;
+
+ EnforceOneUse = false;
+ CheckMask = Opc == ISD::SPLAT_VECTOR;
+
+ if (Opc == ISD::SPLAT_VECTOR)
+ std::tie(Mask, VL) =
+ getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
+ else
+ VL = OrigOperand.getOperand(2);
+ }
+
/// Helper method to set the various fields of this struct based on the
/// type of \p Root.
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -13826,45 +13881,10 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
- case RISCVISD::VMV_V_X_VL: {
- // Historically, we didn't care about splat values not disappearing during
- // combines.
- EnforceOneUse = false;
- CheckMask = false;
- VL = OrigOperand.getOperand(2);
-
- // The operand is a splat of a scalar.
-
- // The pasthru must be undef for tail agnostic.
- if (!OrigOperand.getOperand(0).isUndef())
- break;
-
- // Get the scalar value.
- SDValue Op = OrigOperand.getOperand(1);
-
- // See if we have enough sign bits or zero bits in the scalar to use a
- // widening opcode by splatting to smaller element size.
- MVT VT = Root->getSimpleValueType(0);
- unsigned EltBits = VT.getScalarSizeInBits();
- unsigned ScalarBits = Op.getValueSizeInBits();
- // Make sure we're getting all element bits from the scalar register.
- // FIXME: Support implicit sign extension of vmv.v.x?
- if (ScalarBits < EltBits)
- break;
-
- unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- // If the narrow type cannot be expressed with a legal VMV,
- // this is not a valid candidate.
- if (NarrowSize < 8)
- break;
-
- if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
- SupportsSExt = true;
- if (DAG.MaskedValueIsZero(Op,
- APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
- SupportsZExt = true;
+ case ISD::SPLAT_VECTOR:
+ case RISCVISD::VMV_V_X_VL:
+ fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
break;
- }
default:
break;
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 66e6883dd1d3e3..985424e3557b98 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1466,3 +1466,158 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
ret <vscale x 2 x i32> %or
}
+
+define <vscale x 8 x i64> @vwadd_vx_splat_zext(<vscale x 8 x i32> %va, i32 %b) {
+; RV32-LABEL: vwadd_vx_splat_zext:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw zero, 12(sp)
+; RV32-NEXT: sw a0, 8(sp)
+; RV32-NEXT: addi a0, sp, 8
+; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV32-NEXT: vlse64.v v16, (a0), zero
+; RV32-NEXT: vwaddu.wv v16, v16, v8
+; RV32-NEXT: vmv8r.v v8, v16
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_zext:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vwaddu.vx v16, v8, a0
+; RV64-NEXT: vmv8r.v v8, v16
+; RV64-NEXT: ret
+ %sb = zext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = zext <vscale x 8 x i32> %va to <vscale x 8 x i64>
+ %ve = add <vscale x 8 x i64> %vc, %splat
+ ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i32> @vwadd_vx_splat_zext_i1(<vscale x 8 x i1> %va, i16 %b) {
+; RV32-LABEL: vwadd_vx_splat_zext_i1:
+; RV32: # %bb.0:
+; RV32-NEXT: slli a0, a0, 16
+; RV32-NEXT: srli a0, a0, 16
+; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, mu
+; RV32-NEXT: vmv.v.x v8, a0
+; RV32-NEXT: vadd.vi v8, v8, 1, v0.t
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_zext_i1:
+; RV64: # %bb.0:
+; RV64-NEXT: slli a0, a0, 48
+; RV64-NEXT: srli a0, a0, 48
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, mu
+; RV64-NEXT: vmv.v.x v8, a0
+; RV64-NEXT: vadd.vi v8, v8, 1, v0.t
+; RV64-NEXT: ret
+ %sb = zext i16 %b to i32
+ %head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = zext <vscale x 8 x i1> %va to <vscale x 8 x i32>
+ %ve = add <vscale x 8 x i32> %vc, %splat
+ ret <vscale x 8 x i32> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_wx_splat_zext(<vscale x 8 x i64> %va, i32 %b) {
+; RV32-LABEL: vwadd_wx_splat_zext:
+; RV32: # %bb.0:
+; RV32-NEXT: addi sp, sp, -16
+; RV32-NEXT: .cfi_def_cfa_offset 16
+; RV32-NEXT: sw zero, 12(sp)
+; RV32-NEXT: sw a0, 8(sp)
+; RV32-NEXT: addi a0, sp, 8
+; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT: vlse64.v v16, (a0), zero
+; RV32-NEXT: vadd.vv v8, v8, v16
+; RV32-NEXT: addi sp, sp, 16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_wx_splat_zext:
+; RV64: # %bb.0:
+; RV64-NEXT: slli a0, a0, 32
+; RV64-NEXT: srli a0, a0, 32
+; RV64-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV64-NEXT: vadd.vx v8, v8, a0
+; RV64-NEXT: ret
+ %sb = zext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %ve = add <vscale x 8 x i64> %va, %splat
+ ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) {
+; RV32-LABEL: vwadd_vx_splat_sext:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT: vmv.v.x v16, a0
+; RV32-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; RV32-NEXT: vwadd.wv v16, v16, v8
+; RV32-NEXT: vmv8r.v v8, v16
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_sext:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
+; RV64-NEXT: vwadd.vx v16, v8, a0
+; RV64-NEXT: vmv8r.v v8, v16
+; RV64-NEXT: ret
+ %sb = sext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64>
+ %ve = add <vscale x 8 x i64> %vc, %splat
+ ret <vscale x 8 x i64> %ve
+}
+
+define <vscale x 8 x i32> @vwadd_vx_splat_sext_i1(<vscale x 8 x i1> %va, i16 %b) {
+; RV32-LABEL: vwadd_vx_splat_sext_i1:
+; RV32: # %bb.0:
+; RV32-NEXT: slli a0, a0, 16
+; RV32-NEXT: srai a0, a0, 16
+; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, mu
+; RV32-NEXT: vmv.v.x v8, a0
+; RV32-NEXT: li a0, 1
+; RV32-NEXT: vsub.vx v8, v8, a0, v0.t
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_vx_splat_sext_i1:
+; RV64: # %bb.0:
+; RV64-NEXT: slli a0, a0, 48
+; RV64-NEXT: srai a0, a0, 48
+; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, mu
+; RV64-NEXT: vmv.v.x v8, a0
+; RV64-NEXT: li a0, 1
+; RV64-NEXT: vsub.vx v8, v8, a0, v0.t
+; RV64-NEXT: ret
+ %sb = sext i16 %b to i32
+ %head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
+ %vc = sext <vscale x 8 x i1> %va to <vscale x 8 x i32>
+ %ve = add <vscale x 8 x i32> %vc, %splat
+ ret <vscale x 8 x i32> %ve
+}
+
+define <vscale x 8 x i64> @vwadd_wx_splat_sext(<vscale x 8 x i64> %va, i32 %b) {
+; RV32-LABEL: vwadd_wx_splat_sext:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV32-NEXT: vadd.vx v8, v8, a0
+; RV32-NEXT: ret
+;
+; RV64-LABEL: vwadd_wx_splat_sext:
+; RV64: # %bb.0:
+; RV64-NEXT: sext.w a0, a0
+; RV64-NEXT: vsetvli a1, zero, e64, m8, ta, ma
+; RV64-NEXT: vadd.vx v8, v8, a0
+; RV64-NEXT: ret
+ %sb = sext i32 %b to i64
+ %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
+ %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
+ %ve = add <vscale x 8 x i64> %va, %splat
+ ret <vscale x 8 x i64> %ve
+}
More information about the llvm-commits
mailing list