[Mlir-commits] [mlir] [mlir][intrange] Use `nsw`, `nuw` flags in inference (PR #92642)

Felix Schneider llvmlistbot at llvm.org
Sun May 19 02:33:02 PDT 2024


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/92642

>From 5cebb095abb336a34fffbb430494599815eeea1b Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 18 May 2024 11:36:17 +0200
Subject: [PATCH 1/4] [mlir][intrange] Represent bounds of `ReflectBoundsOp` as
 `si`/`ui`

This patch adapts the `test.reflect_bounds` test Op to use explicitly
signed and unsigned representation for signed and unsigned bounds of
`IntegerType`s.

This is mostly a cosmetic change as the internal representation of the
ranges is unchanged. However, it improves readability of tests.

Example:
```mlir
// old:
test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -56 : i8, umin = 100 : i8}
// new:
test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
```
---
 .../Dialect/Arith/int-range-interface.mlir    |  2 +-
 mlir/test/Dialect/Arith/int-range-opts.mlir   |  4 ++--
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 19 ++++++++++++++-----
 3 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 16524b3634723..17d3fcfc13ce6 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -758,7 +758,7 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
 }
 
 // CHECK-LABEL: func @test_i8_bounds
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
 func.func @test_i8_bounds() -> i8 {
   %cst1 = arith.constant 1 : i8
   %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index 6179003ab4e74..71174f1c5ef0c 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -75,7 +75,7 @@ func.func @test() -> i1 {
 // -----
 
 // CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 24 : i8, smin = 0 : i8, umax = 24 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 24 : si8, smin = 0 : si8, umax = 24 : ui8, umin = 0 : ui8}
 func.func @test() -> i8 {
   %cst1 = arith.constant 1 : i8
   %i8val = test.with_bounds { umin = 0 : i8, umax = 12 : i8, smin = 0 : i8, smax = 12 : i8 } : i8
@@ -87,7 +87,7 @@ func.func @test() -> i8 {
 // -----
 
 // CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
 func.func @test() -> i8 {
   %cst1 = arith.constant 1 : i8
   %i8val = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index bfee0391f6708..b058a8e1abbcb 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -706,11 +706,20 @@ void TestReflectBoundsOp::inferResultRanges(
   const ConstantIntRanges &range = argRanges[0];
   MLIRContext *ctx = getContext();
   Builder b(ctx);
-  auto intTy = getType();
-  setUminAttr(b.getIntegerAttr(intTy, range.umin()));
-  setUmaxAttr(b.getIntegerAttr(intTy, range.umax()));
-  setSminAttr(b.getIntegerAttr(intTy, range.smin()));
-  setSmaxAttr(b.getIntegerAttr(intTy, range.smax()));
+  Type sIntTy, uIntTy;
+  // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
+  // Types for the Attributes.
+  if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
+    unsigned bitwidth = intTy.getWidth();
+    sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
+    uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
+  } else
+    sIntTy = uIntTy = getType();
+
+  setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
+  setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
+  setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
+  setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
   setResultRanges(getResult(), range);
 }
 

>From e501da0d7e17995ec0037aeb1212889d56269d5d Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 18 May 2024 10:50:04 +0200
Subject: [PATCH 2/4] [mlir][intrange] Use `nsw`,`nuw` flags in inference

This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the `arith` dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.
---
 .../Interfaces/Utils/InferIntRangeCommon.h    |  25 +++-
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  |  22 ++-
 .../Index/IR/InferIntRangeInterfaceImpls.cpp  |  22 ++-
 .../Interfaces/Utils/InferIntRangeCommon.cpp  |  98 +++++++------
 .../Dialect/Arith/int-range-interface.mlir    | 133 ++++++++++++++++++
 mlir/test/Dialect/Arith/int-range-opts.mlir   |   2 +-
 6 files changed, 249 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 97c97c23ba82a..851bb534bc7ee 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
 
 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
 
+enum class OverflowFlags : uint32_t {
+  None = 0,
+  Nsw = 1,
+  Nuw = 2,
+  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+
+/// Function that performs inference on an array of `ConstantIntRanges` while
+/// taking special overflow behavior into account.
+using InferRangeWithOvfFlagsFn =
+    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
 ConstantIntRanges truncRange(const ConstantIntRanges &range,
                              unsigned destWidth);
 
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
 
@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
 
 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
 
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
 
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 71eb36bb07a6e..6024f37359ad5 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,16 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::intrange;
 
+intrange::OverflowFlags
+convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
+  intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
+    retFlags |= intrange::OverflowFlags::Nsw;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
+    retFlags |= intrange::OverflowFlags::Nuw;
+  return retFlags;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges));
+  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges));
+  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges));
+  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges));
+  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index b6b8a136791c7..ffa74f0c0faae 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // we take the 64-bit result).
 //===----------------------------------------------------------------------===//
 
+// Some arithmetic inference functions allow specifying special overflow / wrap
+// behavior. We do not require this for the IndexOps and use this helper to call
+// the inference function without any `OverflowFlags`.
+std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
+  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+    return inferWithOvfFn(argRanges, OverflowFlags::None);
+  };
+}
+
 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+                                           argRanges, CmpMode::Both));
 }
 
 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+                                           argRanges, CmpMode::Both));
 }
 
 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+                                           argRanges, CmpMode::Both));
 }
 
 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+                                           argRanges, CmpMode::Both));
 }
 
 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 6af229cae10ab..a1422db70fa37 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -178,18 +178,23 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn uadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.uadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.uadd_sat(b)
+                       : a.uadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn sadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn sadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.sadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.sadd_sat(b)
+                       : a.sadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -205,19 +210,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn usub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn usub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.usub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.usub_sat(b)
+                       : a.usub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn ssub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn ssub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.ssub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.ssub_sat(b)
+                       : a.ssub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
   ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +242,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn umul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn umul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.umul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.umul_sat(b)
+                       : a.umul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn smul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn smul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.smul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.smul_sat(b)
+                       : a.smul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -542,32 +557,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
-              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
-              &rhsUMax = rhs.umax();
+  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
 
-  ConstArithFn shl = [](const APInt &l,
-                        const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
+  // 2^rhs.
+  ConstArithFn ushl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? l.ushl_sat(r)
+                       : l.ushl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn sshl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? l.sshl_sat(r)
+                       : l.sshl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
   };
-
-  // The minMax inference does not work when there is danger of overflow. In the
-  // signed case, this leads to the obvious problem that the sign bit might
-  // change. In the unsigned case, it also leads to problems because the largest
-  // LHS shifted by the largest RHS does not necessarily result in the largest
-  // result anymore.
-  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
-  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
-      rhsUMax.uge(lhsSMax.getNumSignBits()))
-    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
 
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
+      minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
+      minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 17d3fcfc13ce6..5b538197a0c11 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -766,3 +766,136 @@ func.func @test_i8_bounds() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+// CHECK-LABEL: func @test_add_1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+func.func @test_add_1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// Tests below check inference with overflow flags.
+
+// CHECK-LABEL: func @test_add_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap2() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_nowrap() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // nsw flag stops smax from overflowing
+  %1 = arith.addi %0, %cst1 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap1() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap2() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+func.func @test_sub_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // nuw flag stops umin from underflowing
+  %1 = arith.subi %0, %cst10 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_wrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_wrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.muli %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, umax = 127 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // nsw stops overflow
+  %1 = arith.muli %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 160 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_wrap1() -> i8 {
+  %cst3 = arith.constant 3 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.shli %0, %cst3 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 160 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_wrap2() -> i8 {
+  %cst3 = arith.constant 3 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.shli %0, %cst3 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 80 : si8, umax = 127 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_nowrap() -> i8 {
+  %cst3 = arith.constant 3 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : ui8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // nsw stops smax overflow
+  %1 = arith.shli %0, %cst3 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index 71174f1c5ef0c..dd62a481a1246 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -87,7 +87,7 @@ func.func @test() -> i8 {
 // -----
 
 // CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 254 : ui8, umin = 0 : ui8}
 func.func @test() -> i8 {
   %cst1 = arith.constant 1 : i8
   %i8val = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8

>From fbf0ab2186359a63e8274d86155d85302f97adaf Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 18 May 2024 18:27:01 +0200
Subject: [PATCH 3/4] address review comments

---
 mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp | 2 +-
 mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 6024f37359ad5..fbe2ecab8adca 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,7 +19,7 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::intrange;
 
-intrange::OverflowFlags
+static intrange::OverflowFlags
 convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
   intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
   if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index ffa74f0c0faae..64adb6b850524 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -47,7 +47,7 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // Some arithmetic inference functions allow specifying special overflow / wrap
 // behavior. We do not require this for the IndexOps and use this helper to call
 // the inference function without any `OverflowFlags`.
-std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
 inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
   return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
     return inferWithOvfFn(argRanges, OverflowFlags::None);

>From 5f2b4ea0b1f16182453b6035a7c332a78f7671bf Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 19 May 2024 11:32:15 +0200
Subject: [PATCH 4/4] fix clang build

---
 mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index a1422db70fa37..4e42d8e5d7937 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -181,18 +181,21 @@ ConstantIntRanges
 mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
                          OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [ovfFlags](const APInt &a,
+
+  bool saturateUnsigned = any(ovfFlags & OverflowFlags::Nuw);
+  bool saturateSigned = any(ovfFlags & OverflowFlags::Nsw);
+  ConstArithFn uadd = [saturateUnsigned](const APInt &a,
                                  const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+    APInt result = saturateUnsigned
                        ? a.uadd_sat(b)
                        : a.uadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn sadd = [ovfFlags](const APInt &a,
+  ConstArithFn sadd = [saturateSigned](const APInt &a,
                                  const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+    APInt result = saturateSigned
                        ? a.sadd_sat(b)
                        : a.sadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;



More information about the Mlir-commits mailing list