[llvm] [RISCV] Add fixed-length patterns for disjoint or patterns for vwadd[u].v{v,x} (PR #136824)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 23 00:47:31 PDT 2025
https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/136824
This is the fixed-length equivalent of #136716.
The pattern we need to match is ({s,z}ext_vl (or_vl disjoint a, b)). This only allows or_vls with an undef passthru, which allows us to ignore its mask and vl and just take it from the {s,z}ext_vl.
A riscv_or_vl_is_add_oneuse PatFrag is added to mirror or_is_add in RISCVInstrInfo.td.
>From 631ca8615877b8e1ad23609a7adb3c9071efc70d Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 23 Apr 2025 15:43:24 +0800
Subject: [PATCH] [RISCV] Add fixed-length patterns for disjoint or patterns
for vwadd[u].v{v,x}
This is the fixed-length equivalent of #136716.
The pattern we need to match is ({s,z}ext_vl (or_vl disjoint a, b)). This only allows or_vls with an undef passthru, which allows us to ignore its mask and vl and just take it from the {s,z}ext_vl.
A riscv_or_vl_is_add_oneuse PatFrag is added to mirror or_is_add in RISCVInstrInfo.td.
---
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 41 +++++++++++++++++++
.../CodeGen/RISCV/rvv/fixed-vectors-vwadd.ll | 20 ++++-----
2 files changed, 49 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index f80cbc9e2fb5e..068402c59d6e2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -497,6 +497,16 @@ let HasOneUse = 1 in {
node:$E),
(riscv_add_vl node:$A, node:$B, node:$C,
node:$D, node:$E)>;
+ def riscv_or_vl_is_add_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D,
+ node:$E),
+ (riscv_or_vl node:$A, node:$B, node:$C,
+ node:$D, node:$E), [{
+ if (N->getFlags().hasDisjoint())
+ return true;
+ KnownBits Known0 = CurDAG->computeKnownBits(N->getOperand(0), 0);
+ KnownBits Known1 = CurDAG->computeKnownBits(N->getOperand(1), 0);
+ return KnownBits::haveNoCommonBitsSet(Known0, Known1);
+ }]>;
def riscv_sub_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D,
node:$E),
(riscv_sub_vl node:$A, node:$B, node:$C,
@@ -2016,6 +2026,37 @@ foreach vtiToWti = AllWidenableIntVectors in {
}
}
+// DAGCombiner::hoistLogicOpWithSameOpcodeHands may hoist disjoint ors
+// to (ext (or disjoint (a, b)))
+multiclass VPatWidenOrDisjointVL_VV_VX<SDNode extop, string instruction_name> {
+ foreach vtiToWti = AllWidenableIntVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
+ def : Pat<(wti.Vector (extop (vti.Vector
+ (riscv_or_vl_is_add_oneuse
+ vti.RegClass:$rs2, vti.RegClass:$rs1,
+ undef, srcvalue, srcvalue)),
+ VMV0:$vm, VLOpFrag)),
+ (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX#"_MASK")
+ (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+ vti.RegClass:$rs1, VMV0:$vm, GPR:$vl, vti.Log2SEW, TA_MA)>;
+ def : Pat<(wti.Vector (extop (vti.Vector
+ (riscv_or_vl_is_add_oneuse
+ vti.RegClass:$rs2, (SplatPat (XLenVT GPR:$rs1)),
+ undef, srcvalue, srcvalue)),
+ VMV0:$vm, VLOpFrag)),
+ (!cast<Instruction>(instruction_name#"_VX_"#vti.LMul.MX#"_MASK")
+ (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+ GPR:$rs1, VMV0:$vm, GPR:$vl, vti.Log2SEW, TA_MA)>;
+ }
+ }
+}
+
+defm : VPatWidenOrDisjointVL_VV_VX<riscv_sext_vl, "PseudoVWADD">;
+defm : VPatWidenOrDisjointVL_VV_VX<riscv_zext_vl, "PseudoVWADDU">;
+
// 11.3. Vector Integer Extension
defm : VPatExtendVL_V<riscv_zext_vl, "PseudoVZEXT", "VF2",
AllFractionableVF2IntVectors>;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd.ll
index 5e7d1b91d7892..7f8c8258803fc 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd.ll
@@ -901,9 +901,8 @@ define <4 x i32> @vwaddu_vv_disjoint_or(<4 x i16> %x.i16, <4 x i16> %y.i16) {
; CHECK-LABEL: vwaddu_vv_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vv v9, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v8, v9
+; CHECK-NEXT: vwaddu.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
; CHECK-NEXT: ret
%x.i32 = zext <4 x i16> %x.i16 to <4 x i32>
%y.i32 = zext <4 x i16> %y.i16 to <4 x i32>
@@ -915,9 +914,8 @@ define <4 x i32> @vwadd_vv_disjoint_or(<4 x i16> %x.i16, <4 x i16> %y.i16) {
; CHECK-LABEL: vwadd_vv_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vv v9, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vsext.vf2 v8, v9
+; CHECK-NEXT: vwadd.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
; CHECK-NEXT: ret
%x.i32 = sext <4 x i16> %x.i16 to <4 x i32>
%y.i32 = sext <4 x i16> %y.i16 to <4 x i32>
@@ -929,9 +927,8 @@ define <4 x i32> @vwaddu_vx_disjoint_or(<4 x i16> %x.i16, i16 %y.i16) {
; CHECK-LABEL: vwaddu_vx_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vx v9, v8, a0
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v8, v9
+; CHECK-NEXT: vwaddu.vx v9, v8, a0
+; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x.i32 = zext <4 x i16> %x.i16 to <4 x i32>
%y.head = insertelement <4 x i16> poison, i16 %y.i16, i32 0
@@ -945,9 +942,8 @@ define <4 x i32> @vwadd_vx_disjoint_or(<4 x i16> %x.i16, i16 %y.i16) {
; CHECK-LABEL: vwadd_vx_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vx v9, v8, a0
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vsext.vf2 v8, v9
+; CHECK-NEXT: vwadd.vx v9, v8, a0
+; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x.i32 = sext <4 x i16> %x.i16 to <4 x i32>
%y.head = insertelement <4 x i16> poison, i16 %y.i16, i32 0
More information about the llvm-commits
mailing list