[Mlir-commits] [mlir] 2034f2f - [mlir][intrange] Use `nsw`, `nuw` flags in inference (#92642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 22 00:02:35 PDT 2024


Author: Felix Schneider
Date: 2024-05-22T09:02:31+02:00
New Revision: 2034f2fc8729bd4645ef7caa3c5c6efa284d2d3f

URL: https://github.com/llvm/llvm-project/commit/2034f2fc8729bd4645ef7caa3c5c6efa284d2d3f
DIFF: https://github.com/llvm/llvm-project/commit/2034f2fc8729bd4645ef7caa3c5c6efa284d2d3f.diff

LOG: [mlir][intrange] Use `nsw`,`nuw` flags in inference (#92642)

This patch includes the "no signed wrap" and "no unsigned wrap" flags,
which can be used to annotate some Ops in the `arith` dialect and also
in LLVMIR, in the integer range inference.

The general approach is to use saturating arithmetic operations to infer
bounds which are assumed to not wrap and use overflowing arithmetic
operations in the normal case. If overflow is detected in the normal
case,
special handling makes sure that we don't underestimate the result
range.

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
    mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
    mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
    mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
    mlir/test/Dialect/Arith/int-range-interface.mlir
    mlir/test/Dialect/Arith/int-range-opts.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 97c97c23ba82a..851bb534bc7ee 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
@@ -31,6 +32,18 @@ static constexpr unsigned indexMaxWidth = 64;
 
 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
 
+enum class OverflowFlags : uint32_t {
+  None = 0,
+  Nsw = 1,
+  Nuw = 2,
+  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+
+/// Function that performs inference on an array of `ConstantIntRanges` while
+/// taking special overflow behavior into account.
+using InferRangeWithOvfFlagsFn =
+    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
@@ -60,11 +73,14 @@ ConstantIntRanges extSIRange(const ConstantIntRanges &range,
 ConstantIntRanges truncRange(const ConstantIntRanges &range,
                              unsigned destWidth);
 
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
 
@@ -94,7 +110,8 @@ ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
 
 ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
 
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                           OverflowFlags ovfFlags = OverflowFlags::None);
 
 ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
 

diff  --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 71eb36bb07a6e..fbe2ecab8adca 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,16 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::intrange;
 
+static intrange::OverflowFlags
+convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
+  intrange::OverflowFlags retFlags = intrange::OverflowFlags::None;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
+    retFlags |= intrange::OverflowFlags::Nsw;
+  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
+    retFlags |= intrange::OverflowFlags::Nuw;
+  return retFlags;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -38,7 +48,8 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges));
+  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -47,7 +58,8 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges));
+  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,7 +68,8 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges));
+  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,7 +315,8 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges));
+  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index b6b8a136791c7..64adb6b850524 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -44,19 +44,32 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // we take the 64-bit result).
 //===----------------------------------------------------------------------===//
 
+// Some arithmetic inference functions allow specifying special overflow / wrap
+// behavior. We do not require this for the IndexOps and use this helper to call
+// the inference function without any `OverflowFlags`.
+static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
+inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
+  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+    return inferWithOvfFn(argRanges, OverflowFlags::None);
+  };
+}
+
 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+                                           argRanges, CmpMode::Both));
 }
 
 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+                                           argRanges, CmpMode::Both));
 }
 
 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+                                           argRanges, CmpMode::Both));
 }
 
 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
@@ -127,7 +140,8 @@ void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+                                           argRanges, CmpMode::Both));
 }
 
 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,

diff  --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 6af229cae10ab..fe1a67d628738 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -178,18 +178,24 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+
+  std::function uadd = [=](const APInt &a,
+                           const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.uadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.uadd_sat(b)
+                       : a.uadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn sadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  std::function sadd = [=](const APInt &a,
+                           const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.sadd_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.sadd_sat(b)
+                       : a.sadd_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -205,19 +211,24 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn usub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  std::function usub = [=](const APInt &a,
+                           const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.usub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.usub_sat(b)
+                       : a.usub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn ssub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  std::function ssub = [=](const APInt &a,
+                           const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.ssub_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.ssub_sat(b)
+                       : a.ssub_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
   ConstantIntRanges urange = computeBoundsBy(
@@ -232,19 +243,24 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
-  ConstArithFn umul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  std::function umul = [=](const APInt &a,
+                           const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.umul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? a.umul_sat(b)
+                       : a.umul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
-  ConstArithFn smul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
+  std::function smul = [=](const APInt &a,
+                           const APInt &b) -> std::optional<APInt> {
     bool overflowed = false;
-    APInt result = a.smul_ov(b, overflowed);
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? a.smul_sat(b)
+                       : a.smul_ov(b, overflowed);
     return overflowed ? std::optional<APInt>() : result;
   };
 
@@ -542,32 +558,35 @@ mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges,
+                         OverflowFlags ovfFlags) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &lhsSMin = lhs.smin(), &lhsSMax = lhs.smax(),
-              &lhsUMax = lhs.umax(), &rhsUMin = rhs.umin(),
-              &rhsUMax = rhs.umax();
+  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
 
-  ConstArithFn shl = [](const APInt &l,
-                        const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
+  // 2^rhs.
+  std::function ushl = [=](const APInt &l,
+                           const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nuw)
+                       ? l.ushl_sat(r)
+                       : l.ushl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  std::function sshl = [=](const APInt &l,
+                           const APInt &r) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = any(ovfFlags & OverflowFlags::Nsw)
+                       ? l.sshl_sat(r)
+                       : l.sshl_ov(r, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
   };
-
-  // The minMax inference does not work when there is danger of overflow. In the
-  // signed case, this leads to the obvious problem that the sign bit might
-  // change. In the unsigned case, it also leads to problems because the largest
-  // LHS shifted by the largest RHS does not necessarily result in the largest
-  // result anymore.
-  assert(rhsUMax.isNonNegative() && "Unexpected negative shift count");
-  if (rhsUMax.uge(lhsSMin.getNumSignBits()) ||
-      rhsUMax.uge(lhsSMax.getNumSignBits()))
-    return ConstantIntRanges::maxRange(lhsUMax.getBitWidth());
 
   ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhsUMax}, {rhsUMin, rhsUMax},
+      minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/false);
   ConstantIntRanges srange =
-      minMaxBy(shl, {lhsSMin, lhsSMax}, {rhsUMin, rhsUMax},
+      minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
                /*isSigned=*/true);
   return urange.intersection(srange);
 }

diff  --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 17d3fcfc13ce6..5b538197a0c11 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -766,3 +766,136 @@ func.func @test_i8_bounds() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+// CHECK-LABEL: func @test_add_1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+func.func @test_add_1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// Tests below check inference with overflow flags.
+
+// CHECK-LABEL: func @test_add_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap1() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_wrap2() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // smax overflow
+  %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_add_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+func.func @test_add_i8_nowrap() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
+  // nsw flag stops smax from overflowing
+  %1 = arith.addi %0, %cst1 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap1() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+func.func @test_sub_i8_wrap2() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // umin underflows
+  %1 = arith.subi %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_sub_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+func.func @test_sub_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
+  // nuw flag stops umin from underflowing
+  %1 = arith.subi %0, %cst10 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_wrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_wrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.muli %0, %cst10 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_mul_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, umax = 127 : ui8, umin = 100 : ui8}
+func.func @test_mul_i8_nowrap() -> i8 {
+  %cst10 = arith.constant 10 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // nsw stops overflow
+  %1 = arith.muli %0, %cst10 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_wrap1
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 160 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_wrap1() -> i8 {
+  %cst3 = arith.constant 3 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.shli %0, %cst3 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_wrap2
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 160 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_wrap2() -> i8 {
+  %cst3 = arith.constant 3 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // smax overflows
+  %1 = arith.shli %0, %cst3 overflow<nuw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
+
+// CHECK-LABEL: func @test_shl_i8_nowrap
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 80 : si8, umax = 127 : ui8, umin = 80 : ui8}
+func.func @test_shl_i8_nowrap() -> i8 {
+  %cst3 = arith.constant 3 : i8
+  %0 = test.with_bounds { umin = 10 : i8, umax = 20 : ui8, smin = 10 : i8, smax = 20 : i8 } : i8
+  // nsw stops smax overflow
+  %1 = arith.shli %0, %cst3 overflow<nsw> : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}

diff  --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index 71174f1c5ef0c..dd62a481a1246 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -87,7 +87,7 @@ func.func @test() -> i8 {
 // -----
 
 // CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 254 : ui8, umin = 0 : ui8}
 func.func @test() -> i8 {
   %cst1 = arith.constant 1 : i8
   %i8val = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8


        


More information about the Mlir-commits mailing list