[llvm] [RISCV] Use vwadd.vx for splat vector with extension (PR #87249)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 8 04:28:13 PDT 2024


================
@@ -13781,6 +13784,57 @@ struct NodeExtensionHelper {
   /// Check if this node needs to be fully folded or extended for all users.
   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
 
+  void fillUpExtensionSupportForSplat(SDNode *Root, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
+    unsigned Opc = OrigOperand.getOpcode();
+    MVT VT = OrigOperand.getSimpleValueType();
+
+    assert((Opc == ISD::SPLAT_VECTOR || Opc == RISCVISD::VMV_V_X_VL) &&
+           "Unexpected Opcode");
+
+    if (Opc == ISD::SPLAT_VECTOR && !VT.isVector())
+      return;
+
+    // The pasthru must be undef for tail agnostic.
+    if (Opc == RISCVISD::VMV_V_X_VL && !OrigOperand.getOperand(0).isUndef())
+      return;
+
+    // Get the scalar value.
+    SDValue Op = Opc == ISD::SPLAT_VECTOR ? OrigOperand.getOperand(0)
+                                          : OrigOperand.getOperand(1);
+
+    // See if we have enough sign bits or zero bits in the scalar to use a
+    // widening opcode by splatting to smaller element size.
+    unsigned EltBits = VT.getScalarSizeInBits();
+    unsigned ScalarBits = Op.getValueSizeInBits();
+    // Make sure we're getting all element bits from the scalar register.
+    // FIXME: Support implicit sign extension of vmv.v.x?
+    if (ScalarBits < EltBits)
+      return;
+
+    unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
+    // If the narrow type cannot be expressed with a legal VMV,
+    // this is not a valid candidate.
+    if (NarrowSize < 8)
+      return;
+
+    if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
+      SupportsSExt = true;
+
+    if (DAG.MaskedValueIsZero(Op,
+                              APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
+      SupportsZExt = true;
+
+    EnforceOneUse = false;
+    CheckMask = Opc == ISD::SPLAT_VECTOR;
+
+    if (Opc == ISD::SPLAT_VECTOR)
+      std::tie(Mask, VL) =
+          getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
----------------
sun-jacobi wrote:

> the mask and VL are the same

Can you please help explaining what is the meaning of this? 

https://github.com/llvm/llvm-project/pull/87249


More information about the llvm-commits mailing list