[Mlir-commits] [mlir] 9f59aff - Revert "[mlir][Index] Implement InferIntRangeInterface"

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Jan 19 10:43:18 PST 2023


Author: Krzysztof Drewniak
Date: 2023-01-19T18:43:12Z
New Revision: 9f59affa244c1b15d9980c9abcaff433514a9d85

URL: https://github.com/llvm/llvm-project/commit/9f59affa244c1b15d9980c9abcaff433514a9d85
DIFF: https://github.com/llvm/llvm-project/commit/9f59affa244c1b15d9980c9abcaff433514a9d85.diff

LOG: Revert "[mlir][Index] Implement InferIntRangeInterface"

This reverts commit 455305624884cf9237143e2ba0635fcc5ba5206a.

Linker error, unbreak build while I work out how to fix it.

Differential Revision: https://reviews.llvm.org/D142142

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Index/IR/IndexOps.h
    mlir/include/mlir/Dialect/Index/IR/IndexOps.td
    mlir/lib/Dialect/Arith/IR/CMakeLists.txt
    mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
    mlir/lib/Dialect/Index/IR/CMakeLists.txt
    mlir/lib/Interfaces/CMakeLists.txt

Removed: 
    mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
    mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
    mlir/lib/Interfaces/Utils/CMakeLists.txt
    mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
    mlir/test/Dialect/Index/int-range-inference.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
index d8debfb731323..85a0549edd4dd 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
@@ -13,7 +13,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CastInterfaces.h"
-#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 

diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 8fbccc4ba94fc..76008a17364f9 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -12,7 +12,6 @@
 include "mlir/Dialect/Index/IR/IndexDialect.td"
 include "mlir/Dialect/Index/IR/IndexEnums.td"
 include "mlir/Interfaces/CastInterfaces.td"
-include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
@@ -24,8 +23,7 @@ include "mlir/IR/OpBase.td"
 
 /// Base class for Index dialect operations.
 class IndexOp<string mnemonic, list<Trait> traits = []>
-    : Op<IndexDialect, mnemonic,
-      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
+    : Op<IndexDialect, mnemonic, [Pure] # traits>;
 
 //===----------------------------------------------------------------------===//
 // IndexBinaryOp

diff  --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
deleted file mode 100644
index 7ee059cf342ce..0000000000000
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ /dev/null
@@ -1,126 +0,0 @@
-//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file declares implementations of range inference for operations that are
-// common to both the `arith` and `index` dialects to facilitate reuse.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
-#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
-
-#include "mlir/Interfaces/InferIntRangeInterface.h"
-#include "llvm/ADT/ArrayRef.h"
-
-namespace mlir {
-namespace intrange {
-/// Function that performs inference on an array of `ConstantIntRanges`,
-/// abstracted away here to permit writing the function that handles both
-/// 64- and 32-bit index types.
-using InferRangeFn =
-    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
-
-static constexpr unsigned indexMinWidth = 32;
-static constexpr unsigned indexMaxWidth = 64;
-
-enum class CmpMode : uint32_t { Both, Signed, Unsigned };
-
-/// Compute `inferFn` on `ranges`, whose size should be the index storage
-/// bitwidth. Then, compute the function on `argRanges` again after truncating
-/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
-/// equal to the 32-bit result, use it (to preserve compatibility with folders
-/// and inference precision), and take the union of the results otherwise.
-///
-/// The `mode` argument specifies if the unsigned, signed, or both results of
-/// the inference computation should be used when comparing the results.
-ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
-                               ArrayRef<ConstantIntRanges> argRanges,
-                               CmpMode mode);
-
-/// Independently zero-extend the unsigned values and sign-extend the signed
-/// values in `range` to `destWidth` bits, returning the resulting range.
-ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth);
-
-/// Use the unsigned values in `range` to zero-extend it to `destWidth`.
-ConstantIntRanges extUIRange(const ConstantIntRanges &range,
-                             unsigned destWidth);
-
-/// Use the signed values in `range` to sign-extend it to `destWidth`.
-ConstantIntRanges extSIRange(const ConstantIntRanges &range,
-                             unsigned destWidth);
-
-/// Truncate `range` to `destWidth` bits, taking care to handle cases such as
-/// the truncation of [255, 256] to i8 not being a uniform range.
-ConstantIntRanges truncRange(const ConstantIntRanges &range,
-                             unsigned destWidth);
-
-ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferRemS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferRemU(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferMaxS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferMaxU(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferMinS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferMinU(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferAnd(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
-
-ConstantIntRanges inferShrU(ArrayRef<ConstantIntRanges> argRanges);
-
-/// Copy of the enum from `arith` and `index` to allow the common integer range
-/// infrastructure to not depend on either dialect.
-enum class CmpPredicate : uint64_t {
-  eq,
-  ne,
-  slt,
-  sle,
-  sgt,
-  sge,
-  ult,
-  ule,
-  ugt,
-  uge,
-};
-
-/// Returns a boolean value if `pred` is statically true or false for
-/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the
-/// value of the predicate cannot be determined.
-Optional<bool> evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs,
-                            const ConstantIntRanges &rhs);
-
-} // namespace intrange
-} // namespace mlir
-
-#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H

diff  --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index ffbe80105911e..0de17bbfbd12a 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -16,7 +16,6 @@ add_mlir_dialect_library(MLIRArithDialect
 
   LINK_LIBS PUBLIC
   MLIRDialect
-  MLIRInferIntRangeCommon
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR

diff  --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 971477fa94cb9..10d6ef29756c6 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
-#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -17,7 +16,48 @@
 
 using namespace mlir;
 using namespace mlir::arith;
-using namespace mlir::intrange;
+
+/// Function that evaluates the result of doing something on arithmetic
+/// constants and returns std::nullopt on overflow.
+using ConstArithFn =
+    function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
+
+/// Return the maxmially wide signed or unsigned range for a given bitwidth.
+
+/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
+/// If either computation overflows, make the result unbounded.
+static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
+                                         const APInt &minRight,
+                                         const APInt &maxLeft,
+                                         const APInt &maxRight, bool isSigned) {
+  std::optional<APInt> maybeMin = op(minLeft, minRight);
+  std::optional<APInt> maybeMax = op(maxLeft, maxRight);
+  if (maybeMin && maybeMax)
+    return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
+  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
+}
+
+/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
+/// ignoring unbounded values. Returns the maximal range if `op` overflows.
+static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
+                                  ArrayRef<APInt> rhs, bool isSigned) {
+  unsigned width = lhs[0].getBitWidth();
+  APInt min =
+      isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
+  APInt max =
+      isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
+  for (const APInt &left : lhs) {
+    for (const APInt &right : rhs) {
+      std::optional<APInt> maybeThisResult = op(left, right);
+      if (!maybeThisResult)
+        return ConstantIntRanges::maxRange(width);
+      APInt result = std::move(*maybeThisResult);
+      min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
+      max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
+    }
+  }
+  return ConstantIntRanges::range(min, max, isSigned);
+}
 
 //===----------------------------------------------------------------------===//
 // ConstantOp
@@ -38,7 +78,25 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  ConstArithFn uadd = [](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.uadd_ov(b, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn sadd = [](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.sadd_ov(b, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+
+  ConstantIntRanges urange = computeBoundsBy(
+      uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
+  ConstantIntRanges srange = computeBoundsBy(
+      sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
+  setResultRange(getResult(), urange.intersection(srange));
 }
 
 //===----------------------------------------------------------------------===//
@@ -47,7 +105,25 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn usub = [](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.usub_ov(b, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn ssub = [](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.ssub_ov(b, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstantIntRanges urange = computeBoundsBy(
+      usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
+  ConstantIntRanges srange = computeBoundsBy(
+      ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
+  setResultRange(getResult(), urange.intersection(srange));
 }
 
 //===----------------------------------------------------------------------===//
@@ -56,25 +132,96 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn umul = [](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.umul_ov(b, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+  ConstArithFn smul = [](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+    bool overflowed = false;
+    APInt result = a.smul_ov(b, overflowed);
+    return overflowed ? std::optional<APInt>() : result;
+  };
+
+  ConstantIntRanges urange =
+      minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+               /*isSigned=*/false);
+  ConstantIntRanges srange =
+      minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
+               /*isSigned=*/true);
+
+  setResultRange(getResult(), urange.intersection(srange));
 }
 
 //===----------------------------------------------------------------------===//
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
+/// Fix up division results (ex. for ceiling and floor), returning an APInt
+/// if there has been no overflow
+using DivisionFixupFn = function_ref<std::optional<APInt>(
+    const APInt &lhs, const APInt &rhs, const APInt &result)>;
+
+static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
+                                         const ConstantIntRanges &rhs,
+                                         DivisionFixupFn fixup) {
+  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
+              &rhsMax = rhs.umax();
+
+  if (!rhsMin.isZero()) {
+    auto udiv = [&fixup](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+      return fixup(a, b, a.udiv(b));
+    };
+    return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                    /*isSigned=*/false);
+  }
+  // Otherwise, it's possible we might divide by 0.
+  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferDivU(argRanges));
+  setResultRange(getResult(),
+                 inferDivUIRange(argRanges[0], argRanges[1],
+                                 [](const APInt &lhs, const APInt &rhs,
+                                    const APInt &result) { return result; }));
 }
 
 //===----------------------------------------------------------------------===//
 // DivSIOp
 //===----------------------------------------------------------------------===//
 
+static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
+                                         const ConstantIntRanges &rhs,
+                                         DivisionFixupFn fixup) {
+  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+              &rhsMax = rhs.smax();
+  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
+
+  if (canDivide) {
+    auto sdiv = [&fixup](const APInt &a,
+                         const APInt &b) -> std::optional<APInt> {
+      bool overflowed = false;
+      APInt result = a.sdiv_ov(b, overflowed);
+      return overflowed ? std::optional<APInt>() : fixup(a, b, result);
+    };
+    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                    /*isSigned=*/true);
+  }
+  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferDivS(argRanges));
+  setResultRange(getResult(),
+                 inferDivSIRange(argRanges[0], argRanges[1],
+                                 [](const APInt &lhs, const APInt &rhs,
+                                    const APInt &result) { return result; }));
 }
 
 //===----------------------------------------------------------------------===//
@@ -83,7 +230,20 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::CeilDivUIOp::inferResultRanges(
     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferCeilDivU(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  DivisionFixupFn ceilDivUIFix =
+      [](const APInt &lhs, const APInt &rhs,
+         const APInt &result) -> std::optional<APInt> {
+    if (!lhs.urem(rhs).isZero()) {
+      bool overflowed = false;
+      APInt corrected =
+          result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+      return overflowed ? std::optional<APInt>() : corrected;
+    }
+    return result;
+  };
+  setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
 }
 
 //===----------------------------------------------------------------------===//
@@ -92,7 +252,20 @@ void arith::CeilDivUIOp::inferResultRanges(
 
 void arith::CeilDivSIOp::inferResultRanges(
     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferCeilDivS(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  DivisionFixupFn ceilDivSIFix =
+      [](const APInt &lhs, const APInt &rhs,
+         const APInt &result) -> std::optional<APInt> {
+    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
+      bool overflowed = false;
+      APInt corrected =
+          result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+      return overflowed ? std::optional<APInt>() : corrected;
+    }
+    return result;
+  };
+  setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
 }
 
 //===----------------------------------------------------------------------===//
@@ -101,7 +274,20 @@ void arith::CeilDivSIOp::inferResultRanges(
 
 void arith::FloorDivSIOp::inferResultRanges(
     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  return setResultRange(getResult(), inferFloorDivS(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  DivisionFixupFn floorDivSIFix =
+      [](const APInt &lhs, const APInt &rhs,
+         const APInt &result) -> std::optional<APInt> {
+    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
+      bool overflowed = false;
+      APInt corrected =
+          result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
+      return overflowed ? std::optional<APInt>() : corrected;
+    }
+    return result;
+  };
+  setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
 }
 
 //===----------------------------------------------------------------------===//
@@ -110,7 +296,29 @@ void arith::FloorDivSIOp::inferResultRanges(
 
 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferRemU(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
+
+  unsigned width = rhsMin.getBitWidth();
+  APInt umin = APInt::getZero(width);
+  APInt umax = APInt::getMaxValue(width);
+
+  if (!rhsMin.isZero()) {
+    umax = rhsMax - 1;
+    // Special case: sweeping out a contiguous range in N/[modulus]
+    if (rhsMin == rhsMax) {
+      const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
+      if ((lhsMax - lhsMin).ult(rhsMax)) {
+        APInt minRem = lhsMin.urem(rhsMax);
+        APInt maxRem = lhsMax.urem(rhsMax);
+        if (minRem.ule(maxRem)) {
+          umin = minRem;
+          umax = maxRem;
+        }
+      }
+    }
+  }
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
 }
 
 //===----------------------------------------------------------------------===//
@@ -119,16 +327,67 @@ void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferRemS(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+              &rhsMax = rhs.smax();
+
+  unsigned width = rhsMax.getBitWidth();
+  APInt smin = APInt::getSignedMinValue(width);
+  APInt smax = APInt::getSignedMaxValue(width);
+  // No bounds if zero could be a divisor.
+  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
+  if (canBound) {
+    APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
+    bool canNegativeDividend = lhsMin.isNegative();
+    bool canPositiveDividend = lhsMax.isStrictlyPositive();
+    APInt zero = APInt::getZero(maxDivisor.getBitWidth());
+    APInt maxPositiveResult = maxDivisor - 1;
+    APInt minNegativeResult = -maxPositiveResult;
+    smin = canNegativeDividend ? minNegativeResult : zero;
+    smax = canPositiveDividend ? maxPositiveResult : zero;
+    // Special case: sweeping out a contiguous range in N/[modulus].
+    if (rhsMin == rhsMax) {
+      if ((lhsMax - lhsMin).ult(maxDivisor)) {
+        APInt minRem = lhsMin.srem(maxDivisor);
+        APInt maxRem = lhsMax.srem(maxDivisor);
+        if (minRem.sle(maxRem)) {
+          smin = minRem;
+          smax = maxRem;
+        }
+      }
+    }
+  }
+  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
 }
 
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
 
+/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
+/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
+/// that both bonuds have in common. This gives us a consertive approximation
+/// for what values can be passed to bitwise operations.
+static std::tuple<APInt, APInt>
+widenBitwiseBounds(const ConstantIntRanges &bound) {
+  APInt leftVal = bound.umin(), rightVal = bound.umax();
+  unsigned bitwidth = leftVal.getBitWidth();
+  unsigned 
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
+  leftVal.clearLowBits(
diff eringBits);
+  rightVal.setLowBits(
diff eringBits);
+  return std::make_tuple(std::move(leftVal), std::move(rightVal));
+}
+
 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAnd(argRanges));
+  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+  auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+    return a & b;
+  };
+  setResultRange(getResult(),
+                 minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                          /*isSigned=*/false));
 }
 
 //===----------------------------------------------------------------------===//
@@ -137,7 +396,14 @@ void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferOr(argRanges));
+  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+  auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+    return a | b;
+  };
+  setResultRange(getResult(),
+                 minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                          /*isSigned=*/false));
 }
 
 //===----------------------------------------------------------------------===//
@@ -146,7 +412,14 @@ void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferXor(argRanges));
+  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+  auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+    return a ^ b;
+  };
+  setResultRange(getResult(),
+                 minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+                          /*isSigned=*/false));
 }
 
 //===----------------------------------------------------------------------===//
@@ -155,7 +428,11 @@ void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMaxS(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
+  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
+  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
 }
 
 //===----------------------------------------------------------------------===//
@@ -164,7 +441,11 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMaxU(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
+  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
 }
 
 //===----------------------------------------------------------------------===//
@@ -173,7 +454,11 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMinS(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
+  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
+  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
 }
 
 //===----------------------------------------------------------------------===//
@@ -182,40 +467,94 @@ void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMinU(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
+  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
+static ConstantIntRanges extUIRange(const ConstantIntRanges &range,
+                                    Type destType) {
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  APInt umin = range.umin().zext(destWidth);
+  APInt umax = range.umax().zext(destWidth);
+  return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  unsigned destWidth =
-      ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+  Type destType = getResult().getType();
+  setResultRange(getResult(), extUIRange(argRanges[0], destType));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
 
+static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
+                                    Type destType) {
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  APInt smin = range.smin().sext(destWidth);
+  APInt smax = range.smax().sext(destWidth);
+  return ConstantIntRanges::fromSigned(smin, smax);
+}
+
 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  unsigned destWidth =
-      ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+  Type destType = getResult().getType();
+  setResultRange(getResult(), extSIRange(argRanges[0], destType));
 }
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
+static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
+                                     Type destType) {
+  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+  // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
+  // the range of the resulting value is not contiguous ind includes 0.
+  // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
+  // but you can't truncate [255, 257] similarly.
+  bool hasUnsignedRollover =
+      range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
+  APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
+                                   : range.umin().trunc(destWidth);
+  APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
+                                   : range.umax().trunc(destWidth);
+
+  // Signed post-truncation rollover will not occur when either:
+  // - The high parts of the min and max, plus the sign bit, are the same
+  // - The high halves + sign bit of the min and max are either all 1s or all 0s
+  //  and you won't create a [positive, negative] range by truncating.
+  // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
+  // but not [255, 257]_i16 to a range of i8s. You can also truncate
+  // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
+  // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
+  // will truncate to 0x7e, which is greater than 0
+  APInt sminHighPart = range.smin().ashr(destWidth - 1);
+  APInt smaxHighPart = range.smax().ashr(destWidth - 1);
+  bool hasSignedOverflow =
+      (sminHighPart != smaxHighPart) &&
+      !(sminHighPart.isAllOnes() &&
+        (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
+      !(sminHighPart.isZero() && smaxHighPart.isZero());
+  APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
+                                 : range.smin().trunc(destWidth);
+  APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
+                                 : range.smax().trunc(destWidth);
+  return {umin, umax, smin, smax};
+}
+
 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                         SetIntRangeFn setResultRange) {
-  unsigned destWidth =
-      ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+  Type destType = getResult().getType();
+  setResultRange(getResult(), truncIRange(argRanges[0], destType));
 }
 
 //===----------------------------------------------------------------------===//
@@ -230,9 +569,9 @@ void arith::IndexCastOp::inferResultRanges(
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+    setResultRange(getResult(), extSIRange(argRanges[0], destType));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+    setResultRange(getResult(), truncIRange(argRanges[0], destType));
   else
     setResultRange(getResult(), argRanges[0]);
 }
@@ -249,9 +588,9 @@ void arith::IndexCastUIOp::inferResultRanges(
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+    setResultRange(getResult(), extUIRange(argRanges[0], destType));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+    setResultRange(getResult(), truncIRange(argRanges[0], destType));
   else
     setResultRange(getResult(), argRanges[0]);
 }
@@ -260,19 +599,51 @@ void arith::IndexCastUIOp::inferResultRanges(
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
+bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
+                      const ConstantIntRanges &rhs) {
+  switch (pred) {
+  case arith::CmpIPredicate::sle:
+  case arith::CmpIPredicate::slt:
+    return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
+  case arith::CmpIPredicate::ule:
+  case arith::CmpIPredicate::ult:
+    return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
+  case arith::CmpIPredicate::sge:
+  case arith::CmpIPredicate::sgt:
+    return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
+  case arith::CmpIPredicate::uge:
+  case arith::CmpIPredicate::ugt:
+    return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
+  case arith::CmpIPredicate::eq: {
+    std::optional<APInt> lhsConst = lhs.getConstantValue();
+    std::optional<APInt> rhsConst = rhs.getConstantValue();
+    return lhsConst && rhsConst && lhsConst == rhsConst;
+  }
+  case arith::CmpIPredicate::ne: {
+    // While equality requires that there is an interpration of the preceeding
+    // computations that produces equal constants, whether that be signed or
+    // unsigned, statically determining inequality requires that neither
+    // interpretation produce potentially overlapping ranges.
+    bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
+               isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
+    bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
+               isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
+    return sne && une;
+  }
+  }
+  return false;
+}
+
 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  arith::CmpIPredicate arithPred = getPredicate();
-  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
+  arith::CmpIPredicate pred = getPredicate();
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
   APInt min = APInt::getZero(1);
   APInt max = APInt::getAllOnesValue(1);
-
-  Optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
-  if (truthValue.has_value() && *truthValue)
+  if (isStaticallyTrue(pred, lhs, rhs))
     min = max;
-  else if (truthValue.has_value() && !(*truthValue))
+  else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
     max = min;
 
   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
@@ -302,7 +673,18 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  ConstArithFn shl = [](const APInt &l,
+                        const APInt &r) -> std::optional<APInt> {
+    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+  };
+  ConstantIntRanges urange =
+      minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+               /*isSigned=*/false);
+  ConstantIntRanges srange =
+      minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+               /*isSigned=*/true);
+  setResultRange(getResult(), urange.intersection(srange));
 }
 
 //===----------------------------------------------------------------------===//
@@ -311,7 +693,15 @@ void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShrU(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn lshr = [](const APInt &l,
+                         const APInt &r) -> std::optional<APInt> {
+    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
+  };
+  setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
+                                       {rhs.umin(), rhs.umax()},
+                                       /*isSigned=*/false));
 }
 
 //===----------------------------------------------------------------------===//
@@ -320,5 +710,14 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShrS(argRanges));
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  ConstArithFn ashr = [](const APInt &l,
+                         const APInt &r) -> std::optional<APInt> {
+    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
+  };
+
+  setResultRange(getResult(),
+                 minMaxBy(ashr, {lhs.smin(), lhs.smax()},
+                          {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
 }

diff  --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt
index e820eececa483..53321f1ea3f25 100644
--- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRIndexDialect
   IndexAttrs.cpp
   IndexDialect.cpp
   IndexOps.cpp
-  InferIntRangeInterfaceImpls.cpp
 
   DEPENDS
   MLIRIndexOpsIncGen
@@ -11,7 +10,6 @@ add_mlir_dialect_library(MLIRIndexDialect
   MLIRDialect
   MLIRIR
   MLIRCastInterfaces
-  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRSideEffectInterfaces
   )

diff  --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
deleted file mode 100644
index 6daa7640b017e..0000000000000
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ /dev/null
@@ -1,252 +0,0 @@
-//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Index/IR/IndexOps.h"
-#include "mlir/Interfaces/InferIntRangeInterface.h"
-#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
-
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "int-range-analysis"
-
-using namespace mlir;
-using namespace mlir::index;
-using namespace mlir::intrange;
-
-//===----------------------------------------------------------------------===//
-// Constants
-//===----------------------------------------------------------------------===//
-
-void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                   SetIntRangeFn setResultRange) {
-  const APInt &value = getValue();
-  setResultRange(getResult(), ConstantIntRanges::constant(value));
-}
-
-void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                       SetIntRangeFn setResultRange) {
-  bool value = getValue();
-  APInt asInt(/*numBits=*/1, value);
-  setResultRange(getResult(), ConstantIntRanges::constant(asInt));
-}
-
-//===----------------------------------------------------------------------===//
-// Arithmec operations. All of these operations will have their results inferred
-// using both the 64-bit values and truncated 32-bit values of their inputs,
-// with the results being the union of those inferences, except where the
-// truncation of the 64-bit result is equal to the 32-bit result (at which time
-// we take the 64-bit result).
-//===----------------------------------------------------------------------===//
-
-void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
-}
-
-void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
-}
-
-void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
-}
-
-void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
-}
-
-void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
-}
-
-void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                   SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
-}
-
-void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                   SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
-}
-
-void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                    SetIntRangeFn setResultRange) {
-  return setResultRange(
-      getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
-}
-
-void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
-}
-
-void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
-}
-
-void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
-}
-
-void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
-}
-
-void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
-}
-
-void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
-}
-
-void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
-}
-
-void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
-}
-
-void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
-}
-
-void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
-}
-
-void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                             SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
-}
-
-void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
-}
-
-//===----------------------------------------------------------------------===//
-// Casts
-//===----------------------------------------------------------------------===//
-
-static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
-                                      unsigned srcWidth, unsigned destWidth,
-                                      bool isSigned) {
-  if (srcWidth < destWidth)
-    return isSigned ? extSIRange(range, destWidth)
-                    : extUIRange(range, destWidth);
-  if (srcWidth > destWidth)
-    return truncRange(range, destWidth);
-  return range;
-}
-
-// When casting to `index`, we will take the union of the possible fixed-width
-// casts.
-static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
-                                        Type sourceType, Type destType,
-                                        bool isSigned) {
-  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
-  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
-  if (sourceType.isIndex())
-    return makeLikeDest(range, srcWidth, destWidth, isSigned);
-  // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
-  ConstantIntRanges storageRange =
-      makeLikeDest(range, srcWidth, destWidth, isSigned);
-  ConstantIntRanges minWidthRange =
-      makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
-  ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
-  ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
-  return ret;
-}
-
-void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                SetIntRangeFn setResultRange) {
-  Type sourceType = getOperand().getType();
-  Type destType = getResult().getType();
-  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
-                                             /*isSigned=*/true));
-}
-
-void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                SetIntRangeFn setResultRange) {
-  Type sourceType = getOperand().getType();
-  Type destType = getResult().getType();
-  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
-                                             /*isSigned=*/false));
-}
-
-//===----------------------------------------------------------------------===//
-// CmpOp
-//===----------------------------------------------------------------------===//
-
-void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                              SetIntRangeFn setResultRange) {
-  index::IndexCmpPredicate indexPred = getPred();
-  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  APInt min = APInt::getZero(1);
-  APInt max = APInt::getAllOnesValue(1);
-
-  Optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
-
-  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
-                    rhsTrunc = truncRange(rhs, indexMinWidth);
-  Optional<bool> truthValue32 =
-      intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
-  if (truthValue64 == truthValue32) {
-    if (truthValue64.has_value() && *truthValue64)
-      min = max;
-    else if (truthValue64.has_value() && !(*truthValue64))
-      max = min;
-  }
-  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
-}
-
-//===----------------------------------------------------------------------===//
-// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
-//===----------------------------------------------------------------------===//
-
-void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
-                                 SetIntRangeFn setResultRange) {
-  unsigned storageWidth =
-      ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  APInt min(/*numBits=*/storageWidth, indexMinWidth);
-  APInt max(/*numBits=*/storageWidth, indexMaxWidth);
-  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
-}

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 38ad0e4a2231c..a7cdbb5b3a6fe 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -51,5 +51,3 @@ add_mlir_interface_library(SideEffectInterfaces)
 add_mlir_interface_library(TilingInterface)
 add_mlir_interface_library(VectorInterfaces)
 add_mlir_interface_library(ViewLikeInterface)
-
-add_subdirectory(Utils)

diff  --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
deleted file mode 100644
index ece6c8e46ffea..0000000000000
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-add_mlir_library(MLIRInferIntRangeCommon
-    InferIntRangeCommon.cpp
-
-    ADDITIONAL_HEADER_DIRS
-    ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
-
-    DEPENDS
-    MLIRInferIntRangeInterfaceIncGen
-
-    LINK_LIBS PUBLIC
-    MLIRInferIntRangeInterface
-    MLIRIR
-)

diff  --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
deleted file mode 100644
index c81f004ecf5f9..0000000000000
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ /dev/null
@@ -1,663 +0,0 @@
-//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains implementations of range inference for operations that are
-// common to both the `arith` and `index` dialects to facilitate reuse.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
-
-#include "mlir/Interfaces/InferIntRangeInterface.h"
-
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-
-#include "llvm/Support/Debug.h"
-
-#include <iterator>
-#include <optional>
-
-using namespace mlir;
-
-#define DEBUG_TYPE "int-range-analysis"
-
-//===----------------------------------------------------------------------===//
-// General utilities
-//===----------------------------------------------------------------------===//
-
-/// Function that evaluates the result of doing something on arithmetic
-/// constants and returns std::nullopt on overflow.
-using ConstArithFn =
-    function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
-
-/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
-/// If either computation overflows, make the result unbounded.
-static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
-                                         const APInt &minRight,
-                                         const APInt &maxLeft,
-                                         const APInt &maxRight, bool isSigned) {
-  std::optional<APInt> maybeMin = op(minLeft, minRight);
-  std::optional<APInt> maybeMax = op(maxLeft, maxRight);
-  if (maybeMin && maybeMax)
-    return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
-  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
-}
-
-/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
-/// ignoring unbounded values. Returns the maximal range if `op` overflows.
-static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
-                                  ArrayRef<APInt> rhs, bool isSigned) {
-  unsigned width = lhs[0].getBitWidth();
-  APInt min =
-      isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
-  APInt max =
-      isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
-  for (const APInt &left : lhs) {
-    for (const APInt &right : rhs) {
-      std::optional<APInt> maybeThisResult = op(left, right);
-      if (!maybeThisResult)
-        return ConstantIntRanges::maxRange(width);
-      APInt result = std::move(*maybeThisResult);
-      min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
-      max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
-    }
-  }
-  return ConstantIntRanges::range(min, max, isSigned);
-}
-
-//===----------------------------------------------------------------------===//
-// Ext, trunc, index op handling
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferIndexOp(InferRangeFn inferFn,
-                             ArrayRef<ConstantIntRanges> argRanges,
-                             intrange::CmpMode mode) {
-  ConstantIntRanges sixtyFour = inferFn(argRanges);
-  SmallVector<ConstantIntRanges, 2> truncated;
-  llvm::transform(argRanges, std::back_inserter(truncated),
-                  [](const ConstantIntRanges &range) {
-                    return truncRange(range, /*destWidth=*/indexMinWidth);
-                  });
-  ConstantIntRanges thirtyTwo = inferFn(truncated);
-  ConstantIntRanges thirtyTwoAsSixtyFour =
-      extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
-  ConstantIntRanges sixtyFourAsThirtyTwo =
-      truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
-
-  LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
-                          << " 32-bit = " << thirtyTwo << "\n");
-  bool truncEqual = false;
-  switch (mode) {
-  case intrange::CmpMode::Both:
-    truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
-    break;
-  case intrange::CmpMode::Signed:
-    truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
-                  thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
-    break;
-  case intrange::CmpMode::Unsigned:
-    truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
-                  thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
-    break;
-  }
-  if (truncEqual)
-    // Returing the 64-bit result preserves more information.
-    return sixtyFour;
-  ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return merged;
-}
-
-ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
-                                           unsigned int destWidth) {
-  APInt umin = range.umin().zext(destWidth);
-  APInt umax = range.umax().zext(destWidth);
-  APInt smin = range.smin().sext(destWidth);
-  APInt smax = range.smax().sext(destWidth);
-  return {umin, umax, smin, smax};
-}
-
-ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
-                                             unsigned destWidth) {
-  APInt umin = range.umin().zext(destWidth);
-  APInt umax = range.umax().zext(destWidth);
-  return ConstantIntRanges::fromUnsigned(umin, umax);
-}
-
-ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
-                                             unsigned destWidth) {
-  APInt smin = range.smin().sext(destWidth);
-  APInt smax = range.smax().sext(destWidth);
-  return ConstantIntRanges::fromSigned(smin, smax);
-}
-
-ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
-                                             unsigned int destWidth) {
-  // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
-  // the range of the resulting value is not contiguous ind includes 0.
-  // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
-  // but you can't truncate [255, 257] similarly.
-  bool hasUnsignedRollover =
-      range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
-  APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
-                                   : range.umin().trunc(destWidth);
-  APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
-                                   : range.umax().trunc(destWidth);
-
-  // Signed post-truncation rollover will not occur when either:
-  // - The high parts of the min and max, plus the sign bit, are the same
-  // - The high halves + sign bit of the min and max are either all 1s or all 0s
-  //  and you won't create a [positive, negative] range by truncating.
-  // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
-  // but not [255, 257]_i16 to a range of i8s. You can also truncate
-  // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
-  // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
-  // will truncate to 0x7e, which is greater than 0
-  APInt sminHighPart = range.smin().ashr(destWidth - 1);
-  APInt smaxHighPart = range.smax().ashr(destWidth - 1);
-  bool hasSignedOverflow =
-      (sminHighPart != smaxHighPart) &&
-      !(sminHighPart.isAllOnes() &&
-        (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
-      !(sminHighPart.isZero() && smaxHighPart.isZero());
-  APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
-                                 : range.smin().trunc(destWidth);
-  APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
-                                 : range.smax().trunc(destWidth);
-  return {umin, umax, smin, smax};
-}
-
-//===----------------------------------------------------------------------===//
-// Addition
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn uadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-    bool overflowed = false;
-    APInt result = a.uadd_ov(b, overflowed);
-    return overflowed ? std::optional<APInt>() : result;
-  };
-  ConstArithFn sadd = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-    bool overflowed = false;
-    APInt result = a.sadd_ov(b, overflowed);
-    return overflowed ? std::optional<APInt>() : result;
-  };
-
-  ConstantIntRanges urange = computeBoundsBy(
-      uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
-  ConstantIntRanges srange = computeBoundsBy(
-      sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return urange.intersection(srange);
-}
-
-//===----------------------------------------------------------------------===//
-// Subtraction
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  ConstArithFn usub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-    bool overflowed = false;
-    APInt result = a.usub_ov(b, overflowed);
-    return overflowed ? std::optional<APInt>() : result;
-  };
-  ConstArithFn ssub = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-    bool overflowed = false;
-    APInt result = a.ssub_ov(b, overflowed);
-    return overflowed ? std::optional<APInt>() : result;
-  };
-  ConstantIntRanges urange = computeBoundsBy(
-      usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
-  ConstantIntRanges srange = computeBoundsBy(
-      ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return urange.intersection(srange);
-}
-
-//===----------------------------------------------------------------------===//
-// Multiplication
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  ConstArithFn umul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-    bool overflowed = false;
-    APInt result = a.umul_ov(b, overflowed);
-    return overflowed ? std::optional<APInt>() : result;
-  };
-  ConstArithFn smul = [](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-    bool overflowed = false;
-    APInt result = a.smul_ov(b, overflowed);
-    return overflowed ? std::optional<APInt>() : result;
-  };
-
-  ConstantIntRanges urange =
-      minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
-               /*isSigned=*/false);
-  ConstantIntRanges srange =
-      minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
-               /*isSigned=*/true);
-  return urange.intersection(srange);
-}
-
-//===----------------------------------------------------------------------===//
-// DivU, CeilDivU (Unsigned division)
-//===----------------------------------------------------------------------===//
-
-/// Fix up division results (ex. for ceiling and floor), returning an APInt
-/// if there has been no overflow
-using DivisionFixupFn = function_ref<std::optional<APInt>(
-    const APInt &lhs, const APInt &rhs, const APInt &result)>;
-
-static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
-                                        const ConstantIntRanges &rhs,
-                                        DivisionFixupFn fixup) {
-  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
-              &rhsMax = rhs.umax();
-
-  if (!rhsMin.isZero()) {
-    auto udiv = [&fixup](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-      return fixup(a, b, a.udiv(b));
-    };
-    return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/false);
-  }
-  // Otherwise, it's possible we might divide by 0.
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
-}
-
-ConstantIntRanges
-mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
-  return inferDivURange(argRanges[0], argRanges[1],
-                        [](const APInt &lhs, const APInt &rhs,
-                           const APInt &result) { return result; });
-}
-
-ConstantIntRanges
-mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  DivisionFixupFn ceilDivUIFix =
-      [](const APInt &lhs, const APInt &rhs,
-         const APInt &result) -> std::optional<APInt> {
-    if (!lhs.urem(rhs).isZero()) {
-      bool overflowed = false;
-      APInt corrected =
-          result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
-      return overflowed ? std::optional<APInt>() : corrected;
-    }
-    return result;
-  };
-  return inferDivURange(lhs, rhs, ceilDivUIFix);
-}
-
-//===----------------------------------------------------------------------===//
-// DivS, CeilDivS, FloorDivS (Signed division)
-//===----------------------------------------------------------------------===//
-
-static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
-                                        const ConstantIntRanges &rhs,
-                                        DivisionFixupFn fixup) {
-  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
-              &rhsMax = rhs.smax();
-  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
-
-  if (canDivide) {
-    auto sdiv = [&fixup](const APInt &a,
-                         const APInt &b) -> std::optional<APInt> {
-      bool overflowed = false;
-      APInt result = a.sdiv_ov(b, overflowed);
-      return overflowed ? std::optional<APInt>() : fixup(a, b, result);
-    };
-    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/true);
-  }
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
-}
-
-ConstantIntRanges
-mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
-  return inferDivSRange(argRanges[0], argRanges[1],
-                        [](const APInt &lhs, const APInt &rhs,
-                           const APInt &result) { return result; });
-}
-
-ConstantIntRanges
-mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  DivisionFixupFn ceilDivSIFix =
-      [](const APInt &lhs, const APInt &rhs,
-         const APInt &result) -> std::optional<APInt> {
-    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
-      bool overflowed = false;
-      APInt corrected =
-          result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
-      return overflowed ? std::optional<APInt>() : corrected;
-    }
-    return result;
-  };
-  return inferDivSRange(lhs, rhs, ceilDivSIFix);
-}
-
-ConstantIntRanges
-mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  DivisionFixupFn floorDivSIFix =
-      [](const APInt &lhs, const APInt &rhs,
-         const APInt &result) -> std::optional<APInt> {
-    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
-      bool overflowed = false;
-      APInt corrected =
-          result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
-      return overflowed ? std::optional<APInt>() : corrected;
-    }
-    return result;
-  };
-  return inferDivSRange(lhs, rhs, floorDivSIFix);
-}
-
-//===----------------------------------------------------------------------===//
-// Signed remainder (RemS)
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
-              &rhsMax = rhs.smax();
-
-  unsigned width = rhsMax.getBitWidth();
-  APInt smin = APInt::getSignedMinValue(width);
-  APInt smax = APInt::getSignedMaxValue(width);
-  // No bounds if zero could be a divisor.
-  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
-  if (canBound) {
-    APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
-    bool canNegativeDividend = lhsMin.isNegative();
-    bool canPositiveDividend = lhsMax.isStrictlyPositive();
-    APInt zero = APInt::getZero(maxDivisor.getBitWidth());
-    APInt maxPositiveResult = maxDivisor - 1;
-    APInt minNegativeResult = -maxPositiveResult;
-    smin = canNegativeDividend ? minNegativeResult : zero;
-    smax = canPositiveDividend ? maxPositiveResult : zero;
-    // Special case: sweeping out a contiguous range in N/[modulus].
-    if (rhsMin == rhsMax) {
-      if ((lhsMax - lhsMin).ult(maxDivisor)) {
-        APInt minRem = lhsMin.srem(maxDivisor);
-        APInt maxRem = lhsMax.srem(maxDivisor);
-        if (minRem.sle(maxRem)) {
-          smin = minRem;
-          smax = maxRem;
-        }
-      }
-    }
-  }
-  return ConstantIntRanges::fromSigned(smin, smax);
-}
-
-//===----------------------------------------------------------------------===//
-// Unsigned remainder (RemU)
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
-
-  unsigned width = rhsMin.getBitWidth();
-  APInt umin = APInt::getZero(width);
-  APInt umax = APInt::getMaxValue(width);
-
-  if (!rhsMin.isZero()) {
-    umax = rhsMax - 1;
-    // Special case: sweeping out a contiguous range in N/[modulus]
-    if (rhsMin == rhsMax) {
-      const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
-      if ((lhsMax - lhsMin).ult(rhsMax)) {
-        APInt minRem = lhsMin.urem(rhsMax);
-        APInt maxRem = lhsMax.urem(rhsMax);
-        if (minRem.ule(maxRem)) {
-          umin = minRem;
-          umax = maxRem;
-        }
-      }
-    }
-  }
-  return ConstantIntRanges::fromUnsigned(umin, umax);
-}
-
-//===----------------------------------------------------------------------===//
-// Max and min (MaxS, MaxU, MinS, MinU)
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
-  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
-  return ConstantIntRanges::fromSigned(smin, smax);
-}
-
-ConstantIntRanges
-mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
-  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
-  return ConstantIntRanges::fromUnsigned(umin, umax);
-}
-
-ConstantIntRanges
-mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
-  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
-  return ConstantIntRanges::fromSigned(smin, smax);
-}
-
-ConstantIntRanges
-mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
-  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
-  return ConstantIntRanges::fromUnsigned(umin, umax);
-}
-
-//===----------------------------------------------------------------------===//
-// Bitwise operators (And, Or, Xor)
-//===----------------------------------------------------------------------===//
-
-/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
-/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
-/// that both bonuds have in common. This gives us a consertive approximation
-/// for what values can be passed to bitwise operations.
-static std::tuple<APInt, APInt>
-widenBitwiseBounds(const ConstantIntRanges &bound) {
-  APInt leftVal = bound.umin(), rightVal = bound.umax();
-  unsigned bitwidth = leftVal.getBitWidth();
-  unsigned 
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
-  leftVal.clearLowBits(
diff eringBits);
-  rightVal.setLowBits(
diff eringBits);
-  return std::make_tuple(std::move(leftVal), std::move(rightVal));
-}
-
-ConstantIntRanges
-mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
-  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
-  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
-  auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
-    return a & b;
-  };
-  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
-}
-
-ConstantIntRanges
-mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
-  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
-  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
-  auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
-    return a | b;
-  };
-  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
-}
-
-ConstantIntRanges
-mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
-  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
-  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
-  auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
-    return a ^ b;
-  };
-  return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Shifts (Shl, ShrS, ShrU)
-//===----------------------------------------------------------------------===//
-
-ConstantIntRanges
-mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-  ConstArithFn shl = [](const APInt &l,
-                        const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
-  };
-  ConstantIntRanges urange =
-      minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
-               /*isSigned=*/false);
-  ConstantIntRanges srange =
-      minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
-               /*isSigned=*/true);
-  return urange.intersection(srange);
-}
-
-ConstantIntRanges
-mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  ConstArithFn ashr = [](const APInt &l,
-                         const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
-  };
-
-  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
-                  /*isSigned=*/true);
-}
-
-ConstantIntRanges
-mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  ConstArithFn lshr = [](const APInt &l,
-                         const APInt &r) -> std::optional<APInt> {
-    return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
-  };
-  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
-                  /*isSigned=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Comparisons (Cmp)
-//===----------------------------------------------------------------------===//
-
-static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
-  switch (pred) {
-  case intrange::CmpPredicate::eq:
-    return intrange::CmpPredicate::ne;
-  case intrange::CmpPredicate::ne:
-    return intrange::CmpPredicate::eq;
-  case intrange::CmpPredicate::slt:
-    return intrange::CmpPredicate::sge;
-  case intrange::CmpPredicate::sle:
-    return intrange::CmpPredicate::sgt;
-  case intrange::CmpPredicate::sgt:
-    return intrange::CmpPredicate::sle;
-  case intrange::CmpPredicate::sge:
-    return intrange::CmpPredicate::slt;
-  case intrange::CmpPredicate::ult:
-    return intrange::CmpPredicate::uge;
-  case intrange::CmpPredicate::ule:
-    return intrange::CmpPredicate::ugt;
-  case intrange::CmpPredicate::ugt:
-    return intrange::CmpPredicate::ule;
-  case intrange::CmpPredicate::uge:
-    return intrange::CmpPredicate::ult;
-  }
-  llvm_unreachable("unknown cmp predicate value");
-}
-
-static bool isStaticallyTrue(intrange::CmpPredicate pred,
-                             const ConstantIntRanges &lhs,
-                             const ConstantIntRanges &rhs) {
-  switch (pred) {
-  case intrange::CmpPredicate::sle:
-    return lhs.smax().sle(rhs.smin());
-  case intrange::CmpPredicate::slt:
-    return lhs.smax().slt(rhs.smin());
-  case intrange::CmpPredicate::ule:
-    return lhs.umax().ule(rhs.umin());
-  case intrange::CmpPredicate::ult:
-    return lhs.umax().ult(rhs.umin());
-  case intrange::CmpPredicate::sge:
-    return lhs.smin().sge(rhs.smax());
-  case intrange::CmpPredicate::sgt:
-    return lhs.smin().sgt(rhs.smax());
-  case intrange::CmpPredicate::uge:
-    return lhs.umin().uge(rhs.umax());
-  case intrange::CmpPredicate::ugt:
-    return lhs.umin().ugt(rhs.umax());
-  case intrange::CmpPredicate::eq: {
-    std::optional<APInt> lhsConst = lhs.getConstantValue();
-    std::optional<APInt> rhsConst = rhs.getConstantValue();
-    return lhsConst && rhsConst && lhsConst == rhsConst;
-  }
-  case intrange::CmpPredicate::ne: {
-    // While equality requires that there is an interpration of the preceeding
-    // computations that produces equal constants, whether that be signed or
-    // unsigned, statically determining inequality requires that neither
-    // interpretation produce potentially overlapping ranges.
-    bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
-               isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
-    bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
-               isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
-    return sne && une;
-  }
-  }
-  return false;
-}
-
-std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
-                                                 const ConstantIntRanges &lhs,
-                                                 const ConstantIntRanges &rhs) {
-  if (isStaticallyTrue(pred, lhs, rhs))
-    return true;
-  if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
-    return false;
-  return std::nullopt;
-}

diff  --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
deleted file mode 100644
index 2784d5fd5cf70..0000000000000
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ /dev/null
@@ -1,66 +0,0 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
-
-// Most operations are covered by the `arith` tests, which use the same code
-// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
-// code is operating as expected.
-
-// CHECK-LABEL: func @add_same_for_both
-// CHECK: %[[true:.*]] = index.bool.constant true
-// CHECK: return %[[true]]
-func.func @add_same_for_both(%arg0 : index) -> i1 {
-  %c1 = index.constant 1
-  %calmostBig = index.constant 0xfffffffe
-  %0 = index.minu %arg0, %calmostBig
-  %1 = index.add %0, %c1
-  %2 = index.cmp uge(%1, %c1)
-  func.return %2 : i1
-}
-
-// CHECK-LABEL: func @add_unsigned_ov
-// CHECK: %[[uge:.*]] = index.cmp uge
-// CHECK: return %[[uge]]
-func.func @add_unsigned_ov(%arg0 : index) -> i1 {
-  %c1 = index.constant 1
-  %cu32_max = index.constant 0xffffffff
-  %0 = index.minu %arg0, %cu32_max
-  %1 = index.add %0, %c1
-  // On 32-bit, the add could wrap, so the result doesn't have to be >= 1
-  %2 = index.cmp uge(%1, %c1)
-  func.return %2 : i1
-}
-
-// CHECK-LABEL: func @add_signed_ov
-// CHECK: %[[sge:.*]] = index.cmp sge
-// CHECK: return %[[sge]]
-func.func @add_signed_ov(%arg0 : index) -> i1 {
-  %c0 = index.constant 0
-  %c1 = index.constant 1
-  %ci32_max = index.constant 0x7fffffff
-  %0 = index.minu %arg0, %ci32_max
-  %1 = index.add %0, %c1
-  // On 32-bit, the add could wrap, so the result doesn't have to be positive
-  %2 = index.cmp sge(%1, %c0)
-  func.return %2 : i1
-}
-
-// CHECK-LABEL: func @add_big
-// CHECK: %[[true:.*]] = index.bool.constant true
-// CHECK: return %[[true]]
-func.func @add_big(%arg0 : index) -> i1 {
-  %c1 = index.constant 1
-  %cmin = index.constant 0x300000000
-  %cmax = index.constant 0x30000ffff
-  // Note: the order of the clamps matters.
-  // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff]
-  // and then [0x30...0000, 0x30...ffff]
-  // If you switch the order of the below operations, you instead first infer
-  // the range [0,0x3...ffff]. Then, the min inference can't constraint
-  // this intermediate, since in the 32-bit case we could have, for example
-  // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff
-  // which means we can't do any inference.
-  %0 = index.maxu %arg0, %cmin
-  %1 = index.minu %0, %cmax
-  %2 = index.add %1, %c1
-  %3 = index.cmp uge(%1, %cmin)
-  func.return %3 : i1
-}


        


More information about the Mlir-commits mailing list