[Mlir-commits] [mlir] [mlir][intrange] Fix `arith.shl` inference in case of overflow (PR #91737)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 10 06:12:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Felix Schneider (ubfx)
<details>
<summary>Changes</summary>
When an overflow happens during shift left, i.e. the last sign bit or the most significant data bit gets shifted out, the current approach of inferring the range of results does not work anymore.
This patch checks for possible overflow and returns the max range in that case.
This bug shows up most clearly when we peform `arith.shli` on a full-range integer, so for example an i8 between 0 and 0xff:
```mlir
// mlir-opt --int-range-optimizations
func.func @<!-- -->test() -> index {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : index, umax = 255 : index, smin = -128 : index, smax = 127 : index }
%i8val = arith.index_cast %0 : index to i8
%shifted = arith.shli %i8val, %cst1 : i8
%si = arith.index_cast %shifted : i8 to index
%1 = test.reflect_bounds %si
// gets optimized to
// %4 = test.reflect_bounds {smax = 0 : index, smin = -2 : index, umax = -1 : index, umin = 0 : index} %3
return %1: index
}
```
The result range will get vastly underestimated to [-2, 0] which leads to wrong optimizations.
Fix https://github.com/llvm/llvm-project/issues/82158
---
Full diff: https://github.com/llvm/llvm-project/pull/91737.diff
2 Files Affected:
- (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+12)
- (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+29)
``````````diff
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..12fae495b761e 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -548,6 +548,18 @@ mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
const APInt &r) -> std::optional<APInt> {
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
};
+
+ // The minMax inference does not work when there is danger of overflow. In the
+ // signed case, this leads to the obvious problem that the sign bit might
+ // change. In the unsigned case, it also leads to problems because the largest
+ // LHS shifted by the largest RHS does not necessarily result in the largest
+ // result anymore.
+ bool signbitSafe =
+ (lhs.smin().getNumSignBits() > rhs.umax().getZExtValue()) &&
+ (lhs.smax().getNumSignBits() > rhs.umax().getZExtValue());
+ if (!signbitSafe)
+ return ConstantIntRanges::maxRange(lhs.umax().getBitWidth());
+
ConstantIntRanges urange =
minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
/*isSigned=*/false);
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
%1 = arith.cmpi sle, %0, %cst1 : index
return %1: i1
}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+ %i8val = arith.index_cast %0 : index to i8
+ %shifted = arith.shli %i8val, %cst1 : i8
+ %si = arith.index_cast %shifted : i8 to index
+ %1 = test.reflect_bounds %si
+ return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+ %i8val = arith.index_cast %0 : index to i8
+ %shifted = arith.shli %i8val, %cst1 : i8
+ %si = arith.index_cast %shifted : i8 to index
+ %1 = test.reflect_bounds %si
+ return %1: index
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/91737
More information about the Mlir-commits
mailing list