[llvm] combineBinOp VLToVWBinOp VL removeMaskVLCheck (PR #87997)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 8 07:18:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

- Add tests where mismatching VL/masks prevents vwadd from being combined
- [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL


---
Full diff: https://github.com/llvm/llvm-project/pull/87997.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+3-49) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll (+58) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b426f1a7b3791d..e13b3f3ca109fb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
-  /// Records if this operand's mask needs to match the mask of the operation
-  /// that it will fold into.
-  bool CheckMask;
-  /// Value of the Mask for this operand.
-  /// It may be SDValue().
-  SDValue Mask;
-  /// Value of the vector length operand.
-  /// It may be SDValue().
-  SDValue VL;
   /// Original value that this NodeExtensionHelper represents.
   SDValue OrigOperand;
 
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
     SupportsSExt = false;
     SupportsFPExt = false;
     EnforceOneUse = true;
-    CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
+    // For the nodes we handle below, we end up using their inputs directly: see
+    // getSource(). However since they either don't have a passthru or we check
+    // that their passthru is undef, we can safely ignore their mask and VL.
     switch (Opc) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
 
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
-
-      SDLoc DL(Root);
-      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
       break;
     }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::FP_EXTEND_VL:
       SupportsFPExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
       EnforceOneUse = false;
-      CheckMask = false;
-      VL = OrigOperand.getOperand(2);
 
       // The operand is a splat of a scalar.
 
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
             Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
         SupportsFPExt =
             Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
-        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
-        CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
         EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if this operand is compatible with the given vector length \p VL.
-  bool isVLCompatible(SDValue VL) const {
-    return this->VL != SDValue() && this->VL == VL;
-  }
-
-  /// Check if this operand is compatible with the given \p Mask.
-  bool isMaskCompatible(SDValue Mask) const {
-    return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
-  }
-
   /// Helper function to get the Mask and VL from \p Root.
   static std::pair<SDValue, SDValue>
   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
-                              const RISCVSubtarget &Subtarget) const {
-    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
-    return isMaskCompatible(Mask) && isVLCompatible(VL);
-  }
-
   /// Helper function to check if \p N is commutative with respect to the
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS,
                                  uint8_t AllowExtMask, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
               const RISCVSubtarget &Subtarget) {
-  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
-
   if (RHS.SupportsFPExt)
     return CombineResult(
         NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
 
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
                        /*RHSExt=*/{ExtKind::ZExt});
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index a0b7726d3cb5e6..433f5d2717e48e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -41,3 +41,61 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
 declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = add <vscale x 2 x i32> %x.sext, %y.sext
+  ret <vscale x 2 x i32> %add
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/87997


More information about the llvm-commits mailing list