[llvm] combineBinOp VLToVWBinOp VL removeMaskVLCheck (PR #87997)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 8 07:18:42 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
- Add tests where mismatching VL/masks prevents vwadd from being combined
- [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL
---
Full diff: https://github.com/llvm/llvm-project/pull/87997.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+3-49)
- (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll (+58)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b426f1a7b3791d..e13b3f3ca109fb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
- /// Records if this operand's mask needs to match the mask of the operation
- /// that it will fold into.
- bool CheckMask;
- /// Value of the Mask for this operand.
- /// It may be SDValue().
- SDValue Mask;
- /// Value of the vector length operand.
- /// It may be SDValue().
- SDValue VL;
/// Original value that this NodeExtensionHelper represents.
SDValue OrigOperand;
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
SupportsSExt = false;
SupportsFPExt = false;
EnforceOneUse = true;
- CheckMask = true;
unsigned Opc = OrigOperand.getOpcode();
+ // For the nodes we handle below, we end up using their inputs directly: see
+ // getSource(). However since they either don't have a passthru or we check
+ // that their passthru is undef, we can safely ignore their mask and VL.
switch (Opc) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
SupportsZExt = Opc == ISD::ZERO_EXTEND;
SupportsSExt = Opc == ISD::SIGN_EXTEND;
-
- SDLoc DL(Root);
- std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
break;
}
case RISCVISD::VZEXT_VL:
SupportsZExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
break;
case RISCVISD::VSEXT_VL:
SupportsSExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
break;
case RISCVISD::FP_EXTEND_VL:
SupportsFPExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
break;
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
EnforceOneUse = false;
- CheckMask = false;
- VL = OrigOperand.getOperand(2);
// The operand is a splat of a scalar.
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
SupportsFPExt =
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
- std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
- CheckMask = true;
// There's no existing extension here, so we don't have to worry about
// making sure it gets removed.
EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
}
}
- /// Check if this operand is compatible with the given vector length \p VL.
- bool isVLCompatible(SDValue VL) const {
- return this->VL != SDValue() && this->VL == VL;
- }
-
- /// Check if this operand is compatible with the given \p Mask.
- bool isMaskCompatible(SDValue Mask) const {
- return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
- }
-
/// Helper function to get the Mask and VL from \p Root.
static std::pair<SDValue, SDValue>
getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
}
}
- /// Check if the Mask and VL of this operand are compatible with \p Root.
- bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) const {
- auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
- return isMaskCompatible(Mask) && isVLCompatible(VL);
- }
-
/// Helper function to check if \p N is commutative with respect to the
/// foldings that are supported by this class.
static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS,
uint8_t AllowExtMask, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
- !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
- return std::nullopt;
if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
- return std::nullopt;
-
if (RHS.SupportsFPExt)
return CombineResult(
NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
if (!LHS.SupportsSExt || !RHS.SupportsZExt)
return std::nullopt;
- if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
- !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
- return std::nullopt;
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
/*RHSExt=*/{ExtKind::ZExt});
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index a0b7726d3cb5e6..433f5d2717e48e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -41,3 +41,61 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+ %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+ %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+ ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+ %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+ %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+ ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
+ %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+ %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+ ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+ %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+ %add = add <vscale x 2 x i32> %x.sext, %y.sext
+ ret <vscale x 2 x i32> %add
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/87997
More information about the llvm-commits
mailing list