[Mlir-commits] [mlir] [mlir][intrange] Use `nsw`, `nuw` flags in inference (PR #92642)
Felix Schneider
llvmlistbot at llvm.org
Sat May 18 02:43:27 PDT 2024
https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/92642
This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the `arith` dialect and also
in LLVMIR, in the integer range inference.
The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.
Stacked on https://github.com/llvm/llvm-project/pull/92641
>From 5cebb095abb336a34fffbb430494599815eeea1b Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 18 May 2024 11:36:17 +0200
Subject: [PATCH 1/2] [mlir][intrange] Represent bounds of `ReflectBoundsOp` as
`si`/`ui`
This patch adapts the `test.reflect_bounds` test Op to use explicitly
signed and unsigned representation for signed and unsigned bounds of
`IntegerType`s.
This is mostly a cosmetic change as the internal representation of the
ranges is unchanged. However, it improves readability of tests.
Example:
```mlir
// old:
test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -56 : i8, umin = 100 : i8}
// new:
test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
```
---
.../Dialect/Arith/int-range-interface.mlir | 2 +-
mlir/test/Dialect/Arith/int-range-opts.mlir | 4 ++--
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 19 ++++++++++++++-----
3 files changed, 17 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 16524b3634723..17d3fcfc13ce6 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -758,7 +758,7 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
}
// CHECK-LABEL: func @test_i8_bounds
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
func.func @test_i8_bounds() -> i8 {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index 6179003ab4e74..71174f1c5ef0c 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -75,7 +75,7 @@ func.func @test() -> i1 {
// -----
// CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 24 : i8, smin = 0 : i8, umax = 24 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 24 : si8, smin = 0 : si8, umax = 24 : ui8, umin = 0 : ui8}
func.func @test() -> i8 {
%cst1 = arith.constant 1 : i8
%i8val = test.with_bounds { umin = 0 : i8, umax = 12 : i8, smin = 0 : i8, smax = 12 : i8 } : i8
@@ -87,7 +87,7 @@ func.func @test() -> i8 {
// -----
// CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
func.func @test() -> i8 {
%cst1 = arith.constant 1 : i8
%i8val = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index bfee0391f6708..b058a8e1abbcb 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -706,11 +706,20 @@ void TestReflectBoundsOp::inferResultRanges(
const ConstantIntRanges &range = argRanges[0];
MLIRContext *ctx = getContext();
Builder b(ctx);
- auto intTy = getType();
- setUminAttr(b.getIntegerAttr(intTy, range.umin()));
- setUmaxAttr(b.getIntegerAttr(intTy, range.umax()));
- setSminAttr(b.getIntegerAttr(intTy, range.smin()));
- setSmaxAttr(b.getIntegerAttr(intTy, range.smax()));
+ Type sIntTy, uIntTy;
+ // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
+ // Types for the Attributes.
+ if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
+ unsigned bitwidth = intTy.getWidth();
+ sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
+ uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
+ } else
+ sIntTy = uIntTy = getType();
+
+ setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
+ setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
+ setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
+ setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
setResultRanges(getResult(), range);
}
>From e501da0d7e17995ec0037aeb1212889d56269d5d Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 18 May 2024 10:50:04 +0200
Subject: [PATCH 2/2] [mlir][intrange] Use `nsw`,`nuw` flags in inference
This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the `arith` dialect and also
in LLVMIR, in the integer range inference.
The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.
---
.../Interfaces/Utils/InferIntRangeCommon.h | 25 +++-
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 22 ++-
.../Index/IR/InferIntRangeInterfaceImpls.cpp | 22 ++-
.../Interfaces/Utils/InferIntRangeCommon.cpp | 98 +++++++------
.../Dialect/Arith/int-range-interface.mlir | 133 ++++++++++++++++++
mlir/test/Dialect/Arith/int-range-opts.mlir | 2 +-
6 files changed, 249 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 97c97c23ba82a..851bb534bc7ee 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,6 +16,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
#include <optional>
namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
enum class CmpMode : uint32_t { Both, Signed, Unsigned };
+enum class OverflowFlags : uint32_t {
+ None = 0,
+ Nsw = 1,
+ Nuw = 2,
+ LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+
+/// Function that performs inference on an array of `ConstantIntRanges` while
+/// taking special overflow behavior into account.
+using InferRangeWithOvfFlagsFn =
+ function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+
/// Compute `inferFn` on `ranges`, whose size should be the index storage
/// bitwidth. Then, compute the function on `argRanges` again after truncating
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
ConstantIntRanges truncRange(const ConstantIntRanges &range,
unsigned destWidth);
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags = OverflowFlags::None);
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags = OverflowFlags::None);
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags = OverflowFlags::None);
ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags = OverflowFlags::None);
ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 71eb36bb07a6e..6024f37359ad5 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,16 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::intrange;
+intrange::OverflowFlags
+convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
+ intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
+ if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
+ retFlags |= intrange::OverflowFlags::Nsw;
+ if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
+ retFlags |= intrange::OverflowFlags::Nuw;
+ return retFlags;
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferAdd(argRanges));
+ setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+ getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferSub(argRanges));
+ setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+ getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMul(argRanges));
+ setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+ getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferShl(argRanges));
+ setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+ getOverflowFlags())));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index b6b8a136791c7..ffa74f0c0faae 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// we take the 64-bit result).
//===----------------------------------------------------------------------===//
+// Some arithmetic inference functions allow specifying special overflow / wrap
+// behavior. We do not require this for the IndexOps and use this helper to call
+// the inference function without any `OverflowFlags`.
+std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
+ return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+ return inferWithOvfFn(argRanges, OverflowFlags::None);
+ };
+}
+
void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+ setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+ argRanges, CmpMode::Both));
}
void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+ setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+ argRanges, CmpMode::Both));
}
void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+ setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+ argRanges, CmpMode::Both));
}
void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+ setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+ argRanges, CmpMode::Both));
}
void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 6af229cae10ab..a1422db70fa37 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -178,18 +178,23 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
//===----------------------------------------------------------------------===//
ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn uadd = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
+ ConstArithFn uadd = [ovfFlags](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
- APInt result = a.uadd_ov(b, overflowed);
+ APInt result = any(ovfFlags & OverflowFlags::Nuw)
+ ? a.uadd_sat(b)
+ : a.uadd_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
- ConstArithFn sadd = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
+ ConstArithFn sadd = [ovfFlags](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
- APInt result = a.sadd_ov(b, overflowed);
+ APInt result = any(ovfFlags & OverflowFlags::Nsw)
+ ? a.sadd_sat(b)
+ : a.sadd_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
@@ -205,19 +210,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
//===----------------------------------------------------------------------===//
ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn usub = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
+ ConstArithFn usub = [ovfFlags](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
- APInt result = a.usub_ov(b, overflowed);
+ APInt result = any(ovfFlags & OverflowFlags::Nuw)
+ ? a.usub_sat(b)
+ : a.usub_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
- ConstArithFn ssub = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
+ ConstArithFn ssub = [ovfFlags](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
- APInt result = a.ssub_ov(b, overflowed);
+ APInt result = any(ovfFlags & OverflowFlags::Nsw)
+ ? a.ssub_sat(b)
+ : a.ssub_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +242,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
//===----------------------------------------------------------------------===//
ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn umul = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
+ ConstArithFn umul = [ovfFlags](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
- APInt result = a.umul_ov(b, overflowed);
+ APInt result = any(ovfFlags & OverflowFlags::Nuw)
+ ? a.umul_sat(b)
+ : a.umul_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
- ConstArithFn smul = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
+ ConstArithFn smul = [ovfFlags](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
bool overflowed = false;
- APInt result = a.smul_ov(b, overflowed);
+ APInt result = any(ovfFlags & OverflowFlags::Nsw)
+ ? a.smul_sat(b)
+ : a.smul_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : result;
};
@@ -542,32 +557,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
//===----------------------------------------------------------------------===//
ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags ovfFlags) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
- &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
- &rhsUMax = rhs.umax();
+ const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
- ConstArithFn shl = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+ // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
+ // 2^rhs.
+ ConstArithFn ushl = [ovfFlags](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = any(ovfFlags & OverflowFlags::Nuw)
+ ? l.ushl_sat(r)
+ : l.ushl_ov(r, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn sshl = [ovfFlags](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = any(ovfFlags & OverflowFlags::Nsw)
+ ? l.sshl_sat(r)
+ : l.sshl_ov(r, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
};
-
- // The minMax inference does not work when there is danger of overflow. In the
- // signed case, this leads to the obvious problem that the sign bit might
- // change. In the unsigned case, it also leads to problems because the largest
- // LHS shifted by the largest RHS does not necessarily result in the largest
- // result anymore.
- assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
- if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
- rhsUMax.uge(lhsSMax.getNumSignBits()))
- return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
ConstantIntRanges urange =
- minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
+ minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
/*isSigned=*/false);
ConstantIntRanges srange =
- minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
+ minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
/*isSigned=*/true);
return urange.intersection(srange);
}
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 17d3fcfc13ce6..5b538197a0c11 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -766,3 +766,136 @@ func.func @test_i8_bounds() -> i8 {
%2 = test.reflect_bounds %1 : i8
return %2: i8
}
+
+// CHECK-LABEL: func @test_add_1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+func.func @test_add_1() -> i8 {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+ %1 = arith.addi %0, %cst1 : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// Tests below check inference with overflow flags.
+
+// CHECK-LABEL: func @test_add_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap1() -> i8 {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+ // smax overflow
+ %1 = arith.addi %0, %cst1 : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap2() -> i8 {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+ // smax overflow
+ %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_nowrap() -> i8 {
+ %cst1 = arith.constant 1 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+ // nsw flag stops smax from overflowing
+ %1 = arith.addi %0, %cst1 overflow<nsw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap1() -> i8 {
+ %cst10 = arith.constant 10 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+ // umin underflows
+ %1 = arith.subi %0, %cst10 : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap2() -> i8 {
+ %cst10 = arith.constant 10 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+ // umin underflows
+ %1 = arith.subi %0, %cst10 overflow<nsw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+func.func @test_sub_i8_nowrap() -> i8 {
+ %cst10 = arith.constant 10 : i8
+ %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+ // nuw flag stops umin from underflowing
+ %1 = arith.subi %0, %cst10 overflow<nuw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_wrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_wrap() -> i8 {
+ %cst10 = arith.constant 10 : i8
+ %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+ // smax overflows
+ %1 = arith.muli %0, %cst10 : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, umax = 127 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_nowrap() -> i8 {
+ %cst10 = arith.constant 10 : i8
+ %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+ // nsw stops overflow
+ %1 = arith.muli %0, %cst10 overflow<nsw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 160 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_wrap1() -> i8 {
+ %cst3 = arith.constant 3 : i8
+ %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+ // smax overflows
+ %1 = arith.shli %0, %cst3 : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 160 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_wrap2() -> i8 {
+ %cst3 = arith.constant 3 : i8
+ %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+ // smax overflows
+ %1 = arith.shli %0, %cst3 overflow<nuw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 80 : si8, umax = 127 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_nowrap() -> i8 {
+ %cst3 = arith.constant 3 : i8
+ %0 = test.with_bounds { umin = 10 : i8, umax = 20 : ui8, smin = 10 : i8, smax = 20 : i8 } : i8
+ // nsw stops smax overflow
+ %1 = arith.shli %0, %cst3 overflow<nsw> : i8
+ %2 = test.reflect_bounds %1 : i8
+ return %2: i8
+}
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index 71174f1c5ef0c..dd62a481a1246 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -87,7 +87,7 @@ func.func @test() -> i8 {
// -----
// CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 254 : ui8, umin = 0 : ui8}
func.func @test() -> i8 {
%cst1 = arith.constant 1 : i8
%i8val = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
More information about the Mlir-commits
mailing list