[llvm] 7b33b84 - [SelectionDAG] Support scalable splats in U(ADD|SUB)SAT combines

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 27 03:01:46 PDT 2021


Author: Fraser Cormack
Date: 2021-07-27T10:52:34+01:00
New Revision: 7b33b849bd337fa68accf6a61e36cb4f0a947af1

URL: https://github.com/llvm/llvm-project/commit/7b33b849bd337fa68accf6a61e36cb4f0a947af1
DIFF: https://github.com/llvm/llvm-project/commit/7b33b849bd337fa68accf6a61e36cb4f0a947af1.diff

LOG: [SelectionDAG] Support scalable splats in U(ADD|SUB)SAT combines

This patch builds on top of D106575 in which scalable-vector splats were
supported in `ISD::matchBinaryPredicate`. It teaches the DAGCombiner how
to perform a variety of the pre-existing saturating add/sub combines on
scalable-vector types.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D106652

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/RISCV/rvv/combine-sats.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index db4685944b7f4..182c29eea7c42 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10025,10 +10025,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
       // If it's on the left side invert the predicate to simplify logic below.
       SDValue Other;
       ISD::CondCode SatCC = CC;
-      if (ISD::isBuildVectorAllOnes(N1.getNode())) {
+      if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
         Other = N2;
         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
-      } else if (ISD::isBuildVectorAllOnes(N2.getNode())) {
+      } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
         Other = N1;
       }
 
@@ -10049,7 +10049,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
             (OpLHS == CondLHS || OpRHS == CondLHS))
           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
 
-        if (isa<BuildVectorSDNode>(OpRHS) && isa<BuildVectorSDNode>(CondRHS) &&
+        if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
+            (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
+             OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
             CondLHS == OpLHS) {
           // If the RHS is a constant we have to reverse the const
           // canonicalization.
@@ -10070,10 +10072,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
       // the left side invert the predicate to simplify logic below.
       SDValue Other;
       ISD::CondCode SatCC = CC;
-      if (ISD::isBuildVectorAllZeros(N1.getNode())) {
+      if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
         Other = N2;
         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
-      } else if (ISD::isBuildVectorAllZeros(N2.getNode())) {
+      } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
         Other = N1;
       }
 
@@ -10102,8 +10104,10 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
 
-          if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS)) {
-            if (isa<BuildVectorSDNode>(CondRHS)) {
+          if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
+              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
+            if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
+                CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
               // If the RHS is a constant we have to reverse the const
               // canonicalization.
               // x > C-1 ? x+-C : 0 --> usubsat x, C
@@ -10125,15 +10129,15 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
               // FIXME: Would it be better to use computeKnownBits to determine
               //        whether it's safe to decanonicalize the xor?
               // x s< 0 ? x^C : 0 --> usubsat x, C
-              if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
-                if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
-                    ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
-                    OpRHSConst->getAPIntValue().isSignMask()) {
-                  // Note that we have to rebuild the RHS constant here to
-                  // ensure we don't rely on particular values of undef lanes.
-                  OpRHS = DAG.getConstant(OpRHSConst->getAPIntValue(), DL, VT);
-                  return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
-                }
+              APInt SplatValue;
+              if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
+                  ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
+                  ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
+                  SplatValue.isSignMask()) {
+                // Note that we have to rebuild the RHS constant here to
+                // ensure we don't rely on particular values of undef lanes.
+                OpRHS = DAG.getConstant(SplatValue, DL, VT);
+                return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
               }
             }
           }

diff  --git a/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll b/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
index 197e00bb947a7..ada166fa672a8 100644
--- a/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/combine-sats.ll
@@ -101,10 +101,7 @@ define <vscale x 2 x i64> @vselect_sub_nxv2i64(<vscale x 2 x i64> %a0, <vscale x
 ; CHECK-LABEL: vselect_sub_nxv2i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT:    vmsleu.vv v0, v10, v8
-; CHECK-NEXT:    vsub.vv v26, v8, v10
-; CHECK-NEXT:    vmv.v.i v28, 0
-; CHECK-NEXT:    vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT:    vssubu.vv v8, v8, v10
 ; CHECK-NEXT:    ret
   %cmp = icmp uge <vscale x 2 x i64> %a0, %a1
   %v1 = sub <vscale x 2 x i64> %a0, %a1
@@ -131,9 +128,7 @@ define <vscale x 8 x i16> @vselect_sub_2_nxv8i16(<vscale x 8 x i16> %x, i16 zero
 ; CHECK-LABEL: vselect_sub_2_nxv8i16:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, mu
-; CHECK-NEXT:    vmsltu.vx v0, v8, a0
-; CHECK-NEXT:    vsub.vx v26, v8, a0
-; CHECK-NEXT:    vmerge.vim v8, v26, 0, v0
+; CHECK-NEXT:    vssubu.vx v8, v8, a0
 ; CHECK-NEXT:    ret
 entry:
   %0 = insertelement <vscale x 8 x i16> undef, i16 %w, i32 0
@@ -163,11 +158,9 @@ define <2 x i64> @vselect_add_const_v2i64(<2 x i64> %a0) {
 define <vscale x 2 x i64> @vselect_add_const_nxv2i64(<vscale x 2 x i64> %a0) {
 ; CHECK-LABEL: vselect_add_const_nxv2i64:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT:    vadd.vi v26, v8, -6
-; CHECK-NEXT:    vmsgtu.vi v0, v8, 5
-; CHECK-NEXT:    vmv.v.i v28, 0
-; CHECK-NEXT:    vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT:    addi a0, zero, 6
+; CHECK-NEXT:    vsetvli a1, zero, e64, m2, ta, mu
+; CHECK-NEXT:    vssubu.vx v8, v8, a0
 ; CHECK-NEXT:    ret
   %cm1 = insertelement <vscale x 2 x i64> poison, i64 -6, i32 0
   %splatcm1 = shufflevector <vscale x 2 x i64> %cm1, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
@@ -205,27 +198,17 @@ define <vscale x 2 x i16> @vselect_add_const_signbit_nxv2i16(<vscale x 2 x i16>
 ; RV32-LABEL: vselect_add_const_signbit_nxv2i16:
 ; RV32:       # %bb.0:
 ; RV32-NEXT:    lui a0, 8
-; RV32-NEXT:    addi a0, a0, -2
+; RV32-NEXT:    addi a0, a0, -1
 ; RV32-NEXT:    vsetvli a1, zero, e16, mf2, ta, mu
-; RV32-NEXT:    vmsgtu.vx v0, v8, a0
-; RV32-NEXT:    lui a0, 1048568
-; RV32-NEXT:    addi a0, a0, 1
-; RV32-NEXT:    vadd.vx v25, v8, a0
-; RV32-NEXT:    vmv.v.i v26, 0
-; RV32-NEXT:    vmerge.vvm v8, v26, v25, v0
+; RV32-NEXT:    vssubu.vx v8, v8, a0
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: vselect_add_const_signbit_nxv2i16:
 ; RV64:       # %bb.0:
 ; RV64-NEXT:    lui a0, 8
-; RV64-NEXT:    addiw a0, a0, -2
+; RV64-NEXT:    addiw a0, a0, -1
 ; RV64-NEXT:    vsetvli a1, zero, e16, mf2, ta, mu
-; RV64-NEXT:    vmsgtu.vx v0, v8, a0
-; RV64-NEXT:    lui a0, 1048568
-; RV64-NEXT:    addiw a0, a0, 1
-; RV64-NEXT:    vadd.vx v25, v8, a0
-; RV64-NEXT:    vmv.v.i v26, 0
-; RV64-NEXT:    vmerge.vvm v8, v26, v25, v0
+; RV64-NEXT:    vssubu.vx v8, v8, a0
 ; RV64-NEXT:    ret
   %cm1 = insertelement <vscale x 2 x i16> poison, i16 32766, i32 0
   %splatcm1 = shufflevector <vscale x 2 x i16> %cm1, <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer
@@ -255,12 +238,9 @@ define <2 x i16> @vselect_xor_const_signbit_v2i16(<2 x i16> %a0) {
 define <vscale x 2 x i16> @vselect_xor_const_signbit_nxv2i16(<vscale x 2 x i16> %a0) {
 ; CHECK-LABEL: vselect_xor_const_signbit_nxv2i16:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, mu
-; CHECK-NEXT:    vmsle.vi v0, v8, -1
-; CHECK-NEXT:    vmv.v.i v25, 0
-; CHECK-NEXT:    lui a0, 1048568
-; CHECK-NEXT:    vxor.vx v26, v8, a0
-; CHECK-NEXT:    vmerge.vvm v8, v25, v26, v0
+; CHECK-NEXT:    lui a0, 8
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, mu
+; CHECK-NEXT:    vssubu.vx v8, v8, a0
 ; CHECK-NEXT:    ret
   %cmp = icmp slt <vscale x 2 x i16> %a0, zeroinitializer
   %ins = insertelement <vscale x 2 x i16> poison, i16 -32768, i32 0
@@ -291,10 +271,7 @@ define <vscale x 2 x i64> @vselect_add_nxv2i64(<vscale x 2 x i64> %a0, <vscale x
 ; CHECK-LABEL: vselect_add_nxv2i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT:    vadd.vv v26, v8, v10
-; CHECK-NEXT:    vmsleu.vv v0, v8, v26
-; CHECK-NEXT:    vmv.v.i v28, -1
-; CHECK-NEXT:    vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT:    vsaddu.vv v8, v8, v10
 ; CHECK-NEXT:    ret
   %v1 = add <vscale x 2 x i64> %a0, %a1
   %cmp = icmp ule <vscale x 2 x i64> %a0, %v1
@@ -323,10 +300,7 @@ define <vscale x 2 x i64> @vselect_add_const_2_nxv2i64(<vscale x 2 x i64> %a0) {
 ; CHECK-LABEL: vselect_add_const_2_nxv2i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT:    vadd.vi v26, v8, 6
-; CHECK-NEXT:    vmsleu.vi v0, v8, -7
-; CHECK-NEXT:    vmv.v.i v28, -1
-; CHECK-NEXT:    vmerge.vvm v8, v28, v26, v0
+; CHECK-NEXT:    vsaddu.vi v8, v8, 6
 ; CHECK-NEXT:    ret
   %cm1 = insertelement <vscale x 2 x i64> poison, i64 6, i32 0
   %splatcm1 = shufflevector <vscale x 2 x i64> %cm1, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer


        


More information about the llvm-commits mailing list