[llvm] [RISCV] Combine (or disjoint ext, ext) -> vwadd (PR #86929)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 28 03:11:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Luke Lau (lukel97)
<details>
<summary>Changes</summary>
DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd.
This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.
---
Full diff: https://github.com/llvm/llvm-project/pull/86929.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+20-6)
- (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll (+4-7)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 564fda674317f4..e068e9e72a26b3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13527,7 +13527,7 @@ struct CombineResult;
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
/// fadd -> vfwadd | vfwadd_w
@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case ISD::OR:
return RISCVISD::VWADD_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -13697,6 +13698,7 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case ISD::OR:
return RISCVISD::VWADDU_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -13742,6 +13744,7 @@ struct NodeExtensionHelper {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
+ case ISD::OR:
return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
: RISCVISD::VWADDU_W_VL;
case ISD::SUB:
@@ -13862,6 +13865,10 @@ struct NodeExtensionHelper {
case ISD::MUL: {
return Root->getValueType(0).isScalableVector();
}
+ case ISD::OR: {
+ return Root->getValueType(0).isScalableVector() &&
+ Root->getFlags().hasDisjoint();
+ }
// Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
@@ -13942,7 +13949,8 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::OR: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13965,6 +13973,7 @@ struct NodeExtensionHelper {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::MUL:
+ case ISD::OR:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -14031,6 +14040,7 @@ struct CombineResult {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
+ case ISD::OR:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
@@ -14181,6 +14191,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
+ case ISD::OR:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
case RISCVISD::FADD_VL:
@@ -14224,9 +14235,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// Combine a binary operation to its equivalent VW or VW_W form.
/// The supported combines are:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
/// fadd_vl -> vfwadd | vfwadd_w
/// fsub_vl -> vfwsub | vfwsub_w
/// fmul_vl -> vfwmul
@@ -15886,8 +15897,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::AND:
return performANDCombine(N, DCI, Subtarget);
- case ISD::OR:
+ case ISD::OR: {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ return V;
return performORCombine(N, DCI, Subtarget);
+ }
case ISD::XOR:
return performXORCombine(N, DAG, Subtarget);
case ISD::MUL:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
index 36bc10f055b84b..569d1bbbfa5f2d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll
@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb
}
; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or.
-; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint
-; flag is set.
+; Check that we combine disjoint ors into vwaddu.
define <vscale x 2 x i32> @disjoint_or(<vscale x 2 x i8> %x.i8, <vscale x 2 x i8> %y.i8) {
; CHECK-LABEL: disjoint_or:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vsll.vi v8, v10, 8
-; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vzext.vf2 v10, v8
-; CHECK-NEXT: vzext.vf4 v8, v9
-; CHECK-NEXT: vor.vv v8, v10, v8
+; CHECK-NEXT: vsll.vi v10, v10, 8
+; CHECK-NEXT: vzext.vf2 v11, v9
+; CHECK-NEXT: vwaddu.vv v8, v10, v11
; CHECK-NEXT: ret
%x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
%x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)
``````````
</details>
https://github.com/llvm/llvm-project/pull/86929
More information about the llvm-commits
mailing list