[llvm] 2a315d8 - [RISCV] Combine (or disjoint ext, ext) -> vwadd (#86929)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 29 04:45:28 PDT 2024


Author: Luke Lau
Date: 2024-03-29T19:45:24+08:00
New Revision: 2a315d800bb352fe459a012006a42ac7cd63834e

URL: https://github.com/llvm/llvm-project/commit/2a315d800bb352fe459a012006a42ac7cd63834e
DIFF: https://github.com/llvm/llvm-project/commit/2a315d800bb352fe459a012006a42ac7cd63834e.diff

LOG: [RISCV] Combine (or disjoint ext, ext) -> vwadd (#86929)

DAGCombiner (or InstCombine) will convert an add to an or if the bits
are disjoint, which can prevent what was originally an (add {s,z}ext,
{s,z}ext) from being selected as a vwadd.

This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as
an add.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3cd9ecb9dd6817..e48ca4a905ce9e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13530,7 +13530,7 @@ struct CombineResult;
 enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
 /// fadd -> vfwadd | vfwadd_w
@@ -13678,6 +13678,7 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case ISD::OR:
       return RISCVISD::VWADD_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -13700,6 +13701,7 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case ISD::OR:
       return RISCVISD::VWADDU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -13745,6 +13747,7 @@ struct NodeExtensionHelper {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
+    case ISD::OR:
       return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
                                           : RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
@@ -13865,6 +13868,10 @@ struct NodeExtensionHelper {
     case ISD::MUL: {
       return Root->getValueType(0).isScalableVector();
     }
+    case ISD::OR: {
+      return Root->getValueType(0).isScalableVector() &&
+             Root->getFlags().hasDisjoint();
+    }
     // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
@@ -13945,7 +13952,8 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::OR: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13968,6 +13976,7 @@ struct NodeExtensionHelper {
     switch (N->getOpcode()) {
     case ISD::ADD:
     case ISD::MUL:
+    case ISD::OR:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -14034,6 +14043,7 @@ struct CombineResult {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL:
+    case ISD::OR:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
@@ -14184,6 +14194,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   switch (Root->getOpcode()) {
   case ISD::ADD:
   case ISD::SUB:
+  case ISD::OR:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
   case RISCVISD::FADD_VL:
@@ -14227,9 +14238,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 
 /// Combine a binary operation to its equivalent VW or VW_W form.
 /// The supported combines are:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
 /// fadd_vl ->  vfwadd | vfwadd_w
 /// fsub_vl ->  vfwsub | vfwsub_w
 /// fmul_vl ->  vfwmul
@@ -15889,8 +15900,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
-  case ISD::OR:
+  case ISD::OR: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
     return performORCombine(N, DCI, Subtarget);
+  }
   case ISD::XOR:
     return performXORCombine(N, DAG, Subtarget);
   case ISD::MUL:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index ed12afdd95956e..66e6883dd1d3e3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1401,11 +1401,9 @@ define <vscale x 2 x i32> @vwaddu_vv_disjoint_or_add(<vscale x 2 x i8> %x.i8, <v
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vsll.vi v8, v10, 8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf4 v8, v9
-; CHECK-NEXT:    vor.vv v8, v10, v8
+; CHECK-NEXT:    vsll.vi v10, v10, 8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
   %x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)
@@ -1450,9 +1448,8 @@ define <vscale x 2 x i32> @vwadd_vv_disjoint_or(<vscale x 2 x i16> %x.i16, <vsca
 define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
 ; CHECK-LABEL: vwaddu_wv_disjoint_or:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v9
-; CHECK-NEXT:    vor.vv v8, v8, v10
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwaddu.wv v8, v8, v9
 ; CHECK-NEXT:    ret
   %y.i32 = zext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
   %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
@@ -1462,9 +1459,8 @@ define <vscale x 2 x i32> @vwaddu_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsc
 define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vscale x 2 x i16> %y.i16) {
 ; CHECK-LABEL: vwadd_wv_disjoint_or:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v9
-; CHECK-NEXT:    vor.vv v8, v8, v10
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.wv v8, v8, v9
 ; CHECK-NEXT:    ret
   %y.i32 = sext <vscale x 2 x i16> %y.i16 to <vscale x 2 x i32>
   %or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32


        


More information about the llvm-commits mailing list