[llvm] f924a3d - [SelectionDAG] Support scalable-vector splats in yet more cases

Fraser Cormack via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 26 02:24:21 PDT 2021


Author: Fraser Cormack
Date: 2021-07-26T10:15:08+01:00
New Revision: f924a3d47492b7b586ccfd1333ca086a7e2d88b2

URL: https://github.com/llvm/llvm-project/commit/f924a3d47492b7b586ccfd1333ca086a7e2d88b2
DIFF: https://github.com/llvm/llvm-project/commit/f924a3d47492b7b586ccfd1333ca086a7e2d88b2.diff

LOG: [SelectionDAG] Support scalable-vector splats in yet more cases

This patch extends support for (scalable-vector) splats in the
DAGCombiner via the `ISD::matchBinaryPredicate` function, which enable a
variety of simple combines of constants.

Users of this function may now have to distinguish between
`BUILD_VECTOR` and `SPLAT_VECTOR` vector operands. The way of dealing
with this in-tree follows the approach added for
`ISD::matchUnaryPredicate` implemented in D94501.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D106575

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/combine-splats.ll
    llvm/test/CodeGen/RISCV/rvv/urem-seteq-vec.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e1b581f65a6eb..db4685944b7f4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8615,9 +8615,14 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
     };
     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
       SDValue ShiftValue;
-      if (VT.isVector())
+      if (N1.getOpcode() == ISD::BUILD_VECTOR)
         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
-      else
+      else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+        assert(ShiftValues.size() == 1 &&
+               "Expected matchBinaryPredicate to return one element for "
+               "SPLAT_VECTORs");
+        ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
+      } else
         ShiftValue = ShiftValues[0];
       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
     }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 4aa54a74fc07a..2a98464425c40 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -342,8 +342,9 @@ bool ISD::matchBinaryPredicate(
       return Match(LHSCst, RHSCst);
 
   // TODO: Add support for vector UNDEF cases?
-  if (ISD::BUILD_VECTOR != LHS.getOpcode() ||
-      ISD::BUILD_VECTOR != RHS.getOpcode())
+  if (LHS.getOpcode() != RHS.getOpcode() ||
+      (LHS.getOpcode() != ISD::BUILD_VECTOR &&
+       LHS.getOpcode() != ISD::SPLAT_VECTOR))
     return false;
 
   EVT SVT = LHS.getValueType().getScalarType();

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fbdfa3ffd0195..1c1dae8f953fc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5605,7 +5605,7 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
     return SDValue();
 
   SDValue PVal, KVal, QVal;
-  if (VT.isVector()) {
+  if (D.getOpcode() == ISD::BUILD_VECTOR) {
     if (HadTautologicalLanes) {
       // Try to turn PAmts into a splat, since we don't care about the values
       // that are currently '0'. If we can't, just keep '0'`s.
@@ -5619,6 +5619,13 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
     PVal = DAG.getBuildVector(VT, DL, PAmts);
     KVal = DAG.getBuildVector(ShVT, DL, KAmts);
     QVal = DAG.getBuildVector(VT, DL, QAmts);
+  } else if (D.getOpcode() == ISD::SPLAT_VECTOR) {
+    assert(PAmts.size() == 1 && KAmts.size() == 1 && QAmts.size() == 1 &&
+           "Expected matchBinaryPredicate to return one element for "
+           "SPLAT_VECTORs");
+    PVal = DAG.getSplatVector(VT, DL, PAmts[0]);
+    KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]);
+    QVal = DAG.getSplatVector(VT, DL, QAmts[0]);
   } else {
     PVal = PAmts[0];
     KVal = KAmts[0];

diff  --git a/llvm/test/CodeGen/RISCV/rvv/combine-splats.ll b/llvm/test/CodeGen/RISCV/rvv/combine-splats.ll
index 0fe2dd4a2ba80..468f1a0c2b320 100644
--- a/llvm/test/CodeGen/RISCV/rvv/combine-splats.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/combine-splats.ll
@@ -7,10 +7,8 @@
 define <vscale x 4 x i32> @and_or_nxv4i32(<vscale x 4 x i32> %A) {
 ; CHECK-LABEL: and_or_nxv4i32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a0, zero, 255
-; CHECK-NEXT:    vsetvli a1, zero, e32, m2, ta, mu
-; CHECK-NEXT:    vor.vx v26, v8, a0
-; CHECK-NEXT:    vand.vi v8, v26, 8
+; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, mu
+; CHECK-NEXT:    vmv.v.i v8, 8
 ; CHECK-NEXT:    ret
   %ins1 = insertelement <vscale x 4 x i32> poison, i32 255, i32 0
   %splat1 = shufflevector <vscale x 4 x i32> %ins1, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
@@ -27,8 +25,8 @@ define <vscale x 2 x i64> @or_and_nxv2i64(<vscale x 2 x i64> %a0) {
 ; CHECK-LABEL: or_and_nxv2i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT:    vand.vi v26, v8, 7
-; CHECK-NEXT:    vor.vi v8, v26, 3
+; CHECK-NEXT:    vor.vi v26, v8, 3
+; CHECK-NEXT:    vand.vi v8, v26, 7
 ; CHECK-NEXT:    ret
   %ins1 = insertelement <vscale x 2 x i64> poison, i64 7, i32 0
   %splat1 = shufflevector <vscale x 2 x i64> %ins1, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
@@ -45,8 +43,7 @@ define <vscale x 2 x i64> @or_and_nxv2i64_fold(<vscale x 2 x i64> %a0) {
 ; CHECK-LABEL: or_and_nxv2i64_fold:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
-; CHECK-NEXT:    vand.vi v26, v8, 1
-; CHECK-NEXT:    vor.vi v8, v26, 3
+; CHECK-NEXT:    vmv.v.i v8, 3
 ; CHECK-NEXT:    ret
   %ins1 = insertelement <vscale x 2 x i64> poison, i64 1, i32 0
   %splat1 = shufflevector <vscale x 2 x i64> %ins1, <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
@@ -85,8 +82,7 @@ define <vscale x 2 x i32> @combine_vec_ashr_ashr(<vscale x 2 x i32> %x) {
 ; CHECK-LABEL: combine_vec_ashr_ashr:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, mu
-; CHECK-NEXT:    vsra.vi v25, v8, 2
-; CHECK-NEXT:    vsra.vi v8, v25, 4
+; CHECK-NEXT:    vsra.vi v8, v8, 6
 ; CHECK-NEXT:    ret
   %ins1 = insertelement <vscale x 2 x i32> poison, i32 2, i32 0
   %splat1 = shufflevector <vscale x 2 x i32> %ins1, <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
@@ -103,8 +99,7 @@ define <vscale x 8 x i16> @combine_vec_lshr_lshr(<vscale x 8 x i16> %x) {
 ; CHECK-LABEL: combine_vec_lshr_lshr:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e16, m2, ta, mu
-; CHECK-NEXT:    vsrl.vi v26, v8, 4
-; CHECK-NEXT:    vsrl.vi v8, v26, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 8
 ; CHECK-NEXT:    ret
   %ins1 = insertelement <vscale x 8 x i16> poison, i16 2, i32 0
   %splat1 = shufflevector <vscale x 8 x i16> %ins1, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer

diff  --git a/llvm/test/CodeGen/RISCV/rvv/urem-seteq-vec.ll b/llvm/test/CodeGen/RISCV/rvv/urem-seteq-vec.ll
index da7334b0bae90..6908e6280fbd2 100644
--- a/llvm/test/CodeGen/RISCV/rvv/urem-seteq-vec.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/urem-seteq-vec.ll
@@ -8,13 +8,15 @@ define <vscale x 1 x i16> @test_urem_vec_even_divisor_eq0(<vscale x 1 x i16> %x)
 ; RV32-NEXT:    lui a0, 1048571
 ; RV32-NEXT:    addi a0, a0, -1365
 ; RV32-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV32-NEXT:    vmulhu.vx v25, v8, a0
-; RV32-NEXT:    vsrl.vi v25, v25, 2
-; RV32-NEXT:    addi a0, zero, 6
-; RV32-NEXT:    vnmsub.vx v25, a0, v8
-; RV32-NEXT:    vmv.v.i v26, 0
-; RV32-NEXT:    vmsne.vi v0, v25, 0
-; RV32-NEXT:    vmerge.vim v8, v26, -1, v0
+; RV32-NEXT:    vmul.vx v25, v8, a0
+; RV32-NEXT:    vsll.vi v26, v25, 15
+; RV32-NEXT:    vsrl.vi v25, v25, 1
+; RV32-NEXT:    vor.vv v25, v25, v26
+; RV32-NEXT:    lui a0, 3
+; RV32-NEXT:    addi a0, a0, -1366
+; RV32-NEXT:    vmsgtu.vx v0, v25, a0
+; RV32-NEXT:    vmv.v.i v25, 0
+; RV32-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: test_urem_vec_even_divisor_eq0:
@@ -22,13 +24,15 @@ define <vscale x 1 x i16> @test_urem_vec_even_divisor_eq0(<vscale x 1 x i16> %x)
 ; RV64-NEXT:    lui a0, 1048571
 ; RV64-NEXT:    addiw a0, a0, -1365
 ; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV64-NEXT:    vmulhu.vx v25, v8, a0
-; RV64-NEXT:    vsrl.vi v25, v25, 2
-; RV64-NEXT:    addi a0, zero, 6
-; RV64-NEXT:    vnmsub.vx v25, a0, v8
-; RV64-NEXT:    vmv.v.i v26, 0
-; RV64-NEXT:    vmsne.vi v0, v25, 0
-; RV64-NEXT:    vmerge.vim v8, v26, -1, v0
+; RV64-NEXT:    vmul.vx v25, v8, a0
+; RV64-NEXT:    vsll.vi v26, v25, 15
+; RV64-NEXT:    vsrl.vi v25, v25, 1
+; RV64-NEXT:    vor.vv v25, v25, v26
+; RV64-NEXT:    lui a0, 3
+; RV64-NEXT:    addiw a0, a0, -1366
+; RV64-NEXT:    vmsgtu.vx v0, v25, a0
+; RV64-NEXT:    vmv.v.i v25, 0
+; RV64-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV64-NEXT:    ret
   %ins1 = insertelement <vscale x 1 x i16> poison, i16 6, i32 0
   %splat1 = shufflevector <vscale x 1 x i16> %ins1, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -46,13 +50,12 @@ define <vscale x 1 x i16> @test_urem_vec_odd_divisor_eq0(<vscale x 1 x i16> %x)
 ; RV32-NEXT:    lui a0, 1048573
 ; RV32-NEXT:    addi a0, a0, -819
 ; RV32-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV32-NEXT:    vmulhu.vx v25, v8, a0
-; RV32-NEXT:    vsrl.vi v25, v25, 2
-; RV32-NEXT:    addi a0, zero, 5
-; RV32-NEXT:    vnmsub.vx v25, a0, v8
-; RV32-NEXT:    vmv.v.i v26, 0
-; RV32-NEXT:    vmsne.vi v0, v25, 0
-; RV32-NEXT:    vmerge.vim v8, v26, -1, v0
+; RV32-NEXT:    vmul.vx v25, v8, a0
+; RV32-NEXT:    lui a0, 3
+; RV32-NEXT:    addi a0, a0, 819
+; RV32-NEXT:    vmsgtu.vx v0, v25, a0
+; RV32-NEXT:    vmv.v.i v25, 0
+; RV32-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: test_urem_vec_odd_divisor_eq0:
@@ -60,13 +63,12 @@ define <vscale x 1 x i16> @test_urem_vec_odd_divisor_eq0(<vscale x 1 x i16> %x)
 ; RV64-NEXT:    lui a0, 1048573
 ; RV64-NEXT:    addiw a0, a0, -819
 ; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV64-NEXT:    vmulhu.vx v25, v8, a0
-; RV64-NEXT:    vsrl.vi v25, v25, 2
-; RV64-NEXT:    addi a0, zero, 5
-; RV64-NEXT:    vnmsub.vx v25, a0, v8
-; RV64-NEXT:    vmv.v.i v26, 0
-; RV64-NEXT:    vmsne.vi v0, v25, 0
-; RV64-NEXT:    vmerge.vim v8, v26, -1, v0
+; RV64-NEXT:    vmul.vx v25, v8, a0
+; RV64-NEXT:    lui a0, 3
+; RV64-NEXT:    addiw a0, a0, 819
+; RV64-NEXT:    vmsgtu.vx v0, v25, a0
+; RV64-NEXT:    vmv.v.i v25, 0
+; RV64-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV64-NEXT:    ret
   %ins1 = insertelement <vscale x 1 x i16> poison, i16 5, i32 0
   %splat1 = shufflevector <vscale x 1 x i16> %ins1, <vscale x 1 x i16> poison, <vscale x 1 x i32> zeroinitializer
@@ -81,28 +83,36 @@ define <vscale x 1 x i16> @test_urem_vec_odd_divisor_eq0(<vscale x 1 x i16> %x)
 define <vscale x 1 x i16> @test_urem_vec_even_divisor_eq1(<vscale x 1 x i16> %x) nounwind {
 ; RV32-LABEL: test_urem_vec_even_divisor_eq1:
 ; RV32:       # %bb.0:
+; RV32-NEXT:    addi a0, zero, 1
+; RV32-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
+; RV32-NEXT:    vsub.vx v25, v8, a0
 ; RV32-NEXT:    lui a0, 1048571
 ; RV32-NEXT:    addi a0, a0, -1365
-; RV32-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV32-NEXT:    vmulhu.vx v25, v8, a0
-; RV32-NEXT:    vsrl.vi v25, v25, 2
-; RV32-NEXT:    addi a0, zero, 6
-; RV32-NEXT:    vnmsub.vx v25, a0, v8
-; RV32-NEXT:    vmsne.vi v0, v25, 1
+; RV32-NEXT:    vmul.vx v25, v25, a0
+; RV32-NEXT:    vsll.vi v26, v25, 15
+; RV32-NEXT:    vsrl.vi v25, v25, 1
+; RV32-NEXT:    vor.vv v25, v25, v26
+; RV32-NEXT:    lui a0, 3
+; RV32-NEXT:    addi a0, a0, -1366
+; RV32-NEXT:    vmsgtu.vx v0, v25, a0
 ; RV32-NEXT:    vmv.v.i v25, 0
 ; RV32-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: test_urem_vec_even_divisor_eq1:
 ; RV64:       # %bb.0:
+; RV64-NEXT:    addi a0, zero, 1
+; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
+; RV64-NEXT:    vsub.vx v25, v8, a0
 ; RV64-NEXT:    lui a0, 1048571
 ; RV64-NEXT:    addiw a0, a0, -1365
-; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV64-NEXT:    vmulhu.vx v25, v8, a0
-; RV64-NEXT:    vsrl.vi v25, v25, 2
-; RV64-NEXT:    addi a0, zero, 6
-; RV64-NEXT:    vnmsub.vx v25, a0, v8
-; RV64-NEXT:    vmsne.vi v0, v25, 1
+; RV64-NEXT:    vmul.vx v25, v25, a0
+; RV64-NEXT:    vsll.vi v26, v25, 15
+; RV64-NEXT:    vsrl.vi v25, v25, 1
+; RV64-NEXT:    vor.vv v25, v25, v26
+; RV64-NEXT:    lui a0, 3
+; RV64-NEXT:    addiw a0, a0, -1366
+; RV64-NEXT:    vmsgtu.vx v0, v25, a0
 ; RV64-NEXT:    vmv.v.i v25, 0
 ; RV64-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV64-NEXT:    ret
@@ -119,28 +129,30 @@ define <vscale x 1 x i16> @test_urem_vec_even_divisor_eq1(<vscale x 1 x i16> %x)
 define <vscale x 1 x i16> @test_urem_vec_odd_divisor_eq1(<vscale x 1 x i16> %x) nounwind {
 ; RV32-LABEL: test_urem_vec_odd_divisor_eq1:
 ; RV32:       # %bb.0:
+; RV32-NEXT:    addi a0, zero, 1
+; RV32-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
+; RV32-NEXT:    vsub.vx v25, v8, a0
 ; RV32-NEXT:    lui a0, 1048573
 ; RV32-NEXT:    addi a0, a0, -819
-; RV32-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV32-NEXT:    vmulhu.vx v25, v8, a0
-; RV32-NEXT:    vsrl.vi v25, v25, 2
-; RV32-NEXT:    addi a0, zero, 5
-; RV32-NEXT:    vnmsub.vx v25, a0, v8
-; RV32-NEXT:    vmsne.vi v0, v25, 1
+; RV32-NEXT:    vmul.vx v25, v25, a0
+; RV32-NEXT:    lui a0, 3
+; RV32-NEXT:    addi a0, a0, 818
+; RV32-NEXT:    vmsgtu.vx v0, v25, a0
 ; RV32-NEXT:    vmv.v.i v25, 0
 ; RV32-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: test_urem_vec_odd_divisor_eq1:
 ; RV64:       # %bb.0:
+; RV64-NEXT:    addi a0, zero, 1
+; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
+; RV64-NEXT:    vsub.vx v25, v8, a0
 ; RV64-NEXT:    lui a0, 1048573
 ; RV64-NEXT:    addiw a0, a0, -819
-; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, mu
-; RV64-NEXT:    vmulhu.vx v25, v8, a0
-; RV64-NEXT:    vsrl.vi v25, v25, 2
-; RV64-NEXT:    addi a0, zero, 5
-; RV64-NEXT:    vnmsub.vx v25, a0, v8
-; RV64-NEXT:    vmsne.vi v0, v25, 1
+; RV64-NEXT:    vmul.vx v25, v25, a0
+; RV64-NEXT:    lui a0, 3
+; RV64-NEXT:    addiw a0, a0, 818
+; RV64-NEXT:    vmsgtu.vx v0, v25, a0
 ; RV64-NEXT:    vmv.v.i v25, 0
 ; RV64-NEXT:    vmerge.vim v8, v25, -1, v0
 ; RV64-NEXT:    ret


        


More information about the llvm-commits mailing list