[llvm] 117d755 - [DAG] SimplifyDemandedBits - use ComputeKnownBits instead of getValidShiftAmountConstant to check for constant shift amounts. (#92412)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 16 09:04:34 PDT 2024


Author: Simon Pilgrim
Date: 2024-05-16T17:04:30+01:00
New Revision: 117d755b1b84c7d379ea5c3d93f8c2ab9bfcde82

URL: https://github.com/llvm/llvm-project/commit/117d755b1b84c7d379ea5c3d93f8c2ab9bfcde82
DIFF: https://github.com/llvm/llvm-project/commit/117d755b1b84c7d379ea5c3d93f8c2ab9bfcde82.diff

LOG: [DAG] SimplifyDemandedBits - use ComputeKnownBits instead of getValidShiftAmountConstant to check for constant shift amounts. (#92412)

This allows us to handle cases where the constant has already been type legalized behind a bitcast

Despite calling ComputeKnownBits I'm not seeing any notable change in compile time.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/test/CodeGen/AArch64/sadd_sat_vec.ll
    llvm/test/CodeGen/AArch64/ssub_sat_vec.ll
    llvm/test/CodeGen/PowerPC/pr44183.ll
    llvm/test/CodeGen/RISCV/rvv/bitreverse-sdnode.ll
    llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 38583de03d9c8..3ec6b9b795079 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1732,9 +1732,9 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op1 = Op.getOperand(1);
     EVT ShiftVT = Op1.getValueType();
 
-    if (const APInt *SA =
-            TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = SA->getZExtValue();
+    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
+    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
+      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1744,9 +1744,10 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SRL) {
         if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
-            unsigned C1 = SA2->getZExtValue();
+          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
+                                                       DemandedElts, Depth + 1);
+          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
+            unsigned C1 = InnerSA.getConstant().getZExtValue();
             unsigned Opc = ISD::SHL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -1912,9 +1913,9 @@ bool TargetLowering::SimplifyDemandedBits(
                                         DemandedElts, Depth + 1))
       return TLO.CombineTo(Op, AVG);
 
-    if (const APInt *SA =
-            TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = SA->getZExtValue();
+    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
+    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
+      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
@@ -1924,9 +1925,10 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SHL) {
         if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
-          if (const APInt *SA2 =
-                  TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
-            unsigned C1 = SA2->getZExtValue();
+          KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
+                                                       DemandedElts, Depth + 1);
+          if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
+            unsigned C1 = InnerSA.getConstant().getZExtValue();
             unsigned Opc = ISD::SRL;
             int Diff = ShAmt - C1;
             if (Diff < 0) {
@@ -2018,24 +2020,25 @@ bool TargetLowering::SimplifyDemandedBits(
                                         DemandedElts, Depth + 1))
       return TLO.CombineTo(Op, AVG);
 
-    if (const APInt *SA =
-            TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
-      unsigned ShAmt = SA->getZExtValue();
+    KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
+    if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
+      unsigned ShAmt = KnownSA.getConstant().getZExtValue();
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
       // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
       // supports sext_inreg.
       if (Op0.getOpcode() == ISD::SHL) {
-        if (const APInt *InnerSA =
-                TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
+        KnownBits InnerSA = TLO.DAG.computeKnownBits(Op0.getOperand(1),
+                                                     DemandedElts, Depth + 1);
+        if (InnerSA.isConstant() && InnerSA.getConstant().ult(BitWidth)) {
           unsigned LowBits = BitWidth - ShAmt;
           EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
           if (VT.isVector())
             ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
                                      VT.getVectorElementCount());
 
-          if (*InnerSA == ShAmt) {
+          if (InnerSA.getConstant() == ShAmt) {
             if (!TLO.LegalOperations() ||
                 getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) == Legal)
               return TLO.CombineTo(

diff  --git a/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll b/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll
index 8a0e7661883f2..49ed989766746 100644
--- a/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll
+++ b/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll
@@ -149,12 +149,12 @@ define void @v4i8(ptr %px, ptr %py, ptr %pz) nounwind {
 ; CHECK-SD:       // %bb.0:
 ; CHECK-SD-NEXT:    ldr s0, [x0]
 ; CHECK-SD-NEXT:    ldr s1, [x1]
-; CHECK-SD-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-SD-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-SD-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-SD-NEXT:    ushll v1.8h, v1.8b, #0
 ; CHECK-SD-NEXT:    shl v1.4h, v1.4h, #8
 ; CHECK-SD-NEXT:    shl v0.4h, v0.4h, #8
 ; CHECK-SD-NEXT:    sqadd v0.4h, v0.4h, v1.4h
-; CHECK-SD-NEXT:    sshr v0.4h, v0.4h, #8
+; CHECK-SD-NEXT:    ushr v0.4h, v0.4h, #8
 ; CHECK-SD-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
 ; CHECK-SD-NEXT:    str s0, [x2]
 ; CHECK-SD-NEXT:    ret
@@ -364,10 +364,6 @@ define void @v1i16(ptr %px, ptr %py, ptr %pz) nounwind {
 define <16 x i4> @v16i4(<16 x i4> %x, <16 x i4> %y) nounwind {
 ; CHECK-LABEL: v16i4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    shl v0.16b, v0.16b, #4
-; CHECK-NEXT:    shl v1.16b, v1.16b, #4
-; CHECK-NEXT:    sshr v0.16b, v0.16b, #4
-; CHECK-NEXT:    sshr v1.16b, v1.16b, #4
 ; CHECK-NEXT:    shl v1.16b, v1.16b, #4
 ; CHECK-NEXT:    shl v0.16b, v0.16b, #4
 ; CHECK-NEXT:    sqadd v0.16b, v0.16b, v1.16b

diff  --git a/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll b/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll
index a8c1276eadc4f..f6b58023da1f3 100644
--- a/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll
+++ b/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll
@@ -150,12 +150,12 @@ define void @v4i8(ptr %px, ptr %py, ptr %pz) nounwind {
 ; CHECK-SD:       // %bb.0:
 ; CHECK-SD-NEXT:    ldr s0, [x0]
 ; CHECK-SD-NEXT:    ldr s1, [x1]
-; CHECK-SD-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-SD-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-SD-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-SD-NEXT:    ushll v1.8h, v1.8b, #0
 ; CHECK-SD-NEXT:    shl v1.4h, v1.4h, #8
 ; CHECK-SD-NEXT:    shl v0.4h, v0.4h, #8
 ; CHECK-SD-NEXT:    sqsub v0.4h, v0.4h, v1.4h
-; CHECK-SD-NEXT:    sshr v0.4h, v0.4h, #8
+; CHECK-SD-NEXT:    ushr v0.4h, v0.4h, #8
 ; CHECK-SD-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
 ; CHECK-SD-NEXT:    str s0, [x2]
 ; CHECK-SD-NEXT:    ret
@@ -365,10 +365,6 @@ define void @v1i16(ptr %px, ptr %py, ptr %pz) nounwind {
 define <16 x i4> @v16i4(<16 x i4> %x, <16 x i4> %y) nounwind {
 ; CHECK-LABEL: v16i4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    shl v0.16b, v0.16b, #4
-; CHECK-NEXT:    shl v1.16b, v1.16b, #4
-; CHECK-NEXT:    sshr v0.16b, v0.16b, #4
-; CHECK-NEXT:    sshr v1.16b, v1.16b, #4
 ; CHECK-NEXT:    shl v1.16b, v1.16b, #4
 ; CHECK-NEXT:    shl v0.16b, v0.16b, #4
 ; CHECK-NEXT:    sqsub v0.16b, v0.16b, v1.16b

diff  --git a/llvm/test/CodeGen/PowerPC/pr44183.ll b/llvm/test/CodeGen/PowerPC/pr44183.ll
index e3dca13809035..4d2c81c35b7fe 100644
--- a/llvm/test/CodeGen/PowerPC/pr44183.ll
+++ b/llvm/test/CodeGen/PowerPC/pr44183.ll
@@ -22,13 +22,8 @@ define void @_ZN1m1nEv(ptr %this) local_unnamed_addr nounwind align 2 {
 ; CHECK-NEXT:    rlwimi r4, r3, 0, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop
-; CHECK-NEXT:    ld r3, 16(r30)
-; CHECK-NEXT:    ld r4, 8(r30)
-; CHECK-NEXT:    rldicl r4, r4, 60, 4
-; CHECK-NEXT:    sldi r3, r3, 60
-; CHECK-NEXT:    or r3, r3, r4
-; CHECK-NEXT:    sldi r3, r3, 31
-; CHECK-NEXT:    rlwinm r4, r3, 0, 0, 0
+; CHECK-NEXT:    ld r3, 8(r30)
+; CHECK-NEXT:    rlwinm r4, r3, 27, 0, 0
 ; CHECK-NEXT:    bl _ZN1llsE1d
 ; CHECK-NEXT:    nop
 ; CHECK-NEXT:    addi r1, r1, 48

diff  --git a/llvm/test/CodeGen/RISCV/rvv/bitreverse-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/bitreverse-sdnode.ll
index 19ae26d242426..94e945f803205 100644
--- a/llvm/test/CodeGen/RISCV/rvv/bitreverse-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/bitreverse-sdnode.ll
@@ -8,10 +8,9 @@ define <vscale x 1 x i8> @bitreverse_nxv1i8(<vscale x 1 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv1i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, mf8, ta, ma
-; CHECK-NEXT:    vsrl.vi v9, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v9, v8
+; CHECK-NEXT:    vsll.vi v9, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v9
 ; CHECK-NEXT:    vsrl.vi v9, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v9, v9, a0
@@ -40,10 +39,9 @@ define <vscale x 2 x i8> @bitreverse_nxv2i8(<vscale x 2 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv2i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, mf4, ta, ma
-; CHECK-NEXT:    vsrl.vi v9, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v9, v8
+; CHECK-NEXT:    vsll.vi v9, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v9
 ; CHECK-NEXT:    vsrl.vi v9, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v9, v9, a0
@@ -72,10 +70,9 @@ define <vscale x 4 x i8> @bitreverse_nxv4i8(<vscale x 4 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv4i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, mf2, ta, ma
-; CHECK-NEXT:    vsrl.vi v9, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v9, v8
+; CHECK-NEXT:    vsll.vi v9, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v9
 ; CHECK-NEXT:    vsrl.vi v9, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v9, v9, a0
@@ -104,10 +101,9 @@ define <vscale x 8 x i8> @bitreverse_nxv8i8(<vscale x 8 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv8i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m1, ta, ma
-; CHECK-NEXT:    vsrl.vi v9, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v9, v8
+; CHECK-NEXT:    vsll.vi v9, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v9
 ; CHECK-NEXT:    vsrl.vi v9, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v9, v9, a0
@@ -136,10 +132,9 @@ define <vscale x 16 x i8> @bitreverse_nxv16i8(<vscale x 16 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv16i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
-; CHECK-NEXT:    vsrl.vi v10, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v10, v8
+; CHECK-NEXT:    vsll.vi v10, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v10
 ; CHECK-NEXT:    vsrl.vi v10, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v10, v10, a0
@@ -168,10 +163,9 @@ define <vscale x 32 x i8> @bitreverse_nxv32i8(<vscale x 32 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv32i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m4, ta, ma
-; CHECK-NEXT:    vsrl.vi v12, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v12, v8
+; CHECK-NEXT:    vsll.vi v12, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v12
 ; CHECK-NEXT:    vsrl.vi v12, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v12, v12, a0
@@ -200,10 +194,9 @@ define <vscale x 64 x i8> @bitreverse_nxv64i8(<vscale x 64 x i8> %va) {
 ; CHECK-LABEL: bitreverse_nxv64i8:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8, m8, ta, ma
-; CHECK-NEXT:    vsrl.vi v16, v8, 4
-; CHECK-NEXT:    vand.vi v8, v8, 15
-; CHECK-NEXT:    vsll.vi v8, v8, 4
-; CHECK-NEXT:    vor.vv v8, v16, v8
+; CHECK-NEXT:    vsll.vi v16, v8, 4
+; CHECK-NEXT:    vsrl.vi v8, v8, 4
+; CHECK-NEXT:    vor.vv v8, v8, v16
 ; CHECK-NEXT:    vsrl.vi v16, v8, 2
 ; CHECK-NEXT:    li a0, 51
 ; CHECK-NEXT:    vand.vx v16, v16, a0

diff  --git a/llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll b/llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll
index c3e9a2b6841ae..b39e2bb620d6f 100644
--- a/llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll
+++ b/llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll
@@ -3176,7 +3176,8 @@ define <2 x i64> @test_128_i64_x_2_18446744065119617024_mask_ashr_32(<2 x i64> %
 ;
 ; X86-AVX2-LABEL: test_128_i64_x_2_18446744065119617024_mask_ashr_32:
 ; X86-AVX2:       # %bb.0:
-; X86-AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
+; X86-AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [4294967294,4294967294,4294967294,4294967294]
+; X86-AVX2-NEXT:    vpand %xmm1, %xmm0, %xmm0
 ; X86-AVX2-NEXT:    vpsrad $31, %xmm0, %xmm1
 ; X86-AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[1,1,3,3]
 ; X86-AVX2-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]


        


More information about the llvm-commits mailing list