[llvm] [RISCV] Use vwadd.vx for splat vector with extension (PR #87249)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 1 21:08:34 PDT 2024
lukel97 wrote:
Nice, this would be useful to have. I had done something similar last week but just by copying and pasting the code from RISCVISD:VMV_V_X:
```diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f693cbd3bea5..24ecd5d57b6c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13596,7 +13596,8 @@ struct NodeExtensionHelper {
/// Check if this instance represents a splat.
bool isSplat() const {
- return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+ return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+ OrigOperand.getOpcode() == ISD::SPLAT_VECTOR;
}
/// Get the extended opcode.
@@ -13643,6 +13644,9 @@ struct NodeExtensionHelper {
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+ case ISD::SPLAT_VECTOR:
+ // Operand is implicitly truncated
+ return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
default:
// Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
// and that operand should already have the right NarrowVT so no
@@ -13817,6 +13821,37 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case ISD::SPLAT_VECTOR: {
+ MVT VT = OrigOperand.getSimpleValueType();
+ if (!VT.isVector())
+ break;
+ EnforceOneUse = false;
+ std::tie(Mask, VL) =
+ getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
+
+ // Get the scalar value.
+ SDValue Op = OrigOperand.getOperand(0);
+
+ unsigned EltBits = VT.getScalarSizeInBits();
+ unsigned ScalarBits = Op.getValueSizeInBits();
+ // Make sure we're getting all element bits from the scalar register.
+ // FIXME: Support implicit sign extension of vmv.v.x?
+ if (ScalarBits < EltBits)
+ break;
+
+ unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+ // If the narrow type cannot be expressed with a legal VMV,
+ // this is not a valid candidate.
+ if (NarrowSize < 8)
+ break;
+
+ if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
+ SupportsSExt = true;
+ if (DAG.MaskedValueIsZero(Op,
+ APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
+ SupportsZExt = true;
+ break;
+ }
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
```
The trick was that I needed to make sure to add ISD::SPLAT_VECTOR to `isSplat`, otherwise it ended up widening splats to .wv unnecessarily:
```diff
define <vscale x 1 x i16> @vadd_vx_nxv1i16_0(<vscale x 1 x i16> %va) {
; CHECK-LABEL: vadd_vx_nxv1i16_0:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli a0, zero, e16, mf4, ta, ma
-; CHECK-NEXT: vadd.vi v8, v8, -1
+; CHECK-NEXT: li a0, -1
+; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vwadd.wx v8, v8, a0
; CHECK-NEXT: ret
```
https://github.com/llvm/llvm-project/pull/87249
More information about the llvm-commits
mailing list