[llvm] 7b33b84 - [SelectionDAG] Support scalable splats in U(ADD|SUB)SAT combines
Fraser Cormack via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 27 03:01:46 PDT 2021
Author: Fraser Cormack
Date: 2021-07-27T10:52:34+01:00
New Revision: 7b33b849bd337fa68accf6a61e36cb4f0a947af1
URL: https://github.com/llvm/llvm-project/commit/7b33b849bd337fa68accf6a61e36cb4f0a947af1
DIFF: https://github.com/llvm/llvm-project/commit/7b33b849bd337fa68accf6a61e36cb4f0a947af1.diff
LOG: [SelectionDAG] Support scalable splats in U(ADD|SUB)SAT combines
This patch builds on top of D106575 in which scalable-vector splats were
supported in `ISD::matchBinaryPredicate`. It teaches the DAGCombiner how
to perform a variety of the pre-existing saturating add/sub combines on
scalable-vector types.
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D106652
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index db4685944b7f4..182c29eea7c42 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10025,10 +10025,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
// If it's on the left side invert the predicate to simplify logic below.
SDValue Other;
ISD::CondCode SatCC = CC;
- if (ISD::isBuildVectorAllOnes(N1.getNode())) {
+ if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
Other = N2;
SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
- } else if (ISD::isBuildVectorAllOnes(N2.getNode())) {
+ } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
Other = N1;
}
@@ -10049,7 +10049,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
(OpLHS == CondLHS || OpRHS == CondLHS))
return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
- if (isa<BuildVectorSDNode>(OpRHS) && isa<BuildVectorSDNode>(CondRHS) &&
+ if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
+ (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
+ OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
CondLHS == OpLHS) {
// If the RHS is a constant we have to reverse the const
// canonicalization.
@@ -10070,10 +10072,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
// the left side invert the predicate to simplify logic below.
SDValue Other;
ISD::CondCode SatCC = CC;
- if (ISD::isBuildVectorAllZeros(N1.getNode())) {
+ if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
Other = N2;
SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
- } else if (ISD::isBuildVectorAllZeros(N2.getNode())) {
+ } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
Other = N1;
}
@@ -10102,8 +10104,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
- if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) {
- if (isa<BuildVectorSDNode>(CondRHS)) {
+ if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
+ OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
+ if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
+ CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
// If the RHS is a constant we have to reverse the const
// canonicalization.
// x > C-1 ? x+-C : 0 --> usubsat x, C
@@ -10125,15 +10129,15 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
// FIXME: Would it be better to use computeKnownBits to determine
// whether it's safe to decanonicalize the xor?
// x s< 0 ? x^C : 0 --> usubsat x, C
- if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
- if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
- ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
- OpRHSConst->getAPIntValue().isSignMask()) {
- // Note that we have to rebuild the RHS constant here to
- // ensure we don't rely on particular values of undef lanes.
- OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT);
- return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
- }
+ APInt SplatValue;
+ if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
+ ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
+ ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
+ SplatValue.isSignMask()) {
+ // Note that we have to rebuild the RHS constant here to
+ // ensure we don't rely on particular values of undef lanes.
+ OpRHS = DAG.getConstant(SplatValue, DL, VT);
+ return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
}
}
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll b/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
index 197e00bb947a7..ada166fa672a8 100644
--- a/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
@@ -101,10 +101,7 @@ define <vscale x 2 x i64> @vselect_sub_nxv2i64(<vscale x 2 x i64> %a0, <vscale x
; CHECK-LABEL: vselect_sub_nxv2i64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT: vmsleu.vv v0, v10, v8
-; CHECK-NEXT: vsub.vv v26, v8, v10
-; CHECK-NEXT: vmv.v.i v28, 0
-; CHECK-NEXT: vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT: vssubu.vv v8, v8, v10
; CHECK-NEXT: ret
%cmp = icmp uge <vscale x 2 x i64> %a0, %a1
%v1 = sub <vscale x 2 x i64> %a0, %a1
@@ -131,9 +128,7 @@ define <vscale x 8 x i16> @vselect_sub_2_nxv8i16(<vscale x 8 x i16> %x, i16 zero
; CHECK-LABEL: vselect_sub_2_nxv8i16:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, mu
-; CHECK-NEXT: vmsltu.vx v0, v8, a0
-; CHECK-NEXT: vsub.vx v26, v8, a0
-; CHECK-NEXT: vmerge.vim v8, v26, 0, v0
+; CHECK-NEXT: vssubu.vx v8, v8, a0
; CHECK-NEXT: ret
entry:
%0 = insertelement <vscale x 8 x i16> undef, i16 %w, i32 0
@@ -163,11 +158,9 @@ define <2 x i64> @vselect_add_const_v2i64(<2 x i64> %a0) {
define <vscale x 2 x i64> @vselect_add_const_nxv2i64(<vscale x 2 x i64> %a0) {
; CHECK-LABEL: vselect_add_const_nxv2i64:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT: vadd.vi v26, v8, -6
-; CHECK-NEXT: vmsgtu.vi v0, v8, 5
-; CHECK-NEXT: vmv.v.i v28, 0
-; CHECK-NEXT: vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT: addi a0, zero, 6
+; CHECK-NEXT: vsetvli a1, zero, e64, m2, ta, mu
+; CHECK-NEXT: vssubu.vx v8, v8, a0
; CHECK-NEXT: ret
%cm1 = insertelement <vscale x 2 x i64> poison, i64 -6, i32 0
%splatcm1 = shufflevector <vscale x 2 x i64> %cm1, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
@@ -205,27 +198,17 @@ define <vscale x 2 x i16> @vselect_add_const_signbit_nxv2i16(<vscale x 2 x i16>
; RV32-LABEL: vselect_add_const_signbit_nxv2i16:
; RV32: # %bb.0:
; RV32-NEXT: lui a0, 8
-; RV32-NEXT: addi a0, a0, -2
+; RV32-NEXT: addi a0, a0, -1
; RV32-NEXT: vsetvli a1, zero, e16, mf2, ta, mu
-; RV32-NEXT: vmsgtu.vx v0, v8, a0
-; RV32-NEXT: lui a0, 1048568
-; RV32-NEXT: addi a0, a0, 1
-; RV32-NEXT: vadd.vx v25, v8, a0
-; RV32-NEXT: vmv.v.i v26, 0
-; RV32-NEXT: vmerge.vvm v8, v26, v25, v0
+; RV32-NEXT: vssubu.vx v8, v8, a0
; RV32-NEXT: ret
;
; RV64-LABEL: vselect_add_const_signbit_nxv2i16:
; RV64: # %bb.0:
; RV64-NEXT: lui a0, 8
-; RV64-NEXT: addiw a0, a0, -2
+; RV64-NEXT: addiw a0, a0, -1
; RV64-NEXT: vsetvli a1, zero, e16, mf2, ta, mu
-; RV64-NEXT: vmsgtu.vx v0, v8, a0
-; RV64-NEXT: lui a0, 1048568
-; RV64-NEXT: addiw a0, a0, 1
-; RV64-NEXT: vadd.vx v25, v8, a0
-; RV64-NEXT: vmv.v.i v26, 0
-; RV64-NEXT: vmerge.vvm v8, v26, v25, v0
+; RV64-NEXT: vssubu.vx v8, v8, a0
; RV64-NEXT: ret
%cm1 = insertelement <vscale x 2 x i16> poison, i16 32766, i32 0
%splatcm1 = shufflevector <vscale x 2 x i16> %cm1, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -255,12 +238,9 @@ define <2 x i16> @vselect_xor_const_signbit_v2i16(<2 x i16> %a0) {
define <vscale x 2 x i16> @vselect_xor_const_signbit_nxv2i16(<vscale x 2 x i16> %a0) {
; CHECK-LABEL: vselect_xor_const_signbit_nxv2i16:
; CHECK: # %bb.0:
-; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, mu
-; CHECK-NEXT: vmsle.vi v0, v8, -1
-; CHECK-NEXT: vmv.v.i v25, 0
-; CHECK-NEXT: lui a0, 1048568
-; CHECK-NEXT: vxor.vx v26, v8, a0
-; CHECK-NEXT: vmerge.vvm v8, v25, v26, v0
+; CHECK-NEXT: lui a0, 8
+; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, mu
+; CHECK-NEXT: vssubu.vx v8, v8, a0
; CHECK-NEXT: ret
%cmp = icmp slt <vscale x 2 x i16> %a0, zeroinitializer
%ins = insertelement <vscale x 2 x i16> poison, i16 -32768, i32 0
@@ -291,10 +271,7 @@ define <vscale x 2 x i64> @vselect_add_nxv2i64(<vscale x 2 x i64> %a0, <vscale x
; CHECK-LABEL: vselect_add_nxv2i64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT: vadd.vv v26, v8, v10
-; CHECK-NEXT: vmsleu.vv v0, v8, v26
-; CHECK-NEXT: vmv.v.i v28, -1
-; CHECK-NEXT: vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT: vsaddu.vv v8, v8, v10
; CHECK-NEXT: ret
%v1 = add <vscale x 2 x i64> %a0, %a1
%cmp = icmp ule <vscale x 2 x i64> %a0, %v1
@@ -323,10 +300,7 @@ define <vscale x 2 x i64> @vselect_add_const_2_nxv2i64(<vscale x 2 x i64> %a0) {
; CHECK-LABEL: vselect_add_const_2_nxv2i64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT: vadd.vi v26, v8, 6
-; CHECK-NEXT: vmsleu.vi v0, v8, -7
-; CHECK-NEXT: vmv.v.i v28, -1
-; CHECK-NEXT: vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT: vsaddu.vi v8, v8, 6
; CHECK-NEXT: ret
%cm1 = insertelement <vscale x 2 x i64> poison, i64 6, i32 0
%splatcm1 = shufflevector <vscale x 2 x i64> %cm1, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
More information about the llvm-commits
mailing list