[llvm] [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL (PR #87997)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 9 00:56:26 PDT 2024


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/87997

>From 8f8e4edd4b5463df0979bd7c0d44ba46f912e8ca Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 8 Apr 2024 21:35:54 +0800
Subject: [PATCH 1/3] Add tests where mismatching VL/masks prevents vwadd from
 being combined

---
 llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll | 66 +++++++++++++++++++++++++
 1 file changed, 66 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index a0b7726d3cb5e6..80d61080d60ef7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -41,3 +41,69 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
 declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v9
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.wv v9, v10, v8, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v8
+; CHECK-NEXT:    vsext.vf2 v8, v9
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v8, v0.t
+; CHECK-NEXT:    ret
+  %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
+; CHECK-NEXT:    vsext.vf2 v10, v8, v0.t
+; CHECK-NEXT:    vsext.vf2 v8, v9, v0.t
+; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = add <vscale x 2 x i32> %x.sext, %y.sext
+  ret <vscale x 2 x i32> %add
+}

>From 6c83065a46a8d2cba8fc7cc9926937dc016d8b12 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Mon, 8 Apr 2024 21:47:03 +0800
Subject: [PATCH 2/3] [RISCV] Don't require mask or VL to be the same in
 combineBinOp_VLToVWBinOp_VL

In NodeExtensionHelper we keep track of the VL and mask of the operand being extended and check that they are the same as the root node's. However for the nodes that we support, none of them have a passthru operand with the exception of RISCV::VMV_V_X_VL, but we check that it's passthru is undef anyway.

So it's safe to just discard the extend node's VL and mask and just use the root's instead. (This is the same type of reasoning we use to treat any vmset_vl as an all ones mask)

This allows us to match some more cases where we mix VP/non-VP/VL nodes, but these don't seem to appear in practice. The main benefit from this would be to simplify the code.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 52 ++-------------------
 llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll     | 24 ++++------
 2 files changed, 11 insertions(+), 65 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b426f1a7b3791d..e13b3f3ca109fb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
-  /// Records if this operand's mask needs to match the mask of the operation
-  /// that it will fold into.
-  bool CheckMask;
-  /// Value of the Mask for this operand.
-  /// It may be SDValue().
-  SDValue Mask;
-  /// Value of the vector length operand.
-  /// It may be SDValue().
-  SDValue VL;
   /// Original value that this NodeExtensionHelper represents.
   SDValue OrigOperand;
 
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
     SupportsSExt = false;
     SupportsFPExt = false;
     EnforceOneUse = true;
-    CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
+    // For the nodes we handle below, we end up using their inputs directly: see
+    // getSource(). However since they either don't have a passthru or we check
+    // that their passthru is undef, we can safely ignore their mask and VL.
     switch (Opc) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
 
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
-
-      SDLoc DL(Root);
-      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
       break;
     }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::FP_EXTEND_VL:
       SupportsFPExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
       EnforceOneUse = false;
-      CheckMask = false;
-      VL = OrigOperand.getOperand(2);
 
       // The operand is a splat of a scalar.
 
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
             Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
         SupportsFPExt =
             Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
-        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
-        CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
         EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if this operand is compatible with the given vector length \p VL.
-  bool isVLCompatible(SDValue VL) const {
-    return this->VL != SDValue() && this->VL == VL;
-  }
-
-  /// Check if this operand is compatible with the given \p Mask.
-  bool isMaskCompatible(SDValue Mask) const {
-    return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
-  }
-
   /// Helper function to get the Mask and VL from \p Root.
   static std::pair<SDValue, SDValue>
   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
-                              const RISCVSubtarget &Subtarget) const {
-    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
-    return isMaskCompatible(Mask) && isVLCompatible(VL);
-  }
-
   /// Helper function to check if \p N is commutative with respect to the
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS,
                                  uint8_t AllowExtMask, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
               const RISCVSubtarget &Subtarget) {
-  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
-
   if (RHS.SupportsFPExt)
     return CombineResult(
         NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
 
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
                        /*RHSExt=*/{ExtKind::ZExt});
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index 80d61080d60ef7..433f5d2717e48e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -62,11 +62,9 @@ define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i1
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    slli a0, a0, 32
 ; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v9
 ; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
-; CHECK-NEXT:    vwadd.wv v9, v10, v8, v0.t
-; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
 ; CHECK-NEXT:    ret
   %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
   %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
@@ -77,13 +75,11 @@ define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i1
 define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
 ; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vsext.vf2 v8, v9
 ; CHECK-NEXT:    slli a0, a0, 32
 ; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8, v0.t
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
 ; CHECK-NEXT:    ret
   %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
   %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
@@ -94,13 +90,9 @@ define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16>
 define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
 ; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    slli a0, a0, 32
-; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vsetvli zero, a0, e32, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8, v0.t
-; CHECK-NEXT:    vsext.vf2 v8, v9, v0.t
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v10, v8
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v8, v10
 ; CHECK-NEXT:    ret
   %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
   %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)

>From cfbd6b0054ac82c16422108d4c1da297218f2bdd Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Tue, 9 Apr 2024 15:55:56 +0800
Subject: [PATCH 3/3] Update comment

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e13b3f3ca109fb..c9727a3e5a8db3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13552,7 +13552,7 @@ enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
 /// NodeExtensionHelper for `a` and one for `b`.
 ///
 /// This class abstracts away how the extension is materialized and
-/// how its Mask, VL, number of users affect the combines.
+/// how its number of users affect the combines.
 ///
 /// In particular:
 /// - VWADD_W is conceptually == add(op0, sext(op1))



More information about the llvm-commits mailing list