[Mlir-commits] [mlir] [mlir][intrange] Use `nsw`, `nuw` flags in inference (PR #92642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 18 02:43:59 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the `arith` dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal case,
special handling makes sure that we don't underestimate the result range.

Stacked on https://github.com/llvm/llvm-project/pull/92641

---

Patch is 24.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92642.diff


7 Files Affected:

- (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+21-4) 
- (modified) mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (+18-4) 
- (modified) mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp (+18-4) 
- (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+58-40) 
- (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+134-1) 
- (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+2-2) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+14-5) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 97c97c23ba82a..851bb534bc7ee 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
 
 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
 
+enum class OverflowFlags : uint32_t {
+  None = 0,
+  Nsw = 1,
+  Nuw = 2,
+  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+
+/// Function that performs inference on an array of `ConstantIntRanges` while
+/// taking special overflow behavior into account.
+using InferRangeWithOvfFlagsFn =
+    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
 ConstantIntRanges truncRange(const ConstantIntRanges &range,
                              unsigned destWidth);
 
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
 
@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
 
 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
 
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
 
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 71eb36bb07a6e..6024f37359ad5 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,16 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::intrange;
 
+intrange::OverflowFlags
+convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
+  intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
+    retFlags |= intrange::OverflowFlags::Nsw;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
+    retFlags |= intrange::OverflowFlags::Nuw;
+  return retFlags;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges));
+  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges));
+  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges));
+  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges));
+  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index b6b8a136791c7..ffa74f0c0faae 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // we take the 64-bit result).
 //===----------------------------------------------------------------------===//
 
+// Some arithmetic inference functions allow specifying special overflow / wrap
+// behavior. We do not require this for the IndexOps and use this helper to call
+// the inference function without any `OverflowFlags`.
+std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
+  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+    return inferWithOvfFn(argRanges, OverflowFlags::None);
+  };
+}
+
 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+                                           argRanges, CmpMode::Both));
 }
 
 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+                                           argRanges, CmpMode::Both));
 }
 
 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+                                           argRanges, CmpMode::Both));
 }
 
 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+                                           argRanges, CmpMode::Both));
 }
 
 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 6af229cae10ab..a1422db70fa37 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -178,18 +178,23 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn uadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.uadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.uadd_sat(b)
+                       : a.uadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn sadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn sadd = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.sadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.sadd_sat(b)
+                       : a.sadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -205,19 +210,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn usub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn usub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.usub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.usub_sat(b)
+                       : a.usub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn ssub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn ssub = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.ssub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.ssub_sat(b)
+                       : a.ssub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
   ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +242,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn umul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn umul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.umul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.umul_sat(b)
+                       : a.umul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn smul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  ConstArithFn smul = [ovfFlags](const APInt &a,
+                                 const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.smul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.smul_sat(b)
+                       : a.smul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -542,32 +557,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
-              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
-              &rhsUMax = rhs.umax();
+  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
 
-  ConstArithFn shl = [](const APInt &l,
-                        const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
+  // 2^rhs.
+  ConstArithFn ushl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? l.ushl_sat(r)
+                       : l.ushl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn sshl = [ovfFlags](const APInt &l,
+                                 const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? l.sshl_sat(r)
+                       : l.sshl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
   };
-
-  // The minMax inference does not work when there is danger of overflow. In the
-  // signed case, this leads to the obvious problem that the sign bit might
-  // change. In the unsigned case, it also leads to problems because the largest
-  // LHS shifted by the largest RHS does not necessarily result in the largest
-  // result anymore.
-  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
-  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
-      rhsUMax.uge(lhsSMax.getNumSignBits()))
-    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
 
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
+      minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
+      minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 16524b3634723..5b538197a0c11 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -758,7 +758,7 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
 }
 
 // CHECK-LABEL: func @test_i8_bounds
-// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
 func.func @test_i8_bounds() -> i8 {
   %cst1 = arith.constant 1 : i8
   %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
@@ -766,3 +766,136 @@ func.func @test_i8_bounds() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+// CHECK-LABEL: func @test_add_1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+func.func @test_add_1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// Tests below check inference with overflow flags.
+
+// CHECK-LABEL: func @test_add_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap2() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_nowrap() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // nsw flag stops smax from overflowing
+  %1 = arith.addi %0, %cst1 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap1() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap2() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+func.func @test_sub_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // nuw flag stops umin from underflowing
+  %1 = arith.subi %0, %cst10 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_wrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_wrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.muli %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, uma...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/92642


More information about the Mlir-commits mailing list