[Mlir-commits] [mlir] 0f79066 - [mlir][intrange] Fix `arith.shl` inference in case of overflow (#91737)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 13 10:27:43 PDT 2024
Author: Felix Schneider
Date: 2024-05-13T19:27:38+02:00
New Revision: 0f7906645d18a38a6b80a1e8e1d425396f6ab353
URL: https://github.com/llvm/llvm-project/commit/0f7906645d18a38a6b80a1e8e1d425396f6ab353
DIFF: https://github.com/llvm/llvm-project/commit/0f7906645d18a38a6b80a1e8e1d425396f6ab353.diff
LOG: [mlir][intrange] Fix `arith.shl` inference in case of overflow (#91737)
When an overflow happens during shift left, i.e. the last sign bit or
the most significant data bit gets shifted out, the current approach of
inferring the range of results does not work anymore.
This patch checks for possible overflow and returns the max range in
that case.
Fix https://github.com/llvm/llvm-project/issues/82158
Added:
Modified:
mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
mlir/test/Dialect/Arith/int-range-opts.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2b2d937d55d80..6af229cae10ab 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -544,15 +544,30 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
ConstantIntRanges
mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
+ &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
+ &rhsUMax = rhs.umax();
+
ConstArithFn shl = [](const APInt &l,
const APInt &r) -> std::optional<APInt> {
return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
};
+
+ // The minMax inference does not work when there is danger of overflow. In the
+ // signed case, this leads to the obvious problem that the sign bit might
+ // change. In the unsigned case, it also leads to problems because the largest
+ // LHS shifted by the largest RHS does not necessarily result in the largest
+ // result anymore.
+ assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
+ if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
+ rhsUMax.uge(lhsSMax.getNumSignBits()))
+ return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
+
ConstantIntRanges urange =
- minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
/*isSigned=*/false);
ConstantIntRanges srange =
- minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+ minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
/*isSigned=*/true);
return urange.intersection(srange);
}
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..4c3c0854ed026 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -71,3 +71,32 @@ func.func @test() -> i1 {
%1 = arith.cmpi sle, %0, %cst1 : index
return %1: i1
}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
+func.func @test() -> index {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
+ %i8val = arith.index_cast %0 : index to i8
+ %shifted = arith.shli %i8val, %cst1 : i8
+ %si = arith.index_cast %shifted : i8 to index
+ %1 = test.reflect_bounds %si
+ return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
+func.func @test() -> index {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
+ %i8val = arith.index_cast %0 : index to i8
+ %shifted = arith.shli %i8val, %cst1 : i8
+ %si = arith.index_cast %shifted : i8 to index
+ %1 = test.reflect_bounds %si
+ return %1: index
+}
+
More information about the Mlir-commits
mailing list