[llvm] [RISCV] Use vwadd.vx for splat vector with extension (PR #87249)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 21:08:34 PDT 2024


lukel97 wrote:

Nice, this would be useful to have. I had done something similar last week but just by copying and pasting the code from RISCVISD:VMV_V_X:

```diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f693cbd3bea5..24ecd5d57b6c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13596,7 +13596,8 @@ struct NodeExtensionHelper {
 
   /// Check if this instance represents a splat.
   bool isSplat() const {
-    return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+    return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+           OrigOperand.getOpcode() == ISD::SPLAT_VECTOR;
   }
 
   /// Get the extended opcode.
@@ -13643,6 +13644,9 @@ struct NodeExtensionHelper {
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
                          DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+    case ISD::SPLAT_VECTOR:
+      // Operand is implicitly truncated
+      return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
     default:
       // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
       // and that operand should already have the right NarrowVT so no
@@ -13817,6 +13821,37 @@ struct NodeExtensionHelper {
       Mask = OrigOperand.getOperand(1);
       VL = OrigOperand.getOperand(2);
       break;
+    case ISD::SPLAT_VECTOR: {
+      MVT VT = OrigOperand.getSimpleValueType();
+      if (!VT.isVector())
+        break;
+      EnforceOneUse = false;
+      std::tie(Mask, VL) =
+          getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
+
+      // Get the scalar value.
+      SDValue Op = OrigOperand.getOperand(0);
+
+      unsigned EltBits = VT.getScalarSizeInBits();
+      unsigned ScalarBits = Op.getValueSizeInBits();
+      // Make sure we're getting all element bits from the scalar register.
+      // FIXME: Support implicit sign extension of vmv.v.x?
+      if (ScalarBits < EltBits)
+        break;
+
+      unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+      // If the narrow type cannot be expressed with a legal VMV,
+      // this is not a valid candidate.
+      if (NarrowSize < 8)
+        break;
+
+      if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
+        SupportsSExt = true;
+      if (DAG.MaskedValueIsZero(Op,
+                                APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
+        SupportsZExt = true;
+      break;
+    }
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.

```

The trick was that I needed to make sure to add ISD::SPLAT_VECTOR to `isSplat`, otherwise it ended up widening splats to .wv unnecessarily:

```diff
 define <vscale x 1 x i16> @vadd_vx_nxv1i16_0(<vscale x 1 x i16> %va) {
 ; CHECK-LABEL: vadd_vx_nxv1i16_0:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e16, mf4, ta, ma
-; CHECK-NEXT:    vadd.vi v8, v8, -1
+; CHECK-NEXT:    li a0, -1
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vwadd.wx v8, v8, a0
 ; CHECK-NEXT:    ret

```

https://github.com/llvm/llvm-project/pull/87249


More information about the llvm-commits mailing list