[llvm] [RISCV][ISel] Remove redundant vmerge for the vwadd. (PR #78403)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Jan 20 01:27:08 PST 2024
https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/78403
>From c74c6c973eb08d0f7082a552217aa467125d9179 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 17 Jan 2024 16:14:58 +0900
Subject: [PATCH 1/5] [RISCV][Isel] fold (vwadd y, (select cond, x, 0)) ->
select cond (vwadd y, x), y.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 54 ++++++++++++++++++-
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 27 ++++++++++
2 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cb9ffabc41236e..a030538e5e8ba9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13457,6 +13457,56 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return InputRootReplacement;
}
+// (vwadd y, (select cond, x, 0)) -> select cond (vwadd y, x), y
+static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
+ unsigned Opc = N->getOpcode();
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
+ Opc == RISCVISD::VWADDU_W_VL);
+
+ SDValue VL = N->getOperand(4);
+ SDValue Y = N->getOperand(0);
+ SDValue Merge = N->getOperand(1);
+
+ if (Merge.getOpcode() != RISCVISD::VMERGE_VL)
+ return SDValue();
+
+ SDValue Cond = Merge->getOperand(0);
+ SDValue X = Merge->getOperand(1);
+ SDValue Z = Merge->getOperand(2);
+
+ if (Z.getOpcode() != ISD::INSERT_SUBVECTOR ||
+ !isNullConstant(Z.getOperand(2)))
+ return SDValue();
+
+ if (!Merge.hasOneUse())
+ return SDValue();
+
+ SmallVector<SDValue, 6> Ops(N->op_values());
+ Ops[0] = Y;
+ Ops[1] = X;
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+
+ SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags());
+ return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, DAG.getUNDEF(VT),
+ VL);
+}
+
+static SDValue performVWADD_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ unsigned Opc = N->getOpcode();
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
+ Opc == RISCVISD::VWADDU_W_VL);
+
+ if (Opc != RISCVISD::VWADD_VL) {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+ return V;
+ }
+
+ return combineVWADDSelect(N, DCI.DAG);
+}
+
// Helper function for performMemPairCombine.
// Try to combine the memory loads/stores LSNode1 and LSNode2
// into a single memory pair operation.
@@ -15500,9 +15550,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
return V;
return combineToVWMACC(N, DAG, Subtarget);
- case RISCVISD::SUB_VL:
+ case RISCVISD::VWADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ return performVWADD_VLCombine(N, DCI);
+ case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::MUL_VL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 1deb9a709463e8..6744a38d036b00 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -691,6 +691,30 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
GPR:$vl, sew, TU_MU)>;
}
+class VPatTiedBinaryMaskVL_V<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType op2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg op2_reg_class>
+ : Pat<(riscv_vmerge_vl (mask_type V0),
+ (result_type (vop
+ result_reg_class:$rs1,
+ (op2_type op2_reg_class:$rs2),
+ srcvalue,
+ true_mask,
+ VLOpFrag)),
+ result_reg_class:$rs1, result_reg_class:$merge, VLOpFrag),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ result_reg_class:$rs1,
+ op2_reg_class:$rs2,
+ (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>;
+
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
string suffix,
@@ -819,6 +843,9 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
+ def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
+ wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+ vti.LMul, wti.RegClass, vti.RegClass>;
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
>From 50291c86fbbdb64acbc60b13834e0bd77abf8a3d Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 17 Jan 2024 16:15:51 +0900
Subject: [PATCH 2/5] [RISCV][Isel] add fixed-vectors-vwadd-mask.ll
---
.../RISCV/rvv/fixed-vectors-vwadd-mask.ll | 35 +++++++++++++++++++
1 file changed, 35 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
new file mode 100644
index 00000000000000..afc59b875d79df
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -0,0 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+
+define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %sa, %y
+ ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32_commutative:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %y, %sa
+ ret <8 x i64> %ret
+}
>From af9a532988d1ae284051daed7ade214c8ff19d7c Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 18 Jan 2024 00:21:39 +0900
Subject: [PATCH 3/5] [RISCV][Isel] use tied mask and use
isBuildVectorAllZeros.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 5 ++--
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 7 +++--
.../RISCV/rvv/fixed-vectors-vwadd-mask.ll | 27 ++++++++++++++++---
3 files changed, 28 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a030538e5e8ba9..37f24a17cdb65d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13475,7 +13475,7 @@ static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
SDValue Z = Merge->getOperand(2);
if (Z.getOpcode() != ISD::INSERT_SUBVECTOR ||
- !isNullConstant(Z.getOperand(2)))
+ !ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
return SDValue();
if (!Merge.hasOneUse())
@@ -13489,8 +13489,7 @@ static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags());
- return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, DAG.getUNDEF(VT),
- VL);
+ return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, Y, VL);
}
static SDValue performVWADD_VLCombine(SDNode *N,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 6744a38d036b00..59234d7d3a7d49 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -708,12 +708,11 @@ class VPatTiedBinaryMaskVL_V<SDNode vop,
srcvalue,
true_mask,
VLOpFrag)),
- result_reg_class:$rs1, result_reg_class:$merge, VLOpFrag),
- (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
- result_reg_class:$merge,
+ result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
result_reg_class:$rs1,
op2_reg_class:$rs2,
- (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>;
+ (mask_type V0), GPR:$vl, sew, TU_MU)>;
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
index afc59b875d79df..a47b06f24acf90 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -8,8 +8,9 @@ define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
@@ -24,8 +25,9 @@ define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
+; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
@@ -33,3 +35,20 @@ define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
%ret = add <8 x i64> %y, %sa
ret <8 x i64> %ret
}
+
+define <8 x i64> @vwadd_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32_nonzero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vmv.v.i v10, 1
+; CHECK-NEXT: vmerge.vvm v16, v10, v8, v0
+; CHECK-NEXT: vwadd.wv v8, v12, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %y, %sa
+ ret <8 x i64> %ret
+}
>From ab363a0860aded8ecd25bda6aeb019fafef19479 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 18 Jan 2024 12:44:02 +0900
Subject: [PATCH 4/5] [RISCV][Isel] use mask as operand.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 26 +++----
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 26 -------
.../RISCV/rvv/fixed-vectors-vwadd-mask.ll | 68 ++++++++++++++++---
3 files changed, 72 insertions(+), 48 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 37f24a17cdb65d..c7adb5a3fda338 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13460,12 +13460,16 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
// (vwadd y, (select cond, x, 0)) -> select cond (vwadd y, x), y
static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
unsigned Opc = N->getOpcode();
- assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
- Opc == RISCVISD::VWADDU_W_VL);
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADDU_VL ||
+ Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
- SDValue VL = N->getOperand(4);
- SDValue Y = N->getOperand(0);
SDValue Merge = N->getOperand(1);
+ unsigned MergeID = 1;
+
+ if (Merge.getOpcode() != RISCVISD::VMERGE_VL) {
+ Merge = N->getOperand(0);
+ MergeID = 0;
+ }
if (Merge.getOpcode() != RISCVISD::VMERGE_VL)
return SDValue();
@@ -13482,23 +13486,20 @@ static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
return SDValue();
SmallVector<SDValue, 6> Ops(N->op_values());
- Ops[0] = Y;
- Ops[1] = X;
+ Ops[MergeID] = X;
+ Ops[3] = Cond;
SDLoc DL(N);
- EVT VT = N->getValueType(0);
-
- SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags());
- return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, Y, VL);
+ return DAG.getNode(Opc, DL, N->getValueType(0), Ops, N->getFlags());
}
static SDValue performVWADD_VLCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
unsigned Opc = N->getOpcode();
assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
- Opc == RISCVISD::VWADDU_W_VL);
+ Opc == RISCVISD::VWADDU_VL || Opc == RISCVISD::VWADDU_W_VL);
- if (Opc != RISCVISD::VWADD_VL) {
+ if (Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWADD_W_VL) {
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
return V;
}
@@ -15550,6 +15551,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return V;
return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::VWADD_VL:
+ case RISCVISD::VWADDU_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
return performVWADD_VLCombine(N, DCI);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 59234d7d3a7d49..1deb9a709463e8 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -691,29 +691,6 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
GPR:$vl, sew, TU_MU)>;
}
-class VPatTiedBinaryMaskVL_V<SDNode vop,
- string instruction_name,
- string suffix,
- ValueType result_type,
- ValueType op2_type,
- ValueType mask_type,
- int sew,
- LMULInfo vlmul,
- VReg result_reg_class,
- VReg op2_reg_class>
- : Pat<(riscv_vmerge_vl (mask_type V0),
- (result_type (vop
- result_reg_class:$rs1,
- (op2_type op2_reg_class:$rs2),
- srcvalue,
- true_mask,
- VLOpFrag)),
- result_reg_class:$rs1, result_reg_class:$rs1, VLOpFrag),
- (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
- result_reg_class:$rs1,
- op2_reg_class:$rs2,
- (mask_type V0), GPR:$vl, sew, TU_MU)>;
-
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
string suffix,
@@ -842,9 +819,6 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
- def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
- wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
- vti.LMul, wti.RegClass, vti.RegClass>;
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
index a47b06f24acf90..486edea7e42513 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -2,32 +2,80 @@
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
-define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
-; CHECK-LABEL: vwadd_mask_v8i32:
+define <8 x i64> @vwadd_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_wv_mask_v8i32:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
-; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %sa, %y
+ ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
+; CHECK-LABEL: vwadd_vv_mask_v8i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.vv v12, v8, v10, v0.t
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
%sa = sext <8 x i32> %a to <8 x i64>
+ %sy = sext <8 x i32> %y to <8 x i64>
+ %ret = add <8 x i64> %sa, %sy
+ ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwaddu_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwaddu_wv_mask_v8i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwaddu.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = zext <8 x i32> %a to <8 x i64>
%ret = add <8 x i64> %sa, %y
ret <8 x i64> %ret
}
-define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
-; CHECK-LABEL: vwadd_mask_v8i32_commutative:
+define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
+; CHECK-LABEL: vwaddu_vv_mask_v8i32:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vsetvli zero, zero, e32, m2, tu, mu
-; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vwaddu.vv v12, v8, v10, v0.t
; CHECK-NEXT: vmv4r.v v8, v12
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = zext <8 x i32> %a to <8 x i64>
+ %sy = zext <8 x i32> %y to <8 x i64>
+ %ret = add <8 x i64> %sa, %sy
+ ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_wv_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_wv_mask_v8i32_commutative:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
@@ -36,8 +84,8 @@ define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
ret <8 x i64> %ret
}
-define <8 x i64> @vwadd_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
-; CHECK-LABEL: vwadd_mask_v8i32_nonzero:
+define <8 x i64> @vwadd_wv_mask_v8i32_nonzero(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_wv_mask_v8i32_nonzero:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
>From 4b260e4da64f1477e96516aa41ea9d0777c111e3 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sat, 20 Jan 2024 18:26:52 +0900
Subject: [PATCH 5/5] [RISCV][Isel] only combine vwadd.wv and vwaddu.wv.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 70 +++++++++----------
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 25 +++++++
.../RISCV/rvv/fixed-vectors-vwadd-mask.ll | 45 ++++--------
3 files changed, 71 insertions(+), 69 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c7adb5a3fda338..c8cde543373660 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13457,54 +13457,54 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return InputRootReplacement;
}
-// (vwadd y, (select cond, x, 0)) -> select cond (vwadd y, x), y
-static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
+// Fold (vwadd.wv y, (vmerge cond, x, 0)) -> vwadd.wv y, x, y, cond
+// y will be the Passthru and cond will be the Mask.
+static SDValue combineVWADDWSelect(SDNode *N, SelectionDAG &DAG) {
unsigned Opc = N->getOpcode();
- assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADDU_VL ||
- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
+ assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
- SDValue Merge = N->getOperand(1);
- unsigned MergeID = 1;
-
- if (Merge.getOpcode() != RISCVISD::VMERGE_VL) {
- Merge = N->getOperand(0);
- MergeID = 0;
- }
-
- if (Merge.getOpcode() != RISCVISD::VMERGE_VL)
+ SDValue Y = N->getOperand(0);
+ SDValue MergeOp = N->getOperand(1);
+ if (MergeOp.getOpcode() != RISCVISD::VMERGE_VL)
return SDValue();
+ SDValue X = MergeOp->getOperand(1);
- SDValue Cond = Merge->getOperand(0);
- SDValue X = Merge->getOperand(1);
- SDValue Z = Merge->getOperand(2);
+ if (!MergeOp.hasOneUse())
+ return SDValue();
- if (Z.getOpcode() != ISD::INSERT_SUBVECTOR ||
- !ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
+ // Passthru should be undef
+ SDValue Passthru = N->getOperand(2);
+ if (!Passthru.isUndef())
return SDValue();
- if (!Merge.hasOneUse())
+ // Mask should be all ones
+ SDValue Mask = N->getOperand(3);
+ if (Mask.getOpcode() != RISCVISD::VMSET_VL)
return SDValue();
- SmallVector<SDValue, 6> Ops(N->op_values());
- Ops[MergeID] = X;
- Ops[3] = Cond;
+ // False value of MergeOp should be all zeros
+ SDValue Z = MergeOp->getOperand(2);
+ if (Z.getOpcode() != ISD::INSERT_SUBVECTOR)
+ return SDValue();
+ if (!ISD::isBuildVectorAllZeros(Z.getOperand(1).getNode()))
+ return SDValue();
+ if (!isNullOrNullSplat(Z.getOperand(0)) && !Z.getOperand(0).isUndef())
+ return SDValue();
- SDLoc DL(N);
- return DAG.getNode(Opc, DL, N->getValueType(0), Ops, N->getFlags());
+ return DAG.getNode(Opc, SDLoc(N), N->getValueType(0),
+ {Y, X, Y, MergeOp->getOperand(3), N->getOperand(4)},
+ N->getFlags());
}
-static SDValue performVWADD_VLCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue performVWADDW_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
unsigned Opc = N->getOpcode();
- assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
- Opc == RISCVISD::VWADDU_VL || Opc == RISCVISD::VWADDU_W_VL);
+ assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL);
- if (Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWADD_W_VL) {
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
- return V;
- }
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+ return V;
- return combineVWADDSelect(N, DCI.DAG);
+ return combineVWADDWSelect(N, DCI.DAG);
}
// Helper function for performMemPairCombine.
@@ -15550,11 +15550,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
return V;
return combineToVWMACC(N, DAG, Subtarget);
- case RISCVISD::VWADD_VL:
- case RISCVISD::VWADDU_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
- return performVWADD_VLCombine(N, DCI);
+ return performVWADDW_VLCombine(N, DCI);
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 1deb9a709463e8..4821b59b301abc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -691,6 +691,27 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
GPR:$vl, sew, TU_MU)>;
}
+class VPatTiedBinaryMaskVL_V<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType op2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg op2_reg_class> :
+ Pat<(result_type (vop
+ (result_type result_reg_class:$rs1),
+ (op2_type op2_reg_class:$rs2),
+ (result_type result_reg_class:$rs1),
+ (mask_type V0),
+ VLOpFrag)),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK_TIED")
+ result_reg_class:$rs1,
+ op2_reg_class:$rs2,
+ (mask_type V0), GPR:$vl, sew, TU_MU)>;
+
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
string suffix,
@@ -819,6 +840,10 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
+ def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
+ wti.Vector, vti.Vector, wti.Mask,
+ vti.Log2SEW, vti.LMul, wti.RegClass,
+ vti.RegClass>;
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
index 486edea7e42513..2ce542a243ee77 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -5,44 +5,23 @@
define <8 x i64> @vwadd_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwadd_wv_mask_v8i32:
; CHECK: # %bb.0:
-; CHECK-NEXT: li a0, 42
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
-; CHECK-NEXT: ret
- %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
- %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
- %sa = sext <8 x i32> %a to <8 x i64>
- %ret = add <8 x i64> %sa, %y
- ret <8 x i64> %ret
-}
-
-define <8 x i64> @vwadd_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
-; CHECK-LABEL: vwadd_vv_mask_v8i32:
-; CHECK: # %bb.0:
-; CHECK-NEXT: li a0, 42
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.vv v12, v8, v10, v0.t
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, tu, mu
+; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
%sa = sext <8 x i32> %a to <8 x i64>
- %sy = sext <8 x i32> %y to <8 x i64>
- %ret = add <8 x i64> %sa, %sy
+ %ret = add <8 x i64> %sa, %y
ret <8 x i64> %ret
}
define <8 x i64> @vwaddu_wv_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwaddu_wv_mask_v8i32:
; CHECK: # %bb.0:
-; CHECK-NEXT: li a0, 42
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwaddu.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, tu, mu
+; CHECK-NEXT: vwaddu.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
@@ -57,7 +36,9 @@ define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
; CHECK-NEXT: li a0, 42
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwaddu.vv v12, v8, v10, v0.t
+; CHECK-NEXT: vmv.v.i v12, 0
+; CHECK-NEXT: vmerge.vvm v8, v12, v8, v0
+; CHECK-NEXT: vwaddu.vv v12, v8, v10
; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
@@ -71,11 +52,9 @@ define <8 x i64> @vwaddu_vv_mask_v8i32(<8 x i32> %x, <8 x i32> %y) {
define <8 x i64> @vwadd_wv_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
; CHECK-LABEL: vwadd_wv_mask_v8i32_commutative:
; CHECK: # %bb.0:
-; CHECK-NEXT: li a0, 42
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vmslt.vx v0, v8, a0
-; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
-; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, tu, mu
+; CHECK-NEXT: vwadd.wv v12, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v12
; CHECK-NEXT: ret
%mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
%a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
More information about the llvm-commits
mailing list