[llvm] [ValueTracking] Let ComputeKnownSignBits handle (shl (zext X), C) (PR #97693)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 4 01:30:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Björn Pettersson (bjope)

<details>
<summary>Changes</summary>

Add simple support for looking through a zext when doing ComputeKnownSignBits for shl. This is valid for the case when all extended bits are shifted out, because then the number of sign bits can be found by analysing the zext operand.

The solution here is simple as it only handle a single zext (not passing remaining left shift amount during recursion). It could be possible to generalize this in the future by for example passing an 'OffsetFromMSB' parameter to ComputeNumSignBitsImpl, telling it to calculate number of sign bits starting at some offset from the most significant bit.

---
Full diff: https://github.com/llvm/llvm-project/pull/97693.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/ValueTracking.cpp (+13-3) 
- (modified) llvm/unittests/Analysis/ValueTrackingTest.cpp (+30) 


``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 5476dc5d851829..7d229f58d3f6a9 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3757,11 +3757,21 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
     }
     case Instruction::Shl: {
       const APInt *ShAmt;
+      Value *X = nullptr;
       if (match(U->getOperand(1), m_APInt(ShAmt))) {
         // shl destroys sign bits.
-        Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
-        if (ShAmt->uge(TyBits) ||   // Bad shift.
-            ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
+        if (ShAmt->uge(TyBits))
+          break; // Bad shift.
+        // We can look through a zext (more or less treating it as a sext) if
+        // all extended bits are shifted out.
+        if (match(U->getOperand(0), m_ZExt(m_Value(X))) &&
+            ShAmt->uge(TyBits - X->getType()->getScalarSizeInBits())) {
+          Tmp = ComputeNumSignBits(X, Depth + 1, Q);
+          Tmp += TyBits - X->getType()->getScalarSizeInBits();
+        } else
+          Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
+        if (ShAmt->uge(Tmp))
+          break; // Shifted all sign bits out.
         Tmp2 = ShAmt->getZExtValue();
         return Tmp - Tmp2;
       }
diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index a30db468c77291..e850338ff8ff1e 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -697,6 +697,36 @@ TEST_F(ValueTrackingTest, ComputeNumSignBits_PR32045) {
   EXPECT_EQ(ComputeNumSignBits(A, M->getDataLayout()), 32u);
 }
 
+TEST_F(ValueTrackingTest, ComputeNumSignBits_shl_ext1) {
+  parseAssembly("define i32 @test(i8 %a) {\n"
+                "  %b = ashr i8 %a, 4\n"
+                "  %c = zext i8 %b to i32\n"
+                "  %A = shl i32 %c, 24\n"
+                "  ret i32 %A\n"
+                "}\n");
+  EXPECT_EQ(ComputeNumSignBits(A, M->getDataLayout()), 5u);
+}
+
+TEST_F(ValueTrackingTest, ComputeNumSignBits_shl_ext2) {
+  parseAssembly("define i32 @test(i8 %a) {\n"
+                "  %b = ashr i8 %a, 4\n"
+                "  %c = zext i8 %b to i32\n"
+                "  %A = shl i32 %c, 26\n"
+                "  ret i32 %A\n"
+                "}\n");
+  EXPECT_EQ(ComputeNumSignBits(A, M->getDataLayout()), 3u);
+}
+
+TEST_F(ValueTrackingTest, ComputeNumSignBits_shl_ext3) {
+  parseAssembly("define i32 @test(i8 %a) {\n"
+                "  %b = ashr i8 %a, 4\n"
+                "  %c = zext i8 %b to i32\n"
+                "  %A = shl i32 %c, 30\n"
+                "  ret i32 %A\n"
+                "}\n");
+  EXPECT_EQ(ComputeNumSignBits(A, M->getDataLayout()), 1u);
+}
+
 // No guarantees for canonical IR in this analysis, so this just bails out.
 TEST_F(ValueTrackingTest, ComputeNumSignBits_Shuffle) {
   parseAssembly(

``````````

</details>


https://github.com/llvm/llvm-project/pull/97693


More information about the llvm-commits mailing list