[Mlir-commits] [mlir] 0f79066 - [mlir][intrange] Fix `arith.shl` inference in case of overflow (#91737)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 13 10:27:43 PDT 2024


Author: Felix Schneider
Date: 2024-05-13T19:27:38+02:00
New Revision: 0f7906645d18a38a6b80a1e8e1d425396f6ab353

URL: https://github.com/llvm/llvm-project/commit/0f7906645d18a38a6b80a1e8e1d425396f6ab353
DIFF: https://github.com/llvm/llvm-project/commit/0f7906645d18a38a6b80a1e8e1d425396f6ab353.diff

LOG: [mlir][intrange] Fix `arith.shl` inference in case of overflow (#91737)

When an overflow happens during shift left, i.e. the last sign bit or
the most significant data bit gets shifted out, the current approach of
inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in
that case.

Fix https://github.com/llvm/llvm-project/issues/82158

Added: 
    

Modified: 
    mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
    mlir/test/Dialect/Arith/int-range-opts.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..6af229cae10ab 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -544,15 +544,30 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
+              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
+              &rhsUMax = rhs.umax();
+
   ConstArithFn shl = [](const APInt &l,
                         const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
   };
+
+  // The minMax inference does not work when there is danger of overflow. In the
+  // signed case, this leads to the obvious problem that the sign bit might
+  // change. In the unsigned case, it also leads to problems because the largest
+  // LHS shifted by the largest RHS does not necessarily result in the largest
+  // result anymore.
+  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
+  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
+      rhsUMax.uge(lhsSMax.getNumSignBits()))
+    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
+
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }

diff  --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
   %1 = arith.cmpi sle, %0, %cst1 : index
   return %1: i1
 }
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+


        


More information about the Mlir-commits mailing list