[llvm] [RISCV] Add disjoint or patterns for vwadd[u].vv (PR #136716)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 22 10:19:42 PDT 2025


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/136716

>From 6cd17502db5eb87c34672a0e0de88120e044ab30 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Tue, 22 Apr 2025 23:35:03 +0800
Subject: [PATCH 1/2] [RISCV] Add disjoint or patterns for vwadd[u].vv

DAGCombiner::hoistLogicOpWithSameOpcodeHands will hoist

(or disjoint (ext a), (ext b)) -> (ext (or disjoint a, b))

So this adds a pattern to match vwadd[u].vv in this case.

We have to teach the combine to preserve the disjoint flag, and add a generic PatFrag for a disjoint or.

This is meant to be a follow up to #136677 which would allow us to remove the target hook added there.
---
 .../include/llvm/Target/TargetSelectionDAG.td |  4 ++++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  4 +++-
 .../Target/RISCV/RISCVInstrInfoVSDPatterns.td | 19 +++++++++++++++++++
 llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll   | 14 ++++----------
 4 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 9c241b6c4df0f..20ef517426cf8 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -1113,6 +1113,10 @@ def not  : PatFrag<(ops node:$in), (xor node:$in, -1)>;
 def vnot : PatFrag<(ops node:$in), (xor node:$in, immAllOnesV)>;
 def ineg : PatFrag<(ops node:$in), (sub 0, node:$in)>;
 
+def or_disjoint : PatFrag<(ops node:$x, node:$y), (or node:$x, node:$y), [{
+  return N->getFlags().hasDisjoint();
+}]>;
+
 def zanyext : PatFrags<(ops node:$op),
                        [(zext node:$op),
                         (anyext node:$op)]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b175e35385ec6..8cfcd2be8c61c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5982,7 +5982,9 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
         LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
       return SDValue();
     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
-    SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
+    SDNodeFlags LogicFlags;
+    LogicFlags.setDisjoint(N->getFlags().hasDisjoint());
+    SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
     if (HandOpcode == ISD::SIGN_EXTEND_INREG)
       return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
     return DAG.getNode(HandOpcode, DL, VT, Logic);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index b2c5261ae6c2d..71893e85bcb91 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -912,6 +912,25 @@ defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, sext_oneuse, "PseudoVWADD">;
 defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, zext_oneuse, "PseudoVWADDU">;
 defm : VPatWidenBinarySDNode_VV_VX_WV_WX<add, anyext_oneuse, "PseudoVWADDU">;
 
+// DAGCombiner::hoistLogicOpWithSameOpcodeHands may hoist disjoint ors
+// to (ext (or disjoint (a, b)))
+multiclass VPatWidenOrDisjoint_VV<SDNode extop, string instruction_name> {
+  foreach vtiToWti = AllWidenableIntVectors in {
+    defvar vti = vtiToWti.Vti;
+    defvar wti = vtiToWti.Wti;
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<wti>.Predicates) in {
+      def : Pat<(wti.Vector (extop (vti.Vector (or_disjoint vti.RegClass:$rs2, vti.RegClass:$rs1)))),
+                (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
+                   (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+                   vti.RegClass:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
+    }
+  }
+}
+defm : VPatWidenOrDisjoint_VV<sext, "PseudoVWADD">;
+defm : VPatWidenOrDisjoint_VV<zext, "PseudoVWADDU">;
+defm : VPatWidenOrDisjoint_VV<anyext, "PseudoVWADDU">;
+
 defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, sext_oneuse, "PseudoVWSUB">;
 defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, zext_oneuse, "PseudoVWSUBU">;
 defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, anyext_oneuse, "PseudoVWSUBU">;
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 3f5d42f89337b..149950484c477 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1417,15 +1417,12 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v
   ret <vscale x 2 x i32> %add
 }
 
-; TODO: We could select vwaddu.vv, but when both arms of the or are the same
-; DAGCombiner::hoistLogicOpWithSameOpcodeHands moves the zext above the or.
 define <vscale x 2 x i32> @vwaddu_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vscale x 2 x i16> %y.i16) {
 ; CHECK-LABEL: vwaddu_vv_disjoint_or:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
-; CHECK-NEXT:    vor.vv v9, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v8, v9
+; CHECK-NEXT:    vwaddu.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v8, v10
 ; CHECK-NEXT:    ret
   %x.i32 = zext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
   %y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
@@ -1433,15 +1430,12 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsc
   ret <vscale x 2 x i32> %or
 }
 
-; TODO: We could select vwadd.vv, but when both arms of the or are the same
-; DAGCombiner::hoistLogicOpWithSameOpcodeHands moves the zext above the or.
 define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vscale x 2 x i16> %y.i16) {
 ; CHECK-LABEL: vwadd_vv_disjoint_or:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
-; CHECK-NEXT:    vor.vv v9, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v8, v9
+; CHECK-NEXT:    vwadd.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v8, v10
 ; CHECK-NEXT:    ret
   %x.i32 = sext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
   %y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>

>From 644fe31ee2b2dbe164fee51188faa36e3e21e3bc Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 23 Apr 2025 01:19:00 +0800
Subject: [PATCH 2/2] Don't propagate disjoint for inreg exts, use or_is_add,
 match SplatPat

---
 .../include/llvm/Target/TargetSelectionDAG.td |  4 ---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  3 +-
 .../Target/RISCV/RISCVInstrInfoVSDPatterns.td |  6 +++-
 llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll   | 31 +++++++++++++++++++
 4 files changed, 38 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 20ef517426cf8..9c241b6c4df0f 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -1113,10 +1113,6 @@ def not  : PatFrag<(ops node:$in), (xor node:$in, -1)>;
 def vnot : PatFrag<(ops node:$in), (xor node:$in, immAllOnesV)>;
 def ineg : PatFrag<(ops node:$in), (sub 0, node:$in)>;
 
-def or_disjoint : PatFrag<(ops node:$x, node:$y), (or node:$x, node:$y), [{
-  return N->getFlags().hasDisjoint();
-}]>;
-
 def zanyext : PatFrags<(ops node:$op),
                        [(zext node:$op),
                         (anyext node:$op)]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8cfcd2be8c61c..f39951b1865ce 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5983,7 +5983,8 @@ SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
       return SDValue();
     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
     SDNodeFlags LogicFlags;
-    LogicFlags.setDisjoint(N->getFlags().hasDisjoint());
+    LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
+                           ISD::isExtOpcode(HandOpcode));
     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
     if (HandOpcode == ISD::SIGN_EXTEND_INREG)
       return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 71893e85bcb91..25c1d169502ea 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -920,10 +920,14 @@ multiclass VPatWidenOrDisjoint_VV<SDNode extop, string instruction_name> {
     defvar wti = vtiToWti.Wti;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def : Pat<(wti.Vector (extop (vti.Vector (or_disjoint vti.RegClass:$rs2, vti.RegClass:$rs1)))),
+      def : Pat<(wti.Vector (extop (vti.Vector (or_is_add vti.RegClass:$rs2, vti.RegClass:$rs1)))),
                 (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
                    (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
                    vti.RegClass:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
+      def : Pat<(wti.Vector (extop (vti.Vector (or_is_add vti.RegClass:$rs2, (SplatPat (XLenVT GPR:$rs1)))))),
+                (!cast<Instruction>(instruction_name#"_VX_"#vti.LMul.MX)
+                   (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
+                   GPR:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
     }
   }
 }
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 149950484c477..ebccb18540bbd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1443,6 +1443,37 @@ define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsca
   ret <vscale x 2 x i32> %or
 }
 
+define <vscale x 2 x i32> @vwaddu_vx_disjoint_or(<vscale x 2 x i16> %x.i16, i16 %y.i16) {
+; CHECK-LABEL: vwaddu_vx_disjoint_or:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwaddu.vx v9, v8, a0
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %x.i32 = zext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
+  %y.head = insertelement <vscale x 2 x i16> poison, i16 %y.i16, i32 0
+  %y.splat = shufflevector <vscale x 2 x i16> %y.head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
+  %y.i32 = zext <vscale x 2 x i16> %y.splat to <vscale x 2 x i32>
+  %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
+  ret <vscale x 2 x i32> %or
+}
+
+
+define <vscale x 2 x i32> @vwadd_vx_disjoint_or(<vscale x 2 x i16> %x.i16, i16 %y.i16) {
+; CHECK-LABEL: vwadd_vx_disjoint_or:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vx v9, v8, a0
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %x.i32 = sext <vscale x 2 x i16> %x.i16 to <vscale x 2 x i32>
+  %y.head = insertelement <vscale x 2 x i16> poison, i16 %y.i16, i32 0
+  %y.splat = shufflevector <vscale x 2 x i16> %y.head, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
+  %y.i32 = sext <vscale x 2 x i16> %y.splat to <vscale x 2 x i32>
+  %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
+  ret <vscale x 2 x i32> %or
+}
+
 define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
 ; CHECK-LABEL: vwaddu_wv_disjoint_or:
 ; CHECK:       # %bb.0:



More information about the llvm-commits mailing list