[Mlir-commits] [mlir] [mlir][intrange] Fix `arith.shl` inference in case of overflow (PR #91737)

Felix Schneider llvmlistbot at llvm.org
Fri May 10 06:12:01 PDT 2024


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/91737

When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in that case.

This bug shows up most clearly when we peform `arith.shli` on a full-range integer, so for example an i8 between 0 and 0xff:
```mlir
// mlir-opt --int-range-optimizations
func.func @test() -> index {
  %cst1 = arith.constant 1 : i8
  %0 = test.with_bounds { umin = 0 : index, umax = 255 : index, smin = -128 : index, smax = 127 : index }
  %i8val = arith.index_cast %0 : index to i8
  %shifted = arith.shli %i8val, %cst1 : i8
  %si = arith.index_cast %shifted : i8 to index
  %1 = test.reflect_bounds %si
// gets optimized to 
// %4 = test.reflect_bounds {smax = 0 : index, smin = -2 : index, umax = -1 : index, umin = 0 : index} %3
  return %1: index
}
```
The result range will get vastly underestimated to [-2, 0] which leads to wrong optimizations.

Fix https://github.com/llvm/llvm-project/issues/82158

>From f7019ea20f5967aad080d3f7beb5286a83dbf3c5 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 10 May 2024 14:58:46 +0200
Subject: [PATCH] [mlir][intrange] Fix `arith.shl` inference in case of
 overflow

When an overflow happens during shift left, i.e. the last sign bit or
the most significant data bit gets shifted out, the current approach of
inferring the range of results does not work anymore.

This patch checks for possible overflow and returns the max range in that
case.

Fix https://github.com/llvm/llvm-project/issues/82158
---
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 12 ++++++++
 mlir/test/Dialect/Arith/int-range-opts.mlir   | 29 +++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..12fae495b761e 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -548,6 +548,18 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
                         const APInt &r) -> std::optional<APInt> {
     return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
   };
+
+  // The minMax inference does not work when there is danger of overflow. In the
+  // signed case, this leads to the obvious problem that the sign bit might
+  // change. In the unsigned case, it also leads to problems because the largest
+  // LHS shifted by the largest RHS does not necessarily result in the largest
+  // result anymore.
+  bool signbitSafe =
+      (lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
+      (lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
+  if (!signbitSafe)
+    return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
+
   ConstantIntRanges urange =
       minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
                /*isSigned=*/false);
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
   %1 = arith.cmpi sle, %0, %cst1 : index
   return %1: i1
 }
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+  %i8val = arith.index_cast %0 : index to i8
+  %shifted = arith.shli %i8val, %cst1 : i8
+  %si = arith.index_cast %shifted : i8 to index
+  %1 = test.reflect_bounds %si
+  return %1: index
+}
+



More information about the Mlir-commits mailing list