[llvm] 0f20b9b - [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL (#87997)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 9 01:04:14 PDT 2024


Author: Luke Lau
Date: 2024-04-09T16:04:10+08:00
New Revision: 0f20b9b92f5333a90cf7cd19d7ec2e27ee3eac06

URL: https://github.com/llvm/llvm-project/commit/0f20b9b92f5333a90cf7cd19d7ec2e27ee3eac06
DIFF: https://github.com/llvm/llvm-project/commit/0f20b9b92f5333a90cf7cd19d7ec2e27ee3eac06.diff

LOG: [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL (#87997)

In NodeExtensionHelper we keep track of the VL and mask of the operand
being extended and check that they are the same as the root node's.
However for the nodes that we support, none of them have a passthru
operand with the exception of RISCV::VMV_V_X_VL, but we check that it's
passthru is undef anyway.

So it's safe to just discard the extend node's VL and mask and just use
the root's instead. (This is the same type of reasoning we use to treat
any vmset_vl as an all ones mask)

This allows us to match some more cases where we mix VP/non-VP/VL nodes,
but these don't seem to appear in practice. The main benefit from this
would be to simplify the code.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b426f1a7b3791d..c9727a3e5a8db3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13552,7 +13552,7 @@ enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
 /// NodeExtensionHelper for `a` and one for `b`.
 ///
 /// This class abstracts away how the extension is materialized and
-/// how its Mask, VL, number of users affect the combines.
+/// how its number of users affect the combines.
 ///
 /// In particular:
 /// - VWADD_W is conceptually == add(op0, sext(op1))
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
-  /// Records if this operand's mask needs to match the mask of the operation
-  /// that it will fold into.
-  bool CheckMask;
-  /// Value of the Mask for this operand.
-  /// It may be SDValue().
-  SDValue Mask;
-  /// Value of the vector length operand.
-  /// It may be SDValue().
-  SDValue VL;
   /// Original value that this NodeExtensionHelper represents.
   SDValue OrigOperand;
 
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
     SupportsSExt = false;
     SupportsFPExt = false;
     EnforceOneUse = true;
-    CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
+    // For the nodes we handle below, we end up using their inputs directly: see
+    // getSource(). However since they either don't have a passthru or we check
+    // that their passthru is undef, we can safely ignore their mask and VL.
     switch (Opc) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
 
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
-
-      SDLoc DL(Root);
-      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
       break;
     }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::FP_EXTEND_VL:
       SupportsFPExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
       EnforceOneUse = false;
-      CheckMask = false;
-      VL = OrigOperand.getOperand(2);
 
       // The operand is a splat of a scalar.
 
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
             Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
         SupportsFPExt =
             Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
-        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
-        CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
         EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if this operand is compatible with the given vector length \p VL.
-  bool isVLCompatible(SDValue VL) const {
-    return this->VL != SDValue() && this->VL == VL;
-  }
-
-  /// Check if this operand is compatible with the given \p Mask.
-  bool isMaskCompatible(SDValue Mask) const {
-    return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
-  }
-
   /// Helper function to get the Mask and VL from \p Root.
   static std::pair<SDValue, SDValue>
   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
-                              const RISCVSubtarget &Subtarget) const {
-    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
-    return isMaskCompatible(Mask) && isVLCompatible(VL);
-  }
-
   /// Helper function to check if \p N is commutative with respect to the
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS,
                                  uint8_t AllowExtMask, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
               const RISCVSubtarget &Subtarget) {
-  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
-
   if (RHS.SupportsFPExt)
     return CombineResult(
         NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
 
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
                        /*RHSExt=*/{ExtKind::ZExt});

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index a0b7726d3cb5e6..433f5d2717e48e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -41,3 +41,61 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
 declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = add <vscale x 2 x i32> %x.sext, %y.sext
+  ret <vscale x 2 x i32> %add
+}


        


More information about the llvm-commits mailing list