[llvm] 8204931 - [RISCV] Add disjoint or patterns for vwadd[u].v{v, x} (#136716)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 23 00:17:08 PDT 2025
Author: Luke Lau
Date: 2025-04-23T15:17:04+08:00
New Revision: 82049310385d5222527cf7d12984bd8d4f955dd1
URL: https://github.com/llvm/llvm-project/commit/82049310385d5222527cf7d12984bd8d4f955dd1
DIFF: https://github.com/llvm/llvm-project/commit/82049310385d5222527cf7d12984bd8d4f955dd1.diff
LOG: [RISCV] Add disjoint or patterns for vwadd[u].v{v,x} (#136716)
DAGCombiner::hoistLogicOpWithSameOpcodeHands will hoist
(or disjoint (ext a), (ext b)) -> (ext (or disjoint a, b))
So this adds patterns to match vwadd[u].v{v,x} in this case.
We have to teach the combine to preserve the disjoint flag.
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b571f635c744f..6255922979399 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -6037,7 +6037,10 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
return SDValue();
// logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
- SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+ SDNodeFlags LogicFlags;
+ LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
+ ISD::isExtOpcode(HandOpcode));
+ SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
if (HandOpcode == ISD::SIGN_EXTEND_INREG)
return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
return DAG.getNode(HandOpcode, DL, VT, Logic);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index b2c5261ae6c2d..aea125c5348dd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -912,6 +912,29 @@ defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, sext_oneuse, "PseudoVWADD">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, zext_oneuse, "PseudoVWADDU">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, anyext_oneuse, "PseudoVWADDU">;
+// DAGCombiner::hoistLogicOpWithSameOpcodeHands may hoist disjoint ors
+// to (ext (or disjoint (a, b)))
+multiclass VPatWidenOrDisjoint_VV_VX<SDNode extop, string instruction_name> {
+ foreach vtiToWti = AllWidenableIntVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
+ def : Pat<(wti.Vector (extop (vti.Vector (or_is_add vti.RegClass:$rs2, vti.RegClass:$rs1)))),
+ (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
+ (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+ vti.RegClass:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
+ def : Pat<(wti.Vector (extop (vti.Vector (or_is_add vti.RegClass:$rs2, (SplatPat (XLenVT GPR:$rs1)))))),
+ (!cast<Instruction>(instruction_name#"_VX_"#vti.LMul.MX)
+ (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+ GPR:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
+ }
+ }
+}
+defm : VPatWidenOrDisjoint_VV_VX<sext, "PseudoVWADD">;
+defm : VPatWidenOrDisjoint_VV_VX<zext, "PseudoVWADDU">;
+defm : VPatWidenOrDisjoint_VV_VX<anyext, "PseudoVWADDU">;
+
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, sext_oneuse, "PseudoVWSUB">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, zext_oneuse, "PseudoVWSUBU">;
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, anyext_oneuse, "PseudoVWSUBU">;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 3f5d42f89337b..f94e46771f49c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1417,15 +1417,12 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v
ret <vscale x 2 x i32> %add
}
-; TODO: We could select vwaddu.vv, but when both arms of the or are the same
-; DAGCombiner::hoistLogicOpWithSameOpcodeHands moves the zext above the or.
define <vscale x 2 x i32> @vwaddu_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwaddu_vv_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vv v9, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v8, v9
+; CHECK-NEXT: vwaddu.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
; CHECK-NEXT: ret
%x.i32 = zext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
%y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
@@ -1433,15 +1430,12 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsc
ret <vscale x 2 x i32> %or
}
-; TODO: We could select vwadd.vv, but when both arms of the or are the same
-; DAGCombiner::hoistLogicOpWithSameOpcodeHands moves the zext above the or.
define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwadd_vv_disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
-; CHECK-NEXT: vor.vv v9, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vsext.vf2 v8, v9
+; CHECK-NEXT: vwadd.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
; CHECK-NEXT: ret
%x.i32 = sext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
%y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
@@ -1449,6 +1443,36 @@ define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsca
ret <vscale x 2 x i32> %or
}
+define <vscale x 2 x i32> @vwaddu_vx_disjoint_or(<vscale x 2 x i16> %x.i16, i16 %y.i16) {
+; CHECK-LABEL: vwaddu_vx_disjoint_or:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwaddu.vx v9, v8, a0
+; CHECK-NEXT: vmv1r.v v8, v9
+; CHECK-NEXT: ret
+ %x.i32 = zext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
+ %y.head = insertelement <vscale x 2 x i16> poison, i16 %y.i16, i32 0
+ %y.splat = shufflevector <vscale x 2 x i16> %y.head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
+ %y.i32 = zext <vscale x 2 x i16> %y.splat to <vscale x 2 x i32>
+ %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
+ ret <vscale x 2 x i32> %or
+}
+
+define <vscale x 2 x i32> @vwadd_vx_disjoint_or(<vscale x 2 x i16> %x.i16, i16 %y.i16) {
+; CHECK-LABEL: vwadd_vx_disjoint_or:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vx v9, v8, a0
+; CHECK-NEXT: vmv1r.v v8, v9
+; CHECK-NEXT: ret
+ %x.i32 = sext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
+ %y.head = insertelement <vscale x 2 x i16> poison, i16 %y.i16, i32 0
+ %y.splat = shufflevector <vscale x 2 x i16> %y.head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
+ %y.i32 = sext <vscale x 2 x i16> %y.splat to <vscale x 2 x i32>
+ %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
+ ret <vscale x 2 x i32> %or
+}
+
define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
; CHECK-LABEL: vwaddu_wv_disjoint_or:
; CHECK: # %bb.0:
More information about the llvm-commits
mailing list