[llvm] [RISCV] Use vwadd.vx for splat vector with extension (PR #87249)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 9 01:07:42 PDT 2024


================
@@ -13781,6 +13784,54 @@ struct NodeExtensionHelper {
   /// Check if this node needs to be fully folded or extended for all users.
   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
 
+  void fillUpExtensionSupportForSplat(SDNode *Root, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
+    unsigned Opc = OrigOperand.getOpcode();
+    MVT VT = OrigOperand.getSimpleValueType();
+
+    assert((Opc == ISD::SPLAT_VECTOR || Opc == RISCVISD::VMV_V_X_VL) &&
+           "Unexpected Opcode");
+
+    // The pasthru must be undef for tail agnostic.
+    if (Opc == RISCVISD::VMV_V_X_VL && !OrigOperand.getOperand(0).isUndef())
+      return;
+
+    // Get the scalar value.
+    SDValue Op = Opc == ISD::SPLAT_VECTOR ? OrigOperand.getOperand(0)
+                                          : OrigOperand.getOperand(1);
+
+    // See if we have enough sign bits or zero bits in the scalar to use a
+    // widening opcode by splatting to smaller element size.
+    unsigned EltBits = VT.getScalarSizeInBits();
+    unsigned ScalarBits = Op.getValueSizeInBits();
+    // Make sure we're getting all element bits from the scalar register.
+    // FIXME: Support implicit sign extension of vmv.v.x?
+    if (ScalarBits < EltBits)
+      return;
+
+    unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+    // If the narrow type cannot be expressed with a legal VMV,
+    // this is not a valid candidate.
+    if (NarrowSize < 8)
+      return;
+
+    if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
+      SupportsSExt = true;
+
+    if (DAG.MaskedValueIsZero(Op,
+                              APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
+      SupportsZExt = true;
+
+    EnforceOneUse = false;
+    CheckMask = Opc == ISD::SPLAT_VECTOR;
+
+    if (Opc == ISD::SPLAT_VECTOR)
+      std::tie(Mask, VL) =
+          getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
+    else
+      VL = OrigOperand.getOperand(2);
----------------
lukel97 wrote:

This needs rebased after #87997, sorry for conflict. But the good news is that we don't need to worry about the mask/vl anymore

https://github.com/llvm/llvm-project/pull/87249


More information about the llvm-commits mailing list