[Mlir-commits] [mlir] [mlir][intrange] Fix `arith.shl` inference in case of overflow (PR #91737)

Felix Schneider llvmlistbot at llvm.org
Fri May 10 12:04:57 PDT 2024


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/91737

>From f7019ea20f5967aad080d3f7beb5286a83dbf3c5 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 10 May 2024 14:58:46 +0200
Subject: [PATCH 1/2] [mlir][intrange] Fix `arith.shl` inference in case of
 overflow

When an overflow happens during shift left, i.e. the last sign bit or
the most significant data bit gets shifted out, the current approach of
inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in that
case.

Fix https://github.com/llvm/llvm-project/issues/82158
---
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 12 ++++++++
 mlir/test/Dialect/Arith/int-range-opts.mlir   | 29 +++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..12fae495b761e 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -548,6 +548,18 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
                         const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
   };
+
+  // The minMax inference does not work when there is danger of overflow. In the
+  // signed case, this leads to the obvious problem that the sign bit might
+  // change. In the unsigned case, it also leads to problems because the largest
+  // LHS shifted by the largest RHS does not necessarily result in the largest
+  // result anymore.
+  bool signbitSafe =
+      (lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
+      (lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
+  if (!signbitSafe)
+    return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
+
   ConstantIntRanges urange =
       minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
                /*isSigned=*/false);
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
   %1 = arith.cmpi sle, %0, %cst1 : index
   return %1: i1
 }
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+

>From bf81f6206b5db4c7bd9fed9de4783915fd697dbf Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 10 May 2024 20:51:53 +0200
Subject: [PATCH 2/2] Address review

---
 .../Interfaces/Utils/InferIntRangeCommon.cpp    | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 12fae495b761e..6af229cae10ab 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -544,6 +544,10 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
+              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
+              &rhsUMax = rhs.umax();
+
   ConstArithFn shl = [](const APInt &l,
                         const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
@@ -554,17 +558,16 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
   // change. In the unsigned case, it also leads to problems because the largest
   // LHS shifted by the largest RHS does not necessarily result in the largest
   // result anymore.
-  bool signbitSafe =
-      (lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
-      (lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
-  if (!signbitSafe)
-    return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
+  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
+  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
+      rhsUMax.uge(lhsSMax.getNumSignBits()))
+    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
 
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }



More information about the Mlir-commits mailing list