[clang] [llvm] [ConstantRange] Estimate tighter lower (upper) bounds for masked binary and (or) (PR #120352)

Stephen Senran Zhang via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 19 01:52:31 PST 2024


https://github.com/zsrkmyn updated https://github.com/llvm/llvm-project/pull/120352

>From 3351cf82f3fef3bf22cd274e16c3e23133cd7754 Mon Sep 17 00:00:00 2001
From: Senran Zhang <zsrkmyn at gmail.com>
Date: Tue, 17 Dec 2024 16:15:25 +0800
Subject: [PATCH] [ConstantRange] Estimate tighter lower (upper) bounds for
 masked binary and (or)

---
 clang/test/CodeGen/AArch64/fpm-helpers.c      |  18 +--
 llvm/lib/IR/ConstantRange.cpp                 | 105 +++++++++++++++++-
 .../SCCP/range-and-or-bit-masked.ll           |  88 +++++++++++++++
 3 files changed, 196 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/Transforms/SCCP/range-and-or-bit-masked.ll

diff --git a/clang/test/CodeGen/AArch64/fpm-helpers.c b/clang/test/CodeGen/AArch64/fpm-helpers.c
index 4bced01d5c71fa..3b356c0d1136d1 100644
--- a/clang/test/CodeGen/AArch64/fpm-helpers.c
+++ b/clang/test/CodeGen/AArch64/fpm-helpers.c
@@ -35,7 +35,7 @@ extern "C" {
 //
 fpm_t test_init() { return __arm_fpm_init(); }
 
-// CHECK-LABEL: define dso_local noundef i64 @test_src1_1(
+// CHECK-LABEL: define dso_local noundef range(i64 0, -4) i64 @test_src1_1(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 -8
@@ -44,7 +44,7 @@ fpm_t test_src1_1() {
   return __arm_set_fpm_src1_format(INIT_ONES, __ARM_FPM_E5M2);
 }
 
-// CHECK-LABEL: define dso_local noundef i64 @test_src1_2(
+// CHECK-LABEL: define dso_local noundef range(i64 0, -4) i64 @test_src1_2(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 1
@@ -53,7 +53,7 @@ fpm_t test_src1_2() {
   return __arm_set_fpm_src1_format(INIT_ZERO, __ARM_FPM_E4M3);
 }
 
-// CHECK-LABEL: define dso_local noundef i64 @test_src2_1(
+// CHECK-LABEL: define dso_local noundef range(i64 0, -32) i64 @test_src2_1(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 -57
@@ -62,7 +62,7 @@ fpm_t test_src2_1() {
   return __arm_set_fpm_src2_format(INIT_ONES, __ARM_FPM_E5M2);
 }
 
-// CHECK-LABEL: define dso_local noundef i64 @test_src2_2(
+// CHECK-LABEL: define dso_local noundef range(i64 0, -32) i64 @test_src2_2(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 8
@@ -71,7 +71,7 @@ fpm_t test_src2_2() {
   return __arm_set_fpm_src2_format(INIT_ZERO, __ARM_FPM_E4M3);
 }
 
-// CHECK-LABEL: define dso_local noundef i64 @test_dst1_1(
+// CHECK-LABEL: define dso_local noundef range(i64 0, -256) i64 @test_dst1_1(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 -449
@@ -80,7 +80,7 @@ fpm_t test_dst1_1() {
   return __arm_set_fpm_dst_format(INIT_ONES, __ARM_FPM_E5M2);
 }
 
-// CHECK-LABEL: define dso_local noundef i64 @test_dst2_2(
+// CHECK-LABEL: define dso_local noundef range(i64 0, -256) i64 @test_dst2_2(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 64
@@ -139,21 +139,21 @@ fpm_t test_lscale() { return __arm_set_fpm_lscale(INIT_ZERO, 127); }
 //
 fpm_t test_lscale2() { return __arm_set_fpm_lscale2(INIT_ZERO, 63); }
 
-// CHECK-LABEL: define dso_local noundef range(i64 0, 4294967296) i64 @test_nscale_1(
+// CHECK-LABEL: define dso_local noundef range(i64 0, 4286578688) i64 @test_nscale_1(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 2147483648
 //
 fpm_t test_nscale_1() { return __arm_set_fpm_nscale(INIT_ZERO, -128); }
 
-// CHECK-LABEL: define dso_local noundef range(i64 0, 4294967296) i64 @test_nscale_2(
+// CHECK-LABEL: define dso_local noundef range(i64 0, 4286578688) i64 @test_nscale_2(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 2130706432
 //
 fpm_t test_nscale_2() { return __arm_set_fpm_nscale(INIT_ZERO, 127); }
 
-// CHECK-LABEL: define dso_local noundef range(i64 0, 4294967296) i64 @test_nscale_3(
+// CHECK-LABEL: define dso_local noundef range(i64 0, 4286578688) i64 @test_nscale_3(
 // CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    ret i64 4278190080
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index d81a292916fdea..14e35514ca0ff2 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -1520,15 +1520,101 @@ ConstantRange ConstantRange::binaryNot() const {
   return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this);
 }
 
+/// Estimate the 'bit-masked AND' operation's lower bound.
+///
+/// E.g., given two ranges as follows (single quotes are separators and
+/// have no meaning here),
+///
+///   LHS = [10'001'010,  ; LLo
+///          10'100'000]  ; LHi
+///   RHS = [10'111'010,  ; RLo
+///          10'111'100]  ; RHi
+///
+/// we know that the higher 2 bits of the result is always '10'; and note that
+/// there's at least one bit is 1 in LHS[3:6] (since the range is continuous),
+/// and all bits in RHS[3:6] are 1, so we know the lower bound of the result is
+/// 10'001'000.
+///
+/// The algorithm is as follows,
+/// 1. we first calculate a mask to mask out the higher common bits by
+///       Mask = (LLo ^ LHi) | (LLo ^ LHi) | (LLo ^ RLo);
+///       Mask = set all non-leading-zero bits to 1 for Mask;
+/// 2. find the bit field with at least 1 in LHS (i.e., bit 3:6 in the example)
+///    after applying the mask, with
+///       StartBit = BitWidth - (LLo & Mask).clz() - 1;
+///       EndBit = BitWidth - (LHi & Mask).clz();
+/// 3. check if all bits in [StartBit:EndBit] in RHS are 1, and all bits of
+///    RLo and RHi in [StartBit:BitWidth] are same, and if so, the lower bound
+///    can be updated to
+///       LowerBound = LLo & Keep;
+///    where Keep is a mask to mask out trailing bits (the lower 3 bits in the
+///    example);
+/// 4. repeat the step 2 and 3 with LHS and RHS swapped, and update the lower
+///    bound with the larger one.
+static APInt estimateBitMaskedAndLowerBound(const ConstantRange &LHS,
+                                            const ConstantRange &RHS) {
+  auto BitWidth = LHS.getBitWidth();
+  // If either is full set or unsigned wrapped, then the range must contain '0'
+  // which leads the lower bound to 0.
+  if ((LHS.isFullSet() || RHS.isFullSet()) ||
+      (LHS.isWrappedSet() || RHS.isWrappedSet()))
+    return APInt::getZero(BitWidth);
+
+  auto LLo = LHS.getLower();
+  auto LHi = LHS.getUpper() - 1;
+  auto RLo = RHS.getLower();
+  auto RHi = RHS.getUpper() - 1;
+
+  // Calculate the mask that mask out the higher common bits.
+  auto Mask = (LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo);
+  unsigned LeadingZeros = Mask.countLeadingZeros();
+  Mask.setLowBits(BitWidth - LeadingZeros);
+
+  auto estimateBound =
+      [BitWidth, &Mask](const APInt &ALo, const APInt &AHi, const APInt &BLo,
+                        const APInt &BHi) -> std::optional<APInt> {
+    unsigned LeadingZeros = (ALo & Mask).countLeadingZeros();
+    if (LeadingZeros == BitWidth)
+      return std::nullopt;
+
+    unsigned StartBit = BitWidth - LeadingZeros - 1;
+
+    if (BLo.extractBits(BitWidth - StartBit, StartBit) !=
+        BHi.extractBits(BitWidth - StartBit, StartBit))
+      return std::nullopt;
+
+    unsigned EndBit = BitWidth - (AHi & Mask).countLeadingZeros();
+    if (!(BLo.extractBits(EndBit - StartBit, StartBit) &
+          BHi.extractBits(EndBit - StartBit, StartBit))
+             .isAllOnes())
+      return std::nullopt;
+
+    APInt Keep(BitWidth, 0);
+    Keep.setBits(StartBit, BitWidth);
+    return Keep & ALo;
+  };
+
+  auto LowerBoundByLHS = estimateBound(LLo, LHi, RLo, RHi);
+  auto LowerBoundByRHS = estimateBound(RLo, RHi, LLo, LHi);
+
+  if (LowerBoundByLHS && LowerBoundByRHS)
+    return APIntOps::umax(*LowerBoundByLHS, *LowerBoundByRHS);
+  if (LowerBoundByLHS)
+    return *LowerBoundByLHS;
+  if (LowerBoundByRHS)
+    return *LowerBoundByRHS;
+  return APInt::getZero(BitWidth);
+}
+
 ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const {
   if (isEmptySet() || Other.isEmptySet())
     return getEmpty();
 
   ConstantRange KnownBitsRange =
       fromKnownBits(toKnownBits() & Other.toKnownBits(), false);
-  ConstantRange UMinUMaxRange =
-      getNonEmpty(APInt::getZero(getBitWidth()),
-                  APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
+  auto LowerBound = estimateBitMaskedAndLowerBound(*this, Other);
+  ConstantRange UMinUMaxRange = getNonEmpty(
+      LowerBound, APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
   return KnownBitsRange.intersectWith(UMinUMaxRange);
 }
 
@@ -1538,10 +1624,17 @@ ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const {
 
   ConstantRange KnownBitsRange =
       fromKnownBits(toKnownBits() | Other.toKnownBits(), false);
+
+  //      ~a & ~b    >= x
+  // <=>  ~(~a & ~b) <= ~x
+  // <=>  a | b      <= ~x
+  // <=>  a | b      <  ~x + 1 = -x
+  // thus, UpperBound(a | b) == -LowerBound(~a & ~b)
+  auto UpperBound =
+      -estimateBitMaskedAndLowerBound(binaryNot(), Other.binaryNot());
   // Upper wrapped range.
-  ConstantRange UMaxUMinRange =
-      getNonEmpty(APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()),
-                  APInt::getZero(getBitWidth()));
+  ConstantRange UMaxUMinRange = getNonEmpty(
+      APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), UpperBound);
   return KnownBitsRange.intersectWith(UMaxUMinRange);
 }
 
diff --git a/llvm/test/Transforms/SCCP/range-and-or-bit-masked.ll b/llvm/test/Transforms/SCCP/range-and-or-bit-masked.ll
new file mode 100644
index 00000000000000..e81c5d739c6d29
--- /dev/null
+++ b/llvm/test/Transforms/SCCP/range-and-or-bit-masked.ll
@@ -0,0 +1,88 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=ipsccp %s | FileCheck %s
+
+declare void @use(i1)
+
+define i1 @test1(i64 %x) {
+; CHECK-LABEL: @test1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp ugt i64 [[X:%.*]], 65535
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[MASK:%.*]] = and i64 [[X]], -65521
+; CHECK-NEXT:    ret i1 false
+;
+entry:
+  %cond = icmp ugt i64 %x, 65535
+  call void @llvm.assume(i1 %cond)
+  %mask = and i64 %x, -65521
+  %cmp = icmp eq i64 %mask, 0
+  ret i1 %cmp
+}
+
+define void @test.and(i64 %x, i64 %y) {
+; CHECK-LABEL: @test.and(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[C0:%.*]] = icmp uge i64 [[X:%.*]], 138
+; CHECK-NEXT:    [[C1:%.*]] = icmp ule i64 [[X]], 161
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C0]])
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C1]])
+; CHECK-NEXT:    [[C2:%.*]] = icmp uge i64 [[Y:%.*]], 186
+; CHECK-NEXT:    [[C3:%.*]] = icmp ule i64 [[Y]], 188
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C2]])
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C3]])
+; CHECK-NEXT:    [[AND:%.*]] = and i64 [[X]], [[Y]]
+; CHECK-NEXT:    call void @use(i1 false)
+; CHECK-NEXT:    [[R1:%.*]] = icmp ult i64 [[AND]], 137
+; CHECK-NEXT:    call void @use(i1 [[R1]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %c0 = icmp uge i64 %x, 138 ; 0b10001010
+  %c1 = icmp ule i64 %x, 161 ; 0b10100000
+  call void @llvm.assume(i1 %c0)
+  call void @llvm.assume(i1 %c1)
+  %c2 = icmp uge i64 %y, 186 ; 0b10111010
+  %c3 = icmp ule i64 %y, 188 ; 0b10111110
+  call void @llvm.assume(i1 %c2)
+  call void @llvm.assume(i1 %c3)
+  %and = and i64 %x, %y
+  %r0 = icmp ult i64 %and, 136 ; 0b10001000
+  call void @use(i1 %r0) ; false
+  %r1 = icmp ult i64 %and, 137
+  call void @use(i1 %r1) ; unknown
+  ret void
+}
+
+define void @test.or(i64 %x, i64 %y) {
+; CHECK-LABEL: @test.or(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[C0:%.*]] = icmp ule i64 [[X:%.*]], 117
+; CHECK-NEXT:    [[C1:%.*]] = icmp uge i64 [[X]], 95
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C0]])
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C1]])
+; CHECK-NEXT:    [[C2:%.*]] = icmp ule i64 [[Y:%.*]], 69
+; CHECK-NEXT:    [[C3:%.*]] = icmp uge i64 [[Y]], 67
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C2]])
+; CHECK-NEXT:    call void @llvm.assume(i1 [[C3]])
+; CHECK-NEXT:    [[OR:%.*]] = or i64 [[X]], [[Y]]
+; CHECK-NEXT:    call void @use(i1 false)
+; CHECK-NEXT:    [[R1:%.*]] = icmp ugt i64 [[OR]], 118
+; CHECK-NEXT:    call void @use(i1 [[R1]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %c0 = icmp ule i64 %x, 117 ; 0b01110101
+  %c1 = icmp uge i64 %x, 95  ; 0b01011111
+  call void @llvm.assume(i1 %c0)
+  call void @llvm.assume(i1 %c1)
+  %c2 = icmp ule i64 %y, 69  ; 0b01000101
+  %c3 = icmp uge i64 %y, 67  ; 0b01000011
+  call void @llvm.assume(i1 %c2)
+  call void @llvm.assume(i1 %c3)
+  %or = or i64 %x, %y
+  %r0 = icmp ugt i64 %or, 119 ; 0b01110111
+  call void @use(i1 %r0) ; false
+  %r1 = icmp ugt i64 %or, 118
+  call void @use(i1 %r1) ; unknown
+  ret void
+}



More information about the llvm-commits mailing list