[llvm] [RISCV] Combine (or disjoint ext, ext) -> vwadd (PR #86929)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 28 03:11:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd.

This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.


---
Full diff: https://github.com/llvm/llvm-project/pull/86929.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+20-6) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll (+4-7) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 564fda674317f4..e068e9e72a26b3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13527,7 +13527,7 @@ struct CombineResult;
 enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
 /// fadd -> vfwadd | vfwadd_w
@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case ISD::OR:
       return RISCVISD::VWADD_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -13697,6 +13698,7 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case ISD::OR:
       return RISCVISD::VWADDU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -13742,6 +13744,7 @@ struct NodeExtensionHelper {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
+    case ISD::OR:
       return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
                                           : RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
@@ -13862,6 +13865,10 @@ struct NodeExtensionHelper {
     case ISD::MUL: {
       return Root->getValueType(0).isScalableVector();
     }
+    case ISD::OR: {
+      return Root->getValueType(0).isScalableVector() &&
+             Root->getFlags().hasDisjoint();
+    }
     // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
@@ -13942,7 +13949,8 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::OR: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13965,6 +13973,7 @@ struct NodeExtensionHelper {
     switch (N->getOpcode()) {
     case ISD::ADD:
     case ISD::MUL:
+    case ISD::OR:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -14031,6 +14040,7 @@ struct CombineResult {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL:
+    case ISD::OR:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
@@ -14181,6 +14191,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   switch (Root->getOpcode()) {
   case ISD::ADD:
   case ISD::SUB:
+  case ISD::OR:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
   case RISCVISD::FADD_VL:
@@ -14224,9 +14235,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 
 /// Combine a binary operation to its equivalent VW or VW_W form.
 /// The supported combines are:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
 /// fadd_vl ->  vfwadd | vfwadd_w
 /// fsub_vl ->  vfwsub | vfwsub_w
 /// fmul_vl ->  vfwmul
@@ -15886,8 +15897,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
-  case ISD::OR:
+  case ISD::OR: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
     return performORCombine(N, DCI, Subtarget);
+  }
   case ISD::XOR:
     return performXORCombine(N, DAG, Subtarget);
   case ISD::MUL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 36bc10f055b84b..569d1bbbfa5f2d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb
 }
 
 ; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or.
-; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint
-; flag is set.
+; Check that we combine disjoint ors into vwaddu.
 define <vscale x 2 x i32> @disjoint_or(<vscale x 2 x i8> %x.i8, <vscale x 2 x i8> %y.i8) {
 ; CHECK-LABEL: disjoint_or:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
 ; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vsll.vi v8, v10, 8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vzext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf4 v8, v9
-; CHECK-NEXT:    vor.vv v8, v10, v8
+; CHECK-NEXT:    vsll.vi v10, v10, 8
+; CHECK-NEXT:    vzext.vf2 v11, v9
+; CHECK-NEXT:    vwaddu.vv v8, v10, v11
 ; CHECK-NEXT:    ret
   %x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
   %x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)

``````````

</details>


https://github.com/llvm/llvm-project/pull/86929


More information about the llvm-commits mailing list