[Mlir-commits] [mlir] 75bfc6f - [mlir][Arith] Implement InferIntRangeInterface for arithmetic ops

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Jun 14 11:30:39 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-06-14T18:30:34Z
New Revision: 75bfc6f29579b3787a93dff6c4125b614ddfc0b1

URL: https://github.com/llvm/llvm-project/commit/75bfc6f29579b3787a93dff6c4125b614ddfc0b1
DIFF: https://github.com/llvm/llvm-project/commit/75bfc6f29579b3787a93dff6c4125b614ddfc0b1.diff

LOG: [mlir][Arith] Implement InferIntRangeInterface for arithmetic ops

Depends on D124023

Reviewed By: Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D124022

Added: 
    mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp
    mlir/test/Dialect/Arithmetic/int-range-interface.mlir

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/Interfaces/InferIntRangeInterface.h
    mlir/lib/Analysis/IntRangeAnalysis.cpp
    mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
    mlir/lib/Interfaces/InferIntRangeInterface.cpp
    mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
index f74ace109a50a..e6dd5f64081f1 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index acdb3f200051b..75710d60c6d45 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
@@ -49,7 +50,8 @@ class Arith_TernaryOp<string mnemonic, list<Trait> traits = []> :
 
 // Base class for integer binary operations.
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
-    Arith_BinaryOp<mnemonic, traits>,
+    Arith_BinaryOp<mnemonic, traits #
+      [DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
     Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
     Results<(outs SignlessIntegerLike:$result)>;
 
@@ -70,7 +72,7 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
                    list<Trait> traits = []> :
     Arith_Op<mnemonic, traits # [SameOperandsAndResultShape,
-    DeclareOpInterfaceMethods<CastOpInterface>]>,
+      DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins From:$in)>,
     Results<(outs To:$out)> {
   let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
@@ -87,7 +89,9 @@ def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
 // Cast from an integer type to another integer type.
 class Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
     Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
-                           SignlessFixedWidthIntegerLike, traits>;
+                           SignlessFixedWidthIntegerLike,
+                           traits #
+                           [DeclareOpInterfaceMethods<InferIntRangeInterface>]>;
 // Cast from an integer type to a floating point type.
 class Arith_IToFCastOp<string mnemonic, list<Trait> traits = []> :
     Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike, traits>;
@@ -124,7 +128,8 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
 def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
     [ConstantLike, NoSideEffect,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-     AllTypesMatch<["value", "result"]>]> {
+     AllTypesMatch<["value", "result"]>,
+     DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
   let summary = "integer or floating point constant";
   let description = [{
     The `constant` operation produces an SSA value equal to some integer or
@@ -971,8 +976,9 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
         MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
     "signless-integer-like or memref of signless-integer">;
 
-def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint,
-                                                   IndexCastTypeConstraint> {
+def Arith_IndexCastOp
+  : Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
   let summary = "cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1024,7 +1030,9 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
+def Arith_CmpIOp
+  : Arith_CompareOpOfAnyRank<"cmpi",
+                             [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
   let summary = "integer comparison operation";
   let description = [{
     The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1165,7 +1173,8 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
 //===----------------------------------------------------------------------===//
 
 def SelectOp : Arith_Op<"select", [
-    AllTypesMatch<["true_value", "false_value", "result"]>
+    AllTypesMatch<["true_value", "false_value", "result"]>,
+    DeclareOpInterfaceMethods<InferIntRangeInterface>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";
   let description = [{
@@ -1205,7 +1214,7 @@ def SelectOp : Arith_Op<"select", [
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   // FIXME: Switch this to use the declarative assembly format.
   let hasCustomAssemblyFormat = 1;
 }

diff  --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 9a393855d05ff..131807dddf239 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -56,25 +56,40 @@ class ConstantIntRanges {
   /// non-integer types this is 0.
   static unsigned getStorageBitwidth(Type type);
 
-  /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned
-  /// minimum and `max` is both the signed and unsigned maximum.
-  static ConstantIntRanges range(const APInt &min, const APInt &max);
-
-  /// Create an `IntRangeAttrs` with the signed minimum and maximum equal
+  /// Create a `ConstantIntRanges` with the maximum bounds for the width
+  /// `bitwidth`, that is - [0, uint_max(width)]/[sint_min(width),
+  /// sint_max(width)].
+  static ConstantIntRanges maxRange(unsigned bitwidth);
+
+  /// Create a `ConstantIntRanges` with a constant value - that is, with the
+  /// bounds [value, value] for both its signed interpretations.
+  static ConstantIntRanges constant(const APInt &value);
+
+  /// Create a `ConstantIntRanges` whose minimum is `min` and maximum is `max`
+  /// with `isSigned` specifying if the min and max should be interpreted as
+  /// signed or unsigned.
+  static ConstantIntRanges range(const APInt &min, const APInt &max,
+                                 bool isSigned);
+
+  /// Create an `ConstantIntRanges` with the signed minimum and maximum equal
   /// to `smin` and `smax`, where the unsigned bounds are constructed from the
   /// signed ones if they correspond to a contigious range of bit patterns when
   /// viewed as unsigned values and are left at [0, int_max()] otherwise.
   static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
 
-  /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal
+  /// Create an `ConstantIntRanges` with the unsigned minimum and maximum equal
   /// to `umin` and `umax` and the signed part equal to `umin` and `umax`
   /// unless the sign bit changes between the minimum and maximum.
   static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
 
   /// Returns the union (computed separately for signed and unsigned bounds)
-  /// of `a` and `b`.
+  /// of this range and `other`.
   ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
 
+  /// Returns the intersection (computed separately for signed and unsigned
+  /// bounds) of this range and `other`.
+  ConstantIntRanges intersection(const ConstantIntRanges &other) const;
+
   /// If either the signed or unsigned interpretations of the range
   /// indicate that the value it bounds is a constant, return that constant
   /// value.

diff  --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
index fc01607c92ee3..83e1c3748caea 100644
--- a/mlir/lib/Analysis/IntRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp
@@ -43,7 +43,7 @@ struct IntRangeLattice {
   /// value being marked overdefined is even an integer.
   static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
     APInt noIntValue = APInt::getZeroWidth();
-    return ConstantIntRanges::range(noIntValue, noIntValue);
+    return ConstantIntRanges(noIntValue, noIntValue, noIntValue, noIntValue);
   }
 
   /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
index dc34db3169d6c..e23504bee3302 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_public_tablegen_target(MLIRArithmeticCanonicalizationIncGen)
 add_mlir_dialect_library(MLIRArithmeticDialect
   ArithmeticOps.cpp
   ArithmeticDialect.cpp
+  InferIntRangeInterfaceImpls.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic
@@ -14,6 +15,7 @@ add_mlir_dialect_library(MLIRArithmeticDialect
 
   LINK_LIBS PUBLIC
   MLIRDialect
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
   )

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp
new file mode 100644
index 0000000000000..5e870c0357646
--- /dev/null
+++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp
@@ -0,0 +1,660 @@
+//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+using namespace mlir::arith;
+
+/// Function that evaluates the result of doing something on arithmetic
+/// constants and returns None on overflow.
+using ConstArithFn =
+    function_ref<Optional<APInt>(const APInt &, const APInt &)>;
+
+/// Return the maxmially wide signed or unsigned range for a given bitwidth.
+
+/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
+/// If either computation overflows, make the result unbounded.
+static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
+                                         const APInt &minRight,
+                                         const APInt &maxLeft,
+                                         const APInt &maxRight, bool isSigned) {
+  Optional<APInt> maybeMin = op(minLeft, minRight);
+  Optional<APInt> maybeMax = op(maxLeft, maxRight);
+  if (maybeMin.hasValue() && maybeMax.hasValue())
+    return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
+  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
+}
+
+/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
+/// ignoring unbounded values. Returns the maximal range if `op` overflows.
+static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
+                                  ArrayRef<APInt> rhs, bool isSigned) {
+  unsigned width = lhs[0].getBitWidth();
+  APInt min =
+      isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
+  APInt max =
+      isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
+  for (const APInt &left : lhs) {
+    for (const APInt &right : rhs) {
+      Optional<APInt> maybeThisResult = op(left, right);
+      if (!maybeThisResult)
+        return ConstantIntRanges::maxRange(width);
+      APInt result = std::move(*maybeThisResult);
+      min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
+      max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
+    }
+  }
+  return ConstantIntRanges::range(min, max, isSigned);
+}
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                          SetIntRangeFn setResultRange) {
+  auto constAttr = getValue().dyn_cast_or_null<IntegerAttr>();
+  if (constAttr) {
+    const APInt &value = constAttr.getValue();
+    setResultRange(getResult(), ConstantIntRanges::constant(value));
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// AddIOp
+//===----------------------------------------------------------------------===//
+
+void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.uadd_ov(b, overflowed);
+    return overflowed ? Optional<APInt>() : result;
+  };
+  ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.sadd_ov(b, overflowed);
+    return overflowed ? Optional<APInt>() : result;
+  };
+
+  ConstantIntRanges urange = computeBoundsBy(
+      uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
+  ConstantIntRanges srange = computeBoundsBy(
+      sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
+  setResultRange(getResult(), urange.intersection(srange));
+}
+
+//===----------------------------------------------------------------------===//
+// SubIOp
+//===----------------------------------------------------------------------===//
+
+void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.usub_ov(b, overflowed);
+    return overflowed ? Optional<APInt>() : result;
+  };
+  ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.ssub_ov(b, overflowed);
+    return overflowed ? Optional<APInt>() : result;
+  };
+  ConstantIntRanges urange = computeBoundsBy(
+      usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
+  ConstantIntRanges srange = computeBoundsBy(
+      ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
+  setResultRange(getResult(), urange.intersection(srange));
+}
+
+//===----------------------------------------------------------------------===//
+// MulIOp
+//===----------------------------------------------------------------------===//
+
+void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.umul_ov(b, overflowed);
+    return overflowed ? Optional<APInt>() : result;
+  };
+  ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.smul_ov(b, overflowed);
+    return overflowed ? Optional<APInt>() : result;
+  };
+
+  ConstantIntRanges urange =
+      minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+               /*isSigned=*/false);
+  ConstantIntRanges srange =
+      minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
+               /*isSigned=*/true);
+
+  setResultRange(getResult(), urange.intersection(srange));
+}
+
+//===----------------------------------------------------------------------===//
+// DivUIOp
+//===----------------------------------------------------------------------===//
+
+/// Fix up division results (ex. for ceiling and floor), returning an APInt
+/// if there has been no overflow
+using DivisionFixupFn = function_ref<Optional<APInt>(
+    const APInt &lhs, const APInt &rhs, const APInt &result)>;
+
+static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
+                                         const ConstantIntRanges &rhs,
+                                         DivisionFixupFn fixup) {
+  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
+              &rhsMax = rhs.umax();
+
+  if (!rhsMin.isZero()) {
+    auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
+      return fixup(a, b, a.udiv(b));
+    };
+    return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                    /*isSigned=*/false);
+  }
+  // Otherwise, it's possible we might divide by 0.
+  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
+void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  setResultRange(getResult(),
+                 inferDivUIRange(argRanges[0], argRanges[1],
+                                 [](const APInt &lhs, const APInt &rhs,
+                                    const APInt &result) { return result; }));
+}
+
+//===----------------------------------------------------------------------===//
+// DivSIOp
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
+                                         const ConstantIntRanges &rhs,
+                                         DivisionFixupFn fixup) {
+  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+              &rhsMax = rhs.smax();
+  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
+
+  if (canDivide) {
+    auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
+      bool overflowed = false;
+      APInt result = a.sdiv_ov(b, overflowed);
+      return overflowed ? Optional<APInt>() : fixup(a, b, result);
+    };
+    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                    /*isSigned=*/true);
+  }
+  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
+void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  setResultRange(getResult(),
+                 inferDivSIRange(argRanges[0], argRanges[1],
+                                 [](const APInt &lhs, const APInt &rhs,
+                                    const APInt &result) { return result; }));
+}
+
+//===----------------------------------------------------------------------===//
+// CeilDivUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::CeilDivUIOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
+                                    const APInt &result) -> Optional<APInt> {
+    if (!lhs.urem(rhs).isZero()) {
+      bool overflowed = false;
+      APInt corrected =
+          result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+      return overflowed ? Optional<APInt>() : corrected;
+    }
+    return result;
+  };
+  setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
+}
+
+//===----------------------------------------------------------------------===//
+// CeilDivSIOp
+//===----------------------------------------------------------------------===//
+
+void arith::CeilDivSIOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
+                                    const APInt &result) -> Optional<APInt> {
+    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
+      bool overflowed = false;
+      APInt corrected =
+          result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+      return overflowed ? Optional<APInt>() : corrected;
+    }
+    return result;
+  };
+  setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
+}
+
+//===----------------------------------------------------------------------===//
+// FloorDivSIOp
+//===----------------------------------------------------------------------===//
+
+void arith::FloorDivSIOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
+                                     const APInt &result) -> Optional<APInt> {
+    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
+      bool overflowed = false;
+      APInt corrected =
+          result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
+      return overflowed ? Optional<APInt>() : corrected;
+    }
+    return result;
+  };
+  setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
+}
+
+//===----------------------------------------------------------------------===//
+// RemUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
+
+  unsigned width = rhsMin.getBitWidth();
+  APInt umin = APInt::getZero(width);
+  APInt umax = APInt::getMaxValue(width);
+
+  if (!rhsMin.isZero()) {
+    umax = rhsMax - 1;
+    // Special case: sweeping out a contiguous range in N/[modulus]
+    if (rhsMin == rhsMax) {
+      const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
+      if ((lhsMax - lhsMin).ult(rhsMax)) {
+        APInt minRem = lhsMin.urem(rhsMax);
+        APInt maxRem = lhsMax.urem(rhsMax);
+        if (minRem.ule(maxRem)) {
+          umin = minRem;
+          umax = maxRem;
+        }
+      }
+    }
+  }
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+}
+
+//===----------------------------------------------------------------------===//
+// RemSIOp
+//===----------------------------------------------------------------------===//
+
+void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+              &rhsMax = rhs.smax();
+
+  unsigned width = rhsMax.getBitWidth();
+  APInt smin = APInt::getSignedMinValue(width);
+  APInt smax = APInt::getSignedMaxValue(width);
+  // No bounds if zero could be a divisor.
+  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
+  if (canBound) {
+    APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
+    bool canNegativeDividend = lhsMin.isNegative();
+    bool canPositiveDividend = lhsMax.isStrictlyPositive();
+    APInt zero = APInt::getZero(maxDivisor.getBitWidth());
+    APInt maxPositiveResult = maxDivisor - 1;
+    APInt minNegativeResult = -maxPositiveResult;
+    smin = canNegativeDividend ? minNegativeResult : zero;
+    smax = canPositiveDividend ? maxPositiveResult : zero;
+    // Special case: sweeping out a contiguous range in N/[modulus].
+    if (rhsMin == rhsMax) {
+      if ((lhsMax - lhsMin).ult(maxDivisor)) {
+        APInt minRem = lhsMin.srem(maxDivisor);
+        APInt maxRem = lhsMax.srem(maxDivisor);
+        if (minRem.sle(maxRem)) {
+          smin = minRem;
+          smax = maxRem;
+        }
+      }
+    }
+  }
+  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+}
+
+//===----------------------------------------------------------------------===//
+// AndIOp
+//===----------------------------------------------------------------------===//
+
+/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
+/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
+/// that both bonuds have in common. This gives us a consertive approximation
+/// for what values can be passed to bitwise operations.
+static std::tuple<APInt, APInt>
+widenBitwiseBounds(const ConstantIntRanges &bound) {
+  APInt leftVal = bound.umin(), rightVal = bound.umax();
+  unsigned bitwidth = leftVal.getBitWidth();
+  unsigned 
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
+  leftVal.clearLowBits(
diff eringBits);
+  rightVal.setLowBits(
diff eringBits);
+  return {leftVal, rightVal};
+}
+
+void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
+  std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
+  std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
+  auto andi = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    return a & b;
+  };
+  setResultRange(getResult(),
+                 minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                          /*isSigned=*/false));
+}
+
+//===----------------------------------------------------------------------===//
+// OrIOp
+//===----------------------------------------------------------------------===//
+
+void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                     SetIntRangeFn setResultRange) {
+  APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
+  std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
+  std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
+  auto ori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    return a | b;
+  };
+  setResultRange(getResult(),
+                 minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                          /*isSigned=*/false));
+}
+
+//===----------------------------------------------------------------------===//
+// XOrIOp
+//===----------------------------------------------------------------------===//
+
+void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
+  std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
+  std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
+  auto xori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
+    return a ^ b;
+  };
+  setResultRange(getResult(),
+                 minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                          /*isSigned=*/false));
+}
+
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
+  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
+  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+}
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
+  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+}
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
+  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
+  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+}
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
+  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+}
+
+//===----------------------------------------------------------------------===//
+// ExtUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  Type destType = getResult().getType();
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  APInt umin = argRanges[0].umin().zext(destWidth);
+  APInt umax = argRanges[0].umax().zext(destWidth);
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+}
+
+//===----------------------------------------------------------------------===//
+// ExtSIOp
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
+                                    Type destType) {
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  APInt smin = range.smin().sext(destWidth);
+  APInt smax = range.smax().sext(destWidth);
+  return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  Type destType = getResult().getType();
+  setResultRange(getResult(), extSIRange(argRanges[0], destType));
+}
+
+//===----------------------------------------------------------------------===//
+// TruncIOp
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
+                                     Type destType) {
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  APInt umin = range.umin().trunc(destWidth);
+  APInt umax = range.umax().trunc(destWidth);
+  APInt smin = range.smin().trunc(destWidth);
+  APInt smax = range.smax().trunc(destWidth);
+  return {umin, umax, smin, smax};
+}
+
+void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                        SetIntRangeFn setResultRange) {
+  Type destType = getResult().getType();
+  setResultRange(getResult(), truncIRange(argRanges[0], destType));
+}
+
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+void arith::IndexCastOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  Type sourceType = getOperand().getType();
+  Type destType = getResult().getType();
+  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+
+  if (srcWidth < destWidth)
+    setResultRange(getResult(), extSIRange(argRanges[0], destType));
+  else if (srcWidth > destWidth)
+    setResultRange(getResult(), truncIRange(argRanges[0], destType));
+  else
+    setResultRange(getResult(), argRanges[0]);
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
+                      const ConstantIntRanges &rhs) {
+  switch (pred) {
+  case arith::CmpIPredicate::sle:
+  case arith::CmpIPredicate::slt:
+    return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
+  case arith::CmpIPredicate::ule:
+  case arith::CmpIPredicate::ult:
+    return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
+  case arith::CmpIPredicate::sge:
+  case arith::CmpIPredicate::sgt:
+    return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
+  case arith::CmpIPredicate::uge:
+  case arith::CmpIPredicate::ugt:
+    return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
+  case arith::CmpIPredicate::eq: {
+    Optional<APInt> lhsConst = lhs.getConstantValue();
+    Optional<APInt> rhsConst = rhs.getConstantValue();
+    return lhsConst && rhsConst && lhsConst == rhsConst;
+  }
+  case arith::CmpIPredicate::ne: {
+    // While equality requires that there is an interpration of the preceeding
+    // computations that produces equal constants, whether that be signed or
+    // unsigned, statically determining inequality requires that neither
+    // interpretation produce potentially overlapping ranges.
+    bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
+               isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
+    bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
+               isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
+    return sne && une;
+  }
+  }
+  return false;
+}
+
+void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  arith::CmpIPredicate pred = getPredicate();
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  APInt min = APInt::getZero(1);
+  APInt max = APInt::getAllOnesValue(1);
+  if (isStaticallyTrue(pred, lhs, rhs))
+    min = max;
+  else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
+    max = min;
+
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                        SetIntRangeFn setResultRange) {
+  Optional<APInt> mbCondVal = argRanges[0].getConstantValue();
+
+  if (mbCondVal) {
+    if (mbCondVal->isZero())
+      setResultRange(getResult(), argRanges[2]);
+    else
+      setResultRange(getResult(), argRanges[1]);
+    return;
+  }
+  setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
+}
+
+//===----------------------------------------------------------------------===//
+// ShLIOp
+//===----------------------------------------------------------------------===//
+
+void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                      SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional<APInt> {
+    return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.shl(r);
+  };
+  ConstantIntRanges urange =
+      minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+               /*isSigned=*/false);
+  ConstantIntRanges srange =
+      minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+               /*isSigned=*/true);
+  setResultRange(getResult(), urange.intersection(srange));
+}
+
+//===----------------------------------------------------------------------===//
+// ShRUIOp
+//===----------------------------------------------------------------------===//
+
+void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
+    return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.lshr(r);
+  };
+  setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
+                                       {rhs.umin(), rhs.umax()},
+                                       /*isSigned=*/false));
+}
+
+//===----------------------------------------------------------------------===//
+// ShRSIOp
+//===----------------------------------------------------------------------===//
+
+void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                       SetIntRangeFn setResultRange) {
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
+    return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.ashr(r);
+  };
+
+  setResultRange(getResult(),
+                 minMaxBy(ashr, {lhs.smin(), lhs.smax()},
+                          {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
+}

diff  --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 777ea18456551..d81674c276520 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -35,8 +35,19 @@ unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
   return 0;
 }
 
-ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
-  return {min, max, min, max};
+ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
+  return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
+}
+
+ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
+  return {value, value, value, value};
+}
+
+ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max,
+                                           bool isSigned) {
+  if (isSigned)
+    return fromSigned(min, max);
+  return fromUnsigned(min, max);
 }
 
 ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
@@ -84,6 +95,23 @@ ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
 }
 
+ConstantIntRanges
+ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
+  // "Not an integer" poisons everything and also cannot be fed to comparison
+  // operators.
+  if (umin().getBitWidth() == 0)
+    return *this;
+  if (other.umin().getBitWidth() == 0)
+    return other;
+
+  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+
+  return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
+}
+
 Optional<APInt> ConstantIntRanges::getConstantValue() const {
   // Note: we need to exclude the trivially-equal width 0 values here.
   if (umin() == umax() && umin().getBitWidth() != 0)

diff  --git a/mlir/test/Dialect/Arithmetic/int-range-interface.mlir b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir
new file mode 100644
index 0000000000000..72d3dbbb6326c
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir
@@ -0,0 +1,647 @@
+// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @add_min_max
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: return %[[c3]]
+func.func @add_min_max(%a: index, %b: index) -> index {
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %0 = arith.minsi %a, %c1 : index
+    %1 = arith.maxsi %0, %c1 : index
+    %2 = arith.minui %b, %c2 : index
+    %3 = arith.maxui %2, %c2 : index
+    %4 = arith.addi %1, %3 : index
+    func.return %4 : index
+}
+
+// CHECK-LABEL: func @add_lower_bound
+// CHECK: %[[sge:.*]] = arith.cmpi sge
+// CHECK: return %[[sge]]
+func.func @add_lower_bound(%a : i32, %b : i32) -> i1 {
+    %c1 = arith.constant 1 : i32
+    %c2 = arith.constant 2 : i32
+    %0 = arith.maxsi %a, %c1 : i32
+    %1 = arith.maxsi %b, %c1 : i32
+    %2 = arith.addi %0, %1 : i32
+    %3 = arith.cmpi sge, %2, %c2 : i32
+    %4 = arith.cmpi uge, %2, %c2 : i32
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: func @sub_signed_vs_unsigned
+// CHECK-NOT: arith.cmpi sle
+// CHECK: %[[unsigned:.*]] = arith.cmpi ule
+// CHECK: return %[[unsigned]] : i1
+func.func @sub_signed_vs_unsigned(%v : i64) -> i1 {
+    %c0 = arith.constant 0 : i64
+    %c2 = arith.constant 2 : i64
+    %c-5 = arith.constant -5 : i64
+    %0 = arith.minsi %v, %c2 : i64
+    %1 = arith.maxsi %0, %c-5 : i64
+    %2 = arith.subi %1, %c2 : i64
+    %3 = arith.cmpi sle, %2, %c0 : i64
+    %4 = arith.cmpi ule, %2, %c0 : i64
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: func @multiply_negatives
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: return %[[false]]
+func.func @multiply_negatives(%a : index, %b : index) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %c_1 = arith.constant -1 : index
+    %c_2 = arith.constant -2 : index
+    %c_4 = arith.constant -4 : index
+    %c_12 = arith.constant -12 : index
+    %0 = arith.maxsi %a, %c2 : index
+    %1 = arith.minsi %0, %c3 : index
+    %2 = arith.minsi %b, %c_1 : index
+    %3 = arith.maxsi %2, %c_4 : index
+    %4 = arith.muli %1, %3 : index
+    %5 = arith.cmpi slt, %4, %c_12 : index
+    %6 = arith.cmpi slt, %c_1, %4 : index
+    %7 = arith.ori %5, %6 : i1
+    func.return %7 : i1
+}
+
+// CHECK-LABEL: func @multiply_unsigned_bounds
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @multiply_unsigned_bounds(%a : i16, %b : i16) -> i1 {
+    %c0 = arith.constant 0 : i16
+    %c4 = arith.constant 4 : i16
+    %c_mask = arith.constant 0x3fff : i16
+    %c_bound = arith.constant 0xfffc : i16
+    %0 = arith.andi %a, %c_mask : i16
+    %1 = arith.minui %b, %c4 : i16
+    %2 = arith.muli %0, %1 : i16
+    %3 = arith.cmpi uge, %2, %c0 : i16
+    %4 = arith.cmpi ule, %2, %c_bound : i16
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: @for_loop_with_increasing_arg
+// CHECK: %[[ret:.*]] = arith.cmpi ule
+// CHECK: return %[[ret]]
+func.func @for_loop_with_increasing_arg() -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c16 = arith.constant 16 : index
+    %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %c0) -> index {
+        %10 = arith.addi %arg0, %arg1 : index
+        scf.yield %10 : index
+    }
+    %1 = arith.cmpi ule, %0, %c16 : index
+    func.return %1 : i1
+}
+
+// CHECK-LABEL: @for_loop_with_constant_result
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @for_loop_with_constant_result() -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %true = arith.constant true
+    %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %true) -> i1 {
+        %10 = arith.cmpi ule, %arg0, %c4 : index
+        %11 = arith.andi %10, %arg1 : i1
+        scf.yield %11 : i1
+    }
+    func.return %0 : i1
+}
+
+// Test to catch a bug present in some versions of the data flow analysis
+// CHECK-LABEL: func @while_false
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: scf.condition(%[[false]])
+func.func @while_false(%arg0 : index) -> index {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %0 = arith.divui %arg0, %c2 : index
+    %1 = scf.while (%arg1 = %0) : (index) -> index {
+        %2 = arith.cmpi slt, %arg1, %c0 : index
+        scf.condition(%2) %arg1 : index
+    } do {
+    ^bb0(%arg2 : index):
+        scf.yield %c2 : index
+    }
+    func.return %1 : index
+}
+
+// CHECK-LABEL: func @div_bounds_positive
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @div_bounds_positive(%arg0 : index) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index
+    %0 = arith.maxsi %arg0, %c2 : index
+    %1 = arith.divsi %c4, %0 : index
+    %2 = arith.divui %c4, %0 : index
+
+    %3 = arith.cmpi sge, %1, %c0 : index
+    %4 = arith.cmpi sle, %1, %c2 : index
+    %5 = arith.cmpi sge, %2, %c0 : index
+    %6 = arith.cmpi sle, %1, %c2 : index
+
+    %7 = arith.andi %3, %4 : i1
+    %8 = arith.andi %7, %5 : i1
+    %9 = arith.andi %8, %6 : i1
+    func.return %9 : i1
+}
+
+// CHECK-LABEL: func @div_bounds_negative
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @div_bounds_negative(%arg0 : index) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c_2 = arith.constant -2 : index
+    %c4 = arith.constant 4 : index
+    %0 = arith.minsi %arg0, %c_2 : index
+    %1 = arith.divsi %c4, %0 : index
+    %2 = arith.divui %c4, %0 : index
+
+    %3 = arith.cmpi sle, %1, %c0 : index
+    %4 = arith.cmpi sge, %1, %c_2 : index
+    %5 = arith.cmpi eq, %2, %c0 : index
+
+    %7 = arith.andi %3, %4 : i1
+    %8 = arith.andi %7, %5 : i1
+    func.return %8 : i1
+}
+
+// CHECK-LABEL: func @div_zero_undefined
+// CHECK: %[[ret:.*]] = arith.cmpi ule
+// CHECK: return %[[ret]]
+func.func @div_zero_undefined(%arg0 : index) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %0 = arith.andi %arg0, %c1 : index
+    %1 = arith.divui %c4, %0 : index
+    %2 = arith.cmpi ule, %1, %c4 : index
+    func.return %2 : i1
+}
+
+// CHECK-LABEL: func @ceil_divui
+// CHECK: %[[ret:.*]] = arith.cmpi eq
+// CHECK: return %[[ret]]
+func.func @ceil_divui(%arg0 : index) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = arith.minui %arg0, %c3 : index
+    %1 = arith.maxui %0, %c1 : index
+    %2 = arith.ceildivui %1, %c4 : index
+    %3 = arith.cmpi eq, %2, %c1 : index
+
+    %4 = arith.maxui %0, %c0 : index
+    %5 = arith.ceildivui %4, %c4 : index
+    %6 = arith.cmpi eq, %5, %c1 : index
+    %7 = arith.andi %3, %6 : i1
+    func.return %7 : i1
+}
+
+// CHECK-LABEL: func @ceil_divsi
+// CHECK: %[[ret:.*]] = arith.cmpi eq
+// CHECK: return %[[ret]]
+func.func @ceil_divsi(%arg0 : index) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %c4 = arith.constant 4 : index
+    %c-4 = arith.constant -4 : index
+
+    %0 = arith.minsi %arg0, %c3 : index
+    %1 = arith.maxsi %0, %c1 : index
+    %2 = arith.ceildivsi %1, %c4 : index
+    %3 = arith.cmpi eq, %2, %c1 : index
+    %4 = arith.ceildivsi %1, %c-4 : index
+    %5 = arith.cmpi eq, %4, %c0 : index
+    %6 = arith.andi %3, %5 : i1
+
+    %7 = arith.maxsi %0, %c0 : index
+    %8 = arith.ceildivsi %7, %c4 : index
+    %9 = arith.cmpi eq, %8, %c1 : index
+    %10 = arith.andi %6, %9 : i1
+    func.return %10 : i1
+}
+
+// CHECK-LABEL: func @floor_divsi
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @floor_divsi(%arg0 : index) -> i1 {
+    %c4 = arith.constant 4 : index
+    %c-1 = arith.constant -1 : index
+    %c-3 = arith.constant -3 : index
+    %c-4 = arith.constant -4 : index
+
+    %0 = arith.minsi %arg0, %c-1 : index
+    %1 = arith.maxsi %0, %c-4 : index
+    %2 = arith.floordivsi %1, %c4 : index
+    %3 = arith.cmpi eq, %2, %c-1 : index
+    func.return %3 : i1
+}
+
+// CHECK-LABEL: func @remui_base
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = arith.minui %arg1, %c4 : index
+    %1 = arith.maxui %0, %c2 : index
+    %2 = arith.remui %arg0, %1 : index
+    %3 = arith.cmpi ult, %2, %c4 : index
+    func.return %3 : i1
+}
+
+// CHECK-LABEL: func @remsi_base
+// CHECK: %[[ret:.*]] = arith.cmpi sge
+// CHECK: return %[[ret]]
+func.func @remsi_base(%arg0 : index, %arg1 : index ) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index
+    %c-4 = arith.constant -4 : index
+    %true = arith.constant true
+
+    %0 = arith.minsi %arg1, %c4 : index
+    %1 = arith.maxsi %0, %c2 : index
+    %2 = arith.remsi %arg0, %1 : index
+    %3 = arith.cmpi sgt, %2, %c-4 : index
+    %4 = arith.cmpi slt, %2, %c4 : index
+    %5 = arith.cmpi sge, %2, %c0 : index
+    %6 = arith.andi %3, %4 : i1
+    %7 = arith.andi %5, %6 : i1
+    func.return %7 : i1
+}
+
+// CHECK-LABEL: func @remsi_positive
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remsi_positive(%arg0 : index, %arg1 : index ) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index
+    %true = arith.constant true
+
+    %0 = arith.minsi %arg1, %c4 : index
+    %1 = arith.maxsi %0, %c2 : index
+    %2 = arith.maxsi %arg0, %c0 : index
+    %3 = arith.remsi %2, %1 : index
+    %4 = arith.cmpi sge, %3, %c0 : index
+    %5 = arith.cmpi slt, %3, %c4 : index
+    %6 = arith.andi %4, %5 : i1
+    func.return %6 : i1
+}
+
+// CHECK-LABEL: func @remui_restricted
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remui_restricted(%arg0 : index) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %c4 = arith.constant 4 : index
+
+    %0 = arith.minui %arg0, %c3 : index
+    %1 = arith.maxui %0, %c2 : index
+    %2 = arith.remui %1, %c4 : index
+    %3 = arith.cmpi ule, %2, %c3 : index
+    %4 = arith.cmpi uge, %2, %c2 : index
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: func @remsi_restricted
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remsi_restricted(%arg0 : index) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %c-4 = arith.constant -4 : index
+
+    %0 = arith.minsi %arg0, %c3 : index
+    %1 = arith.maxsi %0, %c2 : index
+    %2 = arith.remsi %1, %c-4 : index
+    %3 = arith.cmpi ule, %2, %c3 : index
+    %4 = arith.cmpi uge, %2, %c2 : index
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: func @remui_restricted_fails
+// CHECK: %[[ret:.*]] = arith.cmpi ne
+// CHECK: return %[[ret]]
+func.func @remui_restricted_fails(%arg0 : index) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %c4 = arith.constant 4 : index
+    %c5 = arith.constant 5 : index
+
+    %0 = arith.minui %arg0, %c5 : index
+    %1 = arith.maxui %0, %c3 : index
+    %2 = arith.remui %1, %c4 : index
+    %3 = arith.cmpi ne, %2, %c2 : index
+    func.return %3 : i1
+}
+
+// CHECK-LABEL: func @remsi_restricted_fails
+// CHECK: %[[ret:.*]] = arith.cmpi ne
+// CHECK: return %[[ret]]
+func.func @remsi_restricted_fails(%arg0 : index) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %c5 = arith.constant 5 : index
+    %c-4 = arith.constant -4 : index
+
+    %0 = arith.minsi %arg0, %c5 : index
+    %1 = arith.maxsi %0, %c3 : index
+    %2 = arith.remsi %1, %c-4 : index
+    %3 = arith.cmpi ne, %2, %c2 : index
+    func.return %3 : i1
+}
+
+// CHECK-LABEL: func @andi
+// CHECK: %[[ret:.*]] = arith.cmpi ugt
+// CHECK: return %[[ret]]
+func.func @andi(%arg0 : index) -> i1 {
+    %c2 = arith.constant 2 : index
+    %c5 = arith.constant 5 : index
+    %c7 = arith.constant 7 : index
+
+    %0 = arith.minsi %arg0, %c5 : index
+    %1 = arith.maxsi %0, %c2 : index
+    %2 = arith.andi %1, %c7 : index
+    %3 = arith.cmpi ugt, %2, %c5 : index
+    %4 = arith.cmpi ule, %2, %c7 : index
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: func @andi_doesnt_make_nonnegative
+// CHECK: %[[ret:.*]] = arith.cmpi sge
+// CHECK: return %[[ret]]
+func.func @andi_doesnt_make_nonnegative(%arg0 : index) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %0 = arith.addi %arg0, %c1 : index
+    %1 = arith.andi %arg0, %0 : index
+    %2 = arith.cmpi sge, %1, %c0 : index
+    func.return %2 : i1
+}
+
+
+// CHECK-LABEL: func @ori
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @ori(%arg0 : i128, %arg1 : i128) -> i1 {
+    %c-1 = arith.constant -1 : i128
+    %c0 = arith.constant 0 : i128
+
+    %0 = arith.minsi %arg1, %c-1 : i128
+    %1 = arith.ori %arg0, %0 : i128
+    %2 = arith.cmpi slt, %1, %c0 : i128
+    func.return %2 : i1
+}
+
+// CHECK-LABEL: func @xori
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: return %[[false]]
+func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
+    %c0 = arith.constant 0 : i64
+    %c7 = arith.constant 7 : i64
+    %c15 = arith.constant 15 : i64
+    %true = arith.constant true
+
+    %0 = arith.minui %arg0, %c7 : i64
+    %1 = arith.minui %arg1, %c15 : i64
+    %2 = arith.xori %0, %1 : i64
+    %3 = arith.cmpi sle, %2, %c15 : i64
+    %4 = arith.xori %3, %true : i1
+    func.return %4 : i1
+}
+
+// CHECK-LABEL: func @extui
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @extui(%arg0 : i16) -> i1 {
+    %ci16_max = arith.constant 0xffff : i32
+    %0 = arith.extui %arg0 : i16 to i32
+    %1 = arith.cmpi ule, %0, %ci16_max : i32
+    func.return %1 : i1
+}
+
+// CHECK-LABEL: func @extsi
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @extsi(%arg0 : i16) -> i1 {
+    %ci16_smax = arith.constant 0x7fff : i32
+    %ci16_smin = arith.constant 0xffff8000 : i32
+    %0 = arith.extsi %arg0 : i16 to i32
+    %1 = arith.cmpi sle, %0, %ci16_smax : i32
+    %2 = arith.cmpi sge, %0, %ci16_smin : i32
+    %3 = arith.andi %1, %2 : i1
+    func.return %3 : i1
+}
+
+// CHECK-LABEL: func @trunci
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @trunci(%arg0 : i32) -> i1 {
+    %c-14_i32 = arith.constant -14 : i32
+    %c-14_i16 = arith.constant -14 : i16
+    %ci16_smin = arith.constant 0xffff8000 : i32
+    %0 = arith.minsi %arg0, %c-14_i32 : i32
+    %1 = arith.trunci %0 : i32 to i16
+    %2 = arith.cmpi sle, %1, %c-14_i16 : i16
+    %3 = arith.extsi %1 : i16 to i32
+    %4 = arith.cmpi sle, %3, %c-14_i32 : i32
+    %5 = arith.cmpi sge, %3, %ci16_smin : i32
+    %6 = arith.andi %2, %4 : i1
+    %7 = arith.andi %6, %5 : i1
+    func.return %7 : i1
+}
+
+// CHECK-LABEL: func @index_cast
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @index_cast(%arg0 : index) -> i1 {
+    %ci32_smin = arith.constant 0xffffffff80000000 : i64
+    %0 = arith.index_cast %arg0 : index to i32
+    %1 = arith.index_cast %0 : i32 to index
+    %2 = arith.index_cast %ci32_smin : i64 to index
+    %3 = arith.cmpi sge, %1, %2 : index
+    func.return %3 : i1
+}
+
+// CHECK-LABEL: func @shli
+// CHECK: %[[ret:.*]] = arith.cmpi sgt
+// CHECK: return %[[ret]]
+func.func @shli(%arg0 : i32, %arg1 : i1) -> i1 {
+    %c2 = arith.constant 2 : i32
+    %c4 = arith.constant 4 : i32
+    %c8 = arith.constant 8 : i32
+    %c32 = arith.constant 32 : i32
+    %c-1 = arith.constant -1 : i32
+    %c-16 = arith.constant -16 : i32
+    %0 = arith.maxsi %arg0, %c-1 : i32
+    %1 = arith.minsi %0, %c2 : i32
+    %2 = arith.select %arg1, %c2, %c4 : i32
+    %3 = arith.shli %1, %2 : i32
+    %4 = arith.cmpi sge, %3, %c-16 : i32
+    %5 = arith.cmpi sle, %3, %c32 : i32
+    %6 = arith.cmpi sgt, %3, %c8 : i32
+    %7 = arith.andi %4, %5 : i1
+    %8 = arith.andi %7, %6 : i1
+    func.return %8 : i1
+}
+
+// CHECK-LABEL: func @shrui
+// CHECK: %[[ret:.*]] = arith.cmpi uge
+// CHECK: return %[[ret]]
+func.func @shrui(%arg0 : i1) -> i1 {
+    %c2 = arith.constant 2 : i32
+    %c4 = arith.constant 4 : i32
+    %c8 = arith.constant 8 : i32
+    %c32 = arith.constant 32 : i32
+    %0 = arith.select %arg0, %c2, %c4 : i32
+    %1 = arith.shrui %c32, %0 : i32
+    %2 = arith.cmpi ule, %1, %c8 : i32
+    %3 = arith.cmpi uge, %1, %c2 : i32
+    %4 = arith.cmpi uge, %1, %c8 : i32
+    %5 = arith.andi %2, %3 : i1
+    %6 = arith.andi %5, %4 : i1
+    func.return %6 : i1
+}
+
+// CHECK-LABEL: func @shrsi
+// CHECK: %[[ret:.*]] = arith.cmpi slt
+// CHECK: return %[[ret]]
+func.func @shrsi(%arg0 : i32, %arg1 : i1) -> i1 {
+    %c2 = arith.constant 2 : i32
+    %c4 = arith.constant 4 : i32
+    %c8 = arith.constant 8 : i32
+    %c32 = arith.constant 32 : i32
+    %c-8 = arith.constant -8 : i32
+    %c-32 = arith.constant -32 : i32
+    %0 = arith.maxsi %arg0, %c-32 : i32
+    %1 = arith.minsi %0, %c32 : i32
+    %2 = arith.select %arg1, %c2, %c4 : i32
+    %3 = arith.shrsi %1, %2 : i32
+    %4 = arith.cmpi sge, %3, %c-8 : i32
+    %5 = arith.cmpi sle, %3, %c8 : i32
+    %6 = arith.cmpi slt, %3, %c2 : i32
+    %7 = arith.andi %4, %5 : i1
+    %8 = arith.andi %7, %6 : i1
+    func.return %8 : i1
+}
+
+// CHECK-LABEL: func @no_aggressive_eq
+// CHECK: %[[ret:.*]] = arith.cmpi eq
+// CHECK: return %[[ret]]
+func.func @no_aggressive_eq(%arg0 : index) -> i1 {
+    %c1 = arith.constant 1 : index
+    %0 = arith.andi %arg0, %c1 : index
+    %1 = arith.minui %arg0, %c1 : index
+    %2 = arith.cmpi eq, %0, %1 : index
+    func.return %2 : i1
+}
+
+// CHECK-LABEL: func @select_union
+// CHECK: %[[ret:.*]] = arith.cmpi ne
+// CHECK: return %[[ret]]
+
+func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 {
+    %c64 = arith.constant 64 : index
+    %c100 = arith.constant 100 : index
+    %c128 = arith.constant 128 : index
+    %c192 = arith.constant 192 : index
+    %0 = arith.remui %arg0, %c64 : index
+    %1 = arith.addi %0, %c128 : index
+    %2 = arith.select %arg1, %0, %1 : index
+    %3 = arith.cmpi slt, %2, %c192 : index
+    %4 = arith.cmpi ne, %c100, %2 : index
+    %5 = arith.andi %3, %4 : i1
+    func.return %5 : i1
+}
+
+// CHECK-LABEL: func @if_union
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @if_union(%arg0 : index, %arg1 : i1) -> i1 {
+    %c4 = arith.constant 4 : index
+    %c16 = arith.constant 16 : index
+    %c-1 = arith.constant -1 : index
+    %c-4 = arith.constant -4 : index
+    %0 = arith.minui %arg0, %c4 : index
+    %1 = scf.if %arg1 -> index {
+        %10 = arith.muli %0, %0 : index
+        scf.yield %10 : index
+    } else {
+        %20 = arith.muli %0, %c-1 : index
+        scf.yield %20 : index
+    }
+    %2 = arith.cmpi sle, %1, %c16 : index
+    %3 = arith.cmpi sge, %1, %c-4 : index
+    %4 = arith.andi %2, %3 : i1
+    func.return %4 : i1
+}
+
+// CHECK-LABEL: func @branch_union
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @branch_union(%arg0 : index, %arg1 : i1) -> i1 {
+    %c4 = arith.constant 4 : index
+    %c16 = arith.constant 16 : index
+    %c-1 = arith.constant -1 : index
+    %c-4 = arith.constant -4 : index
+    %0 = arith.minui %arg0, %c4 : index
+    cf.cond_br %arg1, ^bb1, ^bb2
+^bb1 :
+    %1 = arith.muli %0, %0 : index
+    cf.br ^bb3(%1 : index)
+^bb2 :
+    %2 = arith.muli %0, %c-1 : index
+    cf.br ^bb3(%2 : index)
+^bb3(%3 : index) :
+    %4 = arith.cmpi sle, %3, %c16 : index
+    %5 = arith.cmpi sge, %3, %c-4 : index
+    %6 = arith.andi %4, %5 : i1
+    func.return %6 : i1
+}
+
+// CHECK-LABEL: func @loop_bound_not_inferred_with_branch
+// CHECK-DAG: %[[min:.*]] = arith.cmpi sge
+// CHECK-DAG: %[[max:.*]] = arith.cmpi slt
+// CHECK-DAG: %[[ret:.*]] = arith.andi %[[min]], %[[max]]
+// CHECK: return %[[ret]]
+func.func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1 {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %0 = arith.minui %arg0, %c4 : index
+    cf.br ^bb2(%c0 : index)
+^bb1(%1 : index) :
+    %2 = arith.addi %1, %c1 : index
+    cf.br ^bb2(%2 : index)
+^bb2(%3 : index):
+    %4 = arith.cmpi ult, %3, %c4 : index
+    cf.cond_br %4, ^bb1(%3 : index), ^bb3(%3 : index)
+^bb3(%5 : index) :
+    %6 = arith.cmpi sge, %5, %c0 : index
+    %7 = arith.cmpi slt, %5, %c4 : index
+    %8 = arith.andi %6, %7 : i1
+    func.return %8 : i1
+}
+

diff  --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index f9c551c0b9929..81c1531484d43 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -101,16 +101,16 @@ func.func @func_args_unbound(%arg0 : index) -> index {
   func.return %0 : index
 }
 
-// CHECK-LABEL: func @propagate_across_while_loop()
-func.func @propagate_across_while_loop() -> index {
+// CHECK-LABEL: func @propagate_across_while_loop_false()
+func.func @propagate_across_while_loop_false() -> index {
   // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
   // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index }
   %1 = scf.while : () -> index {
-    %true = arith.constant true
+    %false = arith.constant false
     // CHECK: scf.condition(%{{.*}}) %[[C0]]
-    scf.condition(%true) %0 : index
+    scf.condition(%false) %0 : index
   } do {
   ^bb0(%i1: index):
     scf.yield
@@ -119,3 +119,42 @@ func.func @propagate_across_while_loop() -> index {
   %2 = test.increment %1
   return %2 : index
 }
+
+// CHECK-LABEL: func @propagate_across_while_loop
+func.func @propagate_across_while_loop(%arg0 : i1) -> index {
+  // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
+  // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
+  %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                          smin = 0 : index, smax = 0 : index }
+  %1 = scf.while : () -> index {
+    // CHECK: scf.condition(%{{.*}}) %[[C0]]
+    scf.condition(%arg0) %0 : index
+  } do {
+  ^bb0(%i1: index):
+    scf.yield
+  }
+  // CHECK: return %[[C1]]
+  %2 = test.increment %1
+  return %2 : index
+}
+
+// CHECK-LABEL: func @dont_propagate_across_infinite_loop()
+func.func @dont_propagate_across_infinite_loop() -> index {
+  // CHECK: %[[C0:.*]] = "test.constant"() {value = 0
+  %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                          smin = 0 : index, smax = 0 : index }
+  // CHECK: %[[loopRes:.*]] = scf.while
+  %1 = scf.while : () -> index {
+    %true = arith.constant true
+    // CHECK: scf.condition(%{{.*}}) %[[C0]]
+    scf.condition(%true) %0 : index
+  } do {
+  ^bb0(%i1: index):
+    scf.yield
+  }
+  // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]]
+  %2 = test.reflect_bounds %1
+  // CHECK: return %[[ret]]
+  return %2 : index
+}
+


        


More information about the Mlir-commits mailing list