[llvm] [ValueTracking] Let ComputeKnownSignBits handle (shl (zext X), C) (PR #97693)

Björn Pettersson via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 18 03:11:56 PDT 2024


https://github.com/bjope updated https://github.com/llvm/llvm-project/pull/97693

>From 5e932a3b069e0dbd7723b7d8d479af7cf9b6fc77 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Thu, 4 Jul 2024 17:00:53 +0200
Subject: [PATCH 1/2] [ValueTracking] Pre-commit ComputeNumSignBits test for
 (shl (zext X), C)

Adding a test case for potential simplifications of
  (shl (zext X), C)
based on number of known sign bits in X.
---
 .../Analysis/ValueTracking/numsignbits-shl.ll | 168 ++++++++++++++++++
 1 file changed, 168 insertions(+)
 create mode 100644 llvm/test/Analysis/ValueTracking/numsignbits-shl.ll

diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
new file mode 100644
index 0000000000000..c32b6ff12b426
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
@@ -0,0 +1,168 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+declare void @escape(i16 %add)
+declare void @escape2(<2 x i16> %add)
+
+define void @numsignbits_shl_zext(i8 %x) {
+; CHECK-LABEL: define void @numsignbits_shl_zext(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i8 [[X]], 5
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[ASHR]] to i16
+; CHECK-NEXT:    [[NSB4:%.*]] = shl i16 [[ZEXT]], 10
+; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB4]], 16384
+; CHECK-NEXT:    [[ADD14:%.*]] = add i16 [[AND14]], [[NSB4]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
+; CHECK-NEXT:    [[AND13:%.*]] = and i16 [[NSB4]], 8192
+; CHECK-NEXT:    [[ADD13:%.*]] = add i16 [[AND13]], [[NSB4]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD13]])
+; CHECK-NEXT:    [[AND12:%.*]] = and i16 [[NSB4]], 4096
+; CHECK-NEXT:    [[ADD12:%.*]] = add i16 [[AND12]], [[NSB4]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD12]])
+; CHECK-NEXT:    [[AND11:%.*]] = and i16 [[NSB4]], 2048
+; CHECK-NEXT:    [[ADD11:%.*]] = add i16 [[AND11]], [[NSB4]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD11]])
+; CHECK-NEXT:    ret void
+;
+  %ashr = ashr i8 %x, 5
+  %zext = zext i8 %ashr to i16
+  %nsb4 = shl i16 %zext, 10
+  ; Validate ComputeNumSignBits using this simplification:
+  ;   (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
+  ; 4 sign bits: Goal is to fold away the add for bits 12-14.
+  %and14 = and i16 %nsb4, 16384
+  %add14 = add i16 %and14, %nsb4
+  call void @escape(i16 %add14)
+  %and13 = and i16 %nsb4, 8192
+  %add13 = add i16 %and13, %nsb4
+  call void @escape(i16 %add13)
+  %and12 = and i16 %nsb4, 4096
+  %add12 = add i16 %and12, %nsb4
+  call void @escape(i16 %add12)
+  %and11 = and i16 %nsb4, 2048
+  %add11 = add i16 %and11, %nsb4
+  call void @escape(i16 %add11)
+  ret void
+}
+
+define void @numsignbits_shl_zext_shift_amounr_matches_extend(i8 %x) {
+; CHECK-LABEL: define void @numsignbits_shl_zext_shift_amounr_matches_extend(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i8 [[X]], 2
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[ASHR]] to i16
+; CHECK-NEXT:    [[NSB3:%.*]] = shl nuw i16 [[ZEXT]], 8
+; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB3]], 16384
+; CHECK-NEXT:    [[ADD14:%.*]] = add i16 [[AND14]], [[NSB3]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
+; CHECK-NEXT:    [[AND13:%.*]] = and i16 [[NSB3]], 8192
+; CHECK-NEXT:    [[ADD13:%.*]] = add i16 [[AND13]], [[NSB3]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD13]])
+; CHECK-NEXT:    [[AND12:%.*]] = and i16 [[NSB3]], 4096
+; CHECK-NEXT:    [[ADD12:%.*]] = add i16 [[AND12]], [[NSB3]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD12]])
+; CHECK-NEXT:    ret void
+;
+  %ashr = ashr i8 %x, 2
+  %zext = zext i8 %ashr to i16
+  %nsb3 = shl i16 %zext, 8
+  ; Validate ComputeNumSignBits using this simplification:
+  ;   (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
+  ; 3 sign bits: Goal is to fold away the add for bits 13-14.
+  %and14 = and i16 %nsb3, 16384
+  %add14 = add i16 %and14, %nsb3
+  call void @escape(i16 %add14)
+  %and13 = and i16 %nsb3, 8192
+  %add13 = add i16 %and13, %nsb3
+  call void @escape(i16 %add13)
+  %and12 = and i16 %nsb3, 4096
+  %add12 = add i16 %and12, %nsb3
+  call void @escape(i16 %add12)
+  ret void
+}
+
+define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) {
+; CHECK-LABEL: define void @numsignbits_shl_zext_extended_bits_remains(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i8 [[X]], 5
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[ASHR]] to i16
+; CHECK-NEXT:    [[NSB1:%.*]] = shl nuw nsw i16 [[ZEXT]], 7
+; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB1]], 16384
+; CHECK-NEXT:    [[ADD14:%.*]] = add nuw i16 [[AND14]], [[NSB1]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
+; CHECK-NEXT:    ret void
+;
+  %ashr = ashr i8 %x, 5
+  %zext = zext i8 %ashr to i16
+  %nsb1 = shl i16 %zext, 7
+  ; Validate ComputeNumSignBits using this simplification:
+  ;   (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
+  ; 1 sign bit: The add can't be folded away here.
+  %and14 = and i16 %nsb1, 16384
+  %add14 = add i16 %and14, %nsb1
+  call void @escape(i16 %add14)
+  ret void
+}
+
+define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) {
+; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[ASHR:%.*]] = lshr i8 [[X]], 5
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16
+; CHECK-NEXT:    [[NSB1:%.*]] = shl i16 [[ZEXT]], 14
+; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB1]], 16384
+; CHECK-NEXT:    [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]]
+; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
+; CHECK-NEXT:    ret void
+;
+  %ashr = ashr i8 %x, 5
+  %zext = zext i8 %ashr to i16
+  %nsb1 = shl i16 %zext, 14
+  ; Validate ComputeNumSignBits using this simplification:
+  ;   (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
+  ; 1 sign bit: The add can't be folded away here.
+  %and14 = and i16 %nsb1, 16384
+  %add14 = add i16 %and14, %nsb1
+  call void @escape(i16 %add14)
+  ret void
+}
+
+define void @numsignbits_shl_zext_vector(<2 x i8> %x) {
+; CHECK-LABEL: define void @numsignbits_shl_zext_vector(
+; CHECK-SAME: <2 x i8> [[X:%.*]]) {
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr <2 x i8> [[X]], <i8 5, i8 5>
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i8> [[ASHR]] to <2 x i16>
+; CHECK-NEXT:    [[NSB4:%.*]] = shl <2 x i16> [[ZEXT]], <i16 10, i16 10>
+; CHECK-NEXT:    [[AND14:%.*]] = and <2 x i16> [[NSB4]], <i16 16384, i16 16384>
+; CHECK-NEXT:    [[ADD14:%.*]] = add <2 x i16> [[AND14]], [[NSB4]]
+; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD14]])
+; CHECK-NEXT:    [[AND13:%.*]] = and <2 x i16> [[NSB4]], <i16 8192, i16 8192>
+; CHECK-NEXT:    [[ADD13:%.*]] = add <2 x i16> [[AND13]], [[NSB4]]
+; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD13]])
+; CHECK-NEXT:    [[AND12:%.*]] = and <2 x i16> [[NSB4]], <i16 4096, i16 4096>
+; CHECK-NEXT:    [[ADD12:%.*]] = add <2 x i16> [[AND12]], [[NSB4]]
+; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD12]])
+; CHECK-NEXT:    [[AND11:%.*]] = and <2 x i16> [[NSB4]], <i16 2048, i16 2048>
+; CHECK-NEXT:    [[ADD11:%.*]] = add <2 x i16> [[AND11]], [[NSB4]]
+; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD11]])
+; CHECK-NEXT:    ret void
+;
+  %ashr = ashr <2 x i8> %x, <i8 5, i8 5>
+  %zext = zext <2 x i8> %ashr to <2 x i16>
+  %nsb4 = shl <2 x i16> %zext, <i16 10, i16 10>
+  ; Validate ComputeNumSignBits using this simplification:
+  ;   (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
+  ; 4 sign bits: Goal is to fold away the add for bits 12-14.
+  %and14 = and <2 x i16> %nsb4, <i16 16384, i16 16384>
+  %add14 = add <2 x i16> %and14, %nsb4
+  call void @escape2(<2 x i16> %add14)
+  %and13 = and <2 x i16> %nsb4, <i16 8192, i16 8192>
+  %add13 = add <2 x i16> %and13, %nsb4
+  call void @escape2(<2 x i16> %add13)
+  %and12 = and <2 x i16> %nsb4, <i16 4096, i16 4096>
+  %add12 = add <2 x i16> %and12, %nsb4
+  call void @escape2(<2 x i16> %add12)
+  %and11 = and <2 x i16> %nsb4, <i16 2048, i16 2048>
+  %add11 = add <2 x i16> %and11, %nsb4
+  call void @escape2(<2 x i16> %add11)
+  ret void
+}

>From 0deccda3eee27b6c2b7a81d400b8ed3d8efab460 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Thu, 18 Jul 2024 11:54:22 +0200
Subject: [PATCH 2/2] [ValueTracking] Let ComputeKnownSignBits handle (shl
 (zext X), C) (#97693)

Add simple support for looking through a zext when doing
ComputeKnownSignBits for shl. This is valid for the case when
all extended bits are shifted out, because then the number of sign
bits can be found by analysing the zext operand.

The solution here is simple as it only handle a single zext (not
passing remaining left shift amount during recursion). It could be
possible to generalize this in the future by for example passing an
'OffsetFromMSB' parameter to ComputeNumSignBitsImpl, telling it to
calculate number of sign bits starting at some offset from the most
significant bit.
---
 llvm/lib/Analysis/ValueTracking.cpp           | 17 +++++++++--
 .../Analysis/ValueTracking/numsignbits-shl.ll | 30 +++++++------------
 2 files changed, 25 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 535a248a5f1a2..03eb6ef42b0ff 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3866,11 +3866,22 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     }
     case Instruction::Shl: {
       const APInt *ShAmt;
+      Value *X = nullptr;
       if (match(U->getOperand(1), m_APInt(ShAmt))) {
         // shl destroys sign bits.
-        Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
-        if (ShAmt->uge(TyBits) ||   // Bad shift.
-            ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
+        if (ShAmt->uge(TyBits))
+          break; // Bad shift.
+        // We can look through a zext (more or less treating it as a sext) if
+        // all extended bits are shifted out.
+        if (match(U->getOperand(0), m_ZExt(m_Value(X))) &&
+            ShAmt->uge(TyBits - X->getType()->getScalarSizeInBits())) {
+          Tmp = ComputeNumSignBits(X, DemandedElts, Depth + 1, Q);
+          Tmp += TyBits - X->getType()->getScalarSizeInBits();
+        } else
+          Tmp =
+              ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
+        if (ShAmt->uge(Tmp))
+          break; // Shifted all sign bits out.
         Tmp2 = ShAmt->getZExtValue();
         return Tmp - Tmp2;
       }
diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
index c32b6ff12b426..e86d28ebbc1d2 100644
--- a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
+++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
@@ -10,17 +10,14 @@ define void @numsignbits_shl_zext(i8 %x) {
 ; CHECK-NEXT:    [[ASHR:%.*]] = ashr i8 [[X]], 5
 ; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[ASHR]] to i16
 ; CHECK-NEXT:    [[NSB4:%.*]] = shl i16 [[ZEXT]], 10
-; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB4]], 16384
-; CHECK-NEXT:    [[ADD14:%.*]] = add i16 [[AND14]], [[NSB4]]
+; CHECK-NEXT:    [[ADD14:%.*]] = and i16 [[NSB4]], 15360
 ; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
-; CHECK-NEXT:    [[AND13:%.*]] = and i16 [[NSB4]], 8192
-; CHECK-NEXT:    [[ADD13:%.*]] = add i16 [[AND13]], [[NSB4]]
+; CHECK-NEXT:    [[ADD13:%.*]] = and i16 [[NSB4]], 7168
 ; CHECK-NEXT:    call void @escape(i16 [[ADD13]])
-; CHECK-NEXT:    [[AND12:%.*]] = and i16 [[NSB4]], 4096
-; CHECK-NEXT:    [[ADD12:%.*]] = add i16 [[AND12]], [[NSB4]]
+; CHECK-NEXT:    [[ADD12:%.*]] = and i16 [[NSB4]], 3072
 ; CHECK-NEXT:    call void @escape(i16 [[ADD12]])
 ; CHECK-NEXT:    [[AND11:%.*]] = and i16 [[NSB4]], 2048
-; CHECK-NEXT:    [[ADD11:%.*]] = add i16 [[AND11]], [[NSB4]]
+; CHECK-NEXT:    [[ADD11:%.*]] = add nsw i16 [[AND11]], [[NSB4]]
 ; CHECK-NEXT:    call void @escape(i16 [[ADD11]])
 ; CHECK-NEXT:    ret void
 ;
@@ -51,14 +48,12 @@ define void @numsignbits_shl_zext_shift_amounr_matches_extend(i8 %x) {
 ; CHECK-NEXT:    [[ASHR:%.*]] = ashr i8 [[X]], 2
 ; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[ASHR]] to i16
 ; CHECK-NEXT:    [[NSB3:%.*]] = shl nuw i16 [[ZEXT]], 8
-; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB3]], 16384
-; CHECK-NEXT:    [[ADD14:%.*]] = add i16 [[AND14]], [[NSB3]]
+; CHECK-NEXT:    [[ADD14:%.*]] = and i16 [[NSB3]], 16128
 ; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
-; CHECK-NEXT:    [[AND13:%.*]] = and i16 [[NSB3]], 8192
-; CHECK-NEXT:    [[ADD13:%.*]] = add i16 [[AND13]], [[NSB3]]
+; CHECK-NEXT:    [[ADD13:%.*]] = and i16 [[NSB3]], 7936
 ; CHECK-NEXT:    call void @escape(i16 [[ADD13]])
 ; CHECK-NEXT:    [[AND12:%.*]] = and i16 [[NSB3]], 4096
-; CHECK-NEXT:    [[ADD12:%.*]] = add i16 [[AND12]], [[NSB3]]
+; CHECK-NEXT:    [[ADD12:%.*]] = add nsw i16 [[AND12]], [[NSB3]]
 ; CHECK-NEXT:    call void @escape(i16 [[ADD12]])
 ; CHECK-NEXT:    ret void
 ;
@@ -132,17 +127,14 @@ define void @numsignbits_shl_zext_vector(<2 x i8> %x) {
 ; CHECK-NEXT:    [[ASHR:%.*]] = ashr <2 x i8> [[X]], <i8 5, i8 5>
 ; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i8> [[ASHR]] to <2 x i16>
 ; CHECK-NEXT:    [[NSB4:%.*]] = shl <2 x i16> [[ZEXT]], <i16 10, i16 10>
-; CHECK-NEXT:    [[AND14:%.*]] = and <2 x i16> [[NSB4]], <i16 16384, i16 16384>
-; CHECK-NEXT:    [[ADD14:%.*]] = add <2 x i16> [[AND14]], [[NSB4]]
+; CHECK-NEXT:    [[ADD14:%.*]] = and <2 x i16> [[NSB4]], <i16 15360, i16 15360>
 ; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD14]])
-; CHECK-NEXT:    [[AND13:%.*]] = and <2 x i16> [[NSB4]], <i16 8192, i16 8192>
-; CHECK-NEXT:    [[ADD13:%.*]] = add <2 x i16> [[AND13]], [[NSB4]]
+; CHECK-NEXT:    [[ADD13:%.*]] = and <2 x i16> [[NSB4]], <i16 7168, i16 7168>
 ; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD13]])
-; CHECK-NEXT:    [[AND12:%.*]] = and <2 x i16> [[NSB4]], <i16 4096, i16 4096>
-; CHECK-NEXT:    [[ADD12:%.*]] = add <2 x i16> [[AND12]], [[NSB4]]
+; CHECK-NEXT:    [[ADD12:%.*]] = and <2 x i16> [[NSB4]], <i16 3072, i16 3072>
 ; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD12]])
 ; CHECK-NEXT:    [[AND11:%.*]] = and <2 x i16> [[NSB4]], <i16 2048, i16 2048>
-; CHECK-NEXT:    [[ADD11:%.*]] = add <2 x i16> [[AND11]], [[NSB4]]
+; CHECK-NEXT:    [[ADD11:%.*]] = add nsw <2 x i16> [[AND11]], [[NSB4]]
 ; CHECK-NEXT:    call void @escape2(<2 x i16> [[ADD11]])
 ; CHECK-NEXT:    ret void
 ;



More information about the llvm-commits mailing list