[llvm] [RISCV][ISel] Remove redundant vmerge for the vwadd. (PR #78403)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 17 07:22:08 PST 2024
https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/78403
>From c74c6c973eb08d0f7082a552217aa467125d9179 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 17 Jan 2024 16:14:58 +0900
Subject: [PATCH 1/3] [RISCV][Isel] fold (vwadd y, (select cond, x, 0)) ->
select cond (vwadd y, x), y.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 54 ++++++++++++++++++-
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 27 ++++++++++
2 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cb9ffabc41236e..a030538e5e8ba9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13457,6 +13457,56 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return InputRootReplacement;
}
+// (vwadd y, (select cond, x, 0)) -> select cond (vwadd y, x), y
+static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
+ unsigned Opc = N->getOpcode();
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
+ Opc == RISCVISD::VWADDU_W_VL);
+
+ SDValue VL = N->getOperand(4);
+ SDValue Y = N->getOperand(0);
+ SDValue Merge = N->getOperand(1);
+
+ if (Merge.getOpcode() != RISCVISD::VMERGE_VL)
+ return SDValue();
+
+ SDValue Cond = Merge->getOperand(0);
+ SDValue X = Merge->getOperand(1);
+ SDValue Z = Merge->getOperand(2);
+
+ if (Z.getOpcode() != ISD::INSERT_SUBVECTOR ||
+ !isNullConstant(Z.getOperand(2)))
+ return SDValue();
+
+ if (!Merge.hasOneUse())
+ return SDValue();
+
+ SmallVector<SDValue, 6> Ops(N->op_values());
+ Ops[0] = Y;
+ Ops[1] = X;
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+
+ SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags());
+ return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, DAG.getUNDEF(VT),
+ VL);
+}
+
+static SDValue performVWADD_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ unsigned Opc = N->getOpcode();
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
+ Opc == RISCVISD::VWADDU_W_VL);
+
+ if (Opc != RISCVISD::VWADD_VL) {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+ return V;
+ }
+
+ return combineVWADDSelect(N, DCI.DAG);
+}
+
// Helper function for performMemPairCombine.
// Try to combine the memory loads/stores LSNode1 and LSNode2
// into a single memory pair operation.
@@ -15500,9 +15550,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
return V;
return combineToVWMACC(N, DAG, Subtarget);
- case RISCVISD::SUB_VL:
+ case RISCVISD::VWADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ return performVWADD_VLCombine(N, DCI);
+ case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::MUL_VL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 1deb9a709463e8..6744a38d036b00 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -691,6 +691,30 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
GPR:$vl, sew, TU_MU)>;
}
+class VPatTiedBinaryMaskVL_V<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType op2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg op2_reg_class>
+ : Pat<(riscv_vmerge_vl (mask_type V0),
+ (result_type (vop
+ result_reg_class:$rs1,
+ (op2_type op2_reg_class:$rs2),
+ srcvalue,
+ true_mask,
+ VLOpFrag)),
+ result_reg_class:$rs1, result_reg_class:$merge, VLOpFrag),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ result_reg_class:$rs1,
+ op2_reg_class:$rs2,
+ (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>;
+
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
string suffix,
@@ -819,6 +843,9 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
+ def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
+ wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+ vti.LMul, wti.RegClass, vti.RegClass>;
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
>From 50291c86fbbdb64acbc60b13834e0bd77abf8a3d Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 17 Jan 2024 16:15:51 +0900
Subject: [PATCH 2/3] [RISCV][Isel] add fixed-vectors-vwadd-mask.ll
---
.../RISCV/rvv/fixed-vectors-vwadd-mask.ll | 35 +++++++++++++++++++
1 file changed, 35 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
new file mode 100644
index 00000000000000..afc59b875d79df
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -0,0 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+
+define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %sa, %y
+ ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32_commutative:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %y, %sa
+ ret <8 x i64> %ret
+}
>From af9a532988d1ae284051daed7ade214c8ff19d7c Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 18 Jan 2024 00:21:39 +0900
Subject: [PATCH 3/3] [RISCV][Isel] use tied mask and use
isBuildVectorAllZeros.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 5 ++--
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 7 +++--
.../RISCV/rvv/fixed-vectors-vwadd-mask.ll | 27 ++++++++++++++++---
3 files changed, 28 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a030538e5e8ba9..37f24a17cdb65d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13475,7 +13475,7 @@ static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
SDValue Z = Merge->getOperand(2);
if (Z.getOpcode() != ISD::INSERT_SUBVECTOR ||
- !isNullConstant(Z.getOperand(2)))
+ !ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
return SDValue();
if (!Merge.hasOneUse())
@@ -13489,8 +13489,7 @@ static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags());
- return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, DAG.getUNDEF(VT),
- VL);
+ return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, Y, VL);
}
static SDValue performVWADD_VLCombine(SDNode *N,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 6744a38d036b00..59234d7d3a7d49 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -708,12 +708,11 @@ class VPatTiedBinaryMaskVL_V<SDNode vop,
srcvalue,
true_mask,
VLOpFrag)),
- result_reg_class:$rs1, result_reg_class:$merge, VLOpFrag),
- (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
- result_reg_class:$merge,
+ result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
result_reg_class:$rs1,
op2_reg_class:$rs2,
- (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>;
+ (mask_type V0), GPR:$vl, sew, TU_MU)>;
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
index afc59b875d79df..a47b06f24acf90 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -8,8 +8,9 @@ define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
@@ -24,8 +25,9 @@ define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
@@ -33,3 +35,20 @@ define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
%ret = add <8 x i64> %y, %sa
ret <8 x i64> %ret
}
+
+define <8 x i64> @vwadd_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32_nonzero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vmv.v.i v10, 1
+; CHECK-NEXT: vmerge.vvm v16, v10, v8, v0
+; CHECK-NEXT: vwadd.wv v8, v12, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %y, %sa
+ ret <8 x i64> %ret
+}
More information about the llvm-commits
mailing list