[Mlir-commits] [mlir] [mlir][dataflow] Fix for integer range analysis propagation bug (PR #93199)

Spenser Bauman llvmlistbot at llvm.org
Sat May 25 12:59:08 PDT 2024


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/93199

>From 377db1af51d8b92053467f03974b4a9d823fac54 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Thu, 23 May 2024 08:18:55 -0400
Subject: [PATCH 1/4] [mlir][dataflow] Fix for integer range analysis
 propagation bug

Integer range analysis will not update the range of an operation when
any of the inferred input lattices are uninitialized. In the current
behavior, all lattice values for non integer types are uninitialized.

For operations like arith.cmpf

```mlir
%3 = arith.cmpf ugt, %arg0, %arg1 : f32
```

that will result in the range of the output also being uninitialized,
and so on for any consumer of the arith.cmpf result. When control-flow
ops are involved, the lack of propagation results in incorrect ranges,
as the back edges for loop carried values are not properly joined with
the definitions from the body region.

For example, an scf.while loop whose body region produces a value that
is in a dataflow relationship with some floating-point values through
an arith.cmpf operation:

```mlir
func.func @test_bad_range(%arg0: f32, %arg1: f32) -> (index, index) {
  %c4 = arith.constant 4 : index
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index

  %3 = arith.cmpf ugt, %arg0, %arg1 : f32

  %1:2 = scf.while (%arg2 = %c0, %arg3 = %c0) : (index, index) -> (index, index) {
    %2 = arith.cmpi ult, %arg2, %c4 : index
    scf.condition(%2) %arg2, %arg3 : index, index
  } do {
  ^bb0(%arg2: index, %arg3: index):
    %4 = arith.select %3, %arg3, %arg3 : index
    %5 = arith.addi %arg2, %c1 : index
    scf.yield %5, %4 : index, index
  }

  return %1#0, %1#1 : index, index
}
```

The existing behavior results in the control condition %2 being
optimized to true, turning the while loop into an infinite loop. The
update to %arg2 through the body region is never factored into the range
calculation, as the ranges for the body ops all test as uninitialized.

This change causes all values initialized with setToEntryState to
be set to some initialized range, even if the values are not integers.
---
 .../Analysis/DataFlow/IntegerRangeAnalysis.cpp |  2 --
 .../Dialect/Arith/int-range-interface.mlir     | 18 ++++++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index a82c30717e275..b69b2e0416209 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -38,8 +38,6 @@ using namespace mlir::dataflow;
 
 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
-  if (width == 0)
-    return {};
   APInt umin = APInt::getMinValue(width);
   APInt umax = APInt::getMaxValue(width);
   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 5b538197a0c11..fdeb8a2e6c935 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -899,3 +899,21 @@ func.func @test_shl_i8_nowrap() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+/// A test case to ensure that the ranges for unsupported ops are initialized
+/// properly to maxRange, rather than left uninitialized.
+/// In this test case, the previous behavior would leave the ranges for %a and
+/// %b uninitialized, resulting in arith.cmpf's range not being updated, even
+/// though it has an integer valued result.
+
+// CHECK-LABEL: func @test_cmpf_propagates
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+func.func @test_cmpf_propagates(%a: f32, %b: f32) -> index {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+
+  %0 = arith.cmpf ueq, %a, %b : f32
+  %1 = arith.select %0, %c1, %c2 : index
+  %2 = test.reflect_bounds %1 : index
+  func.return %2 : index
+}

>From 7410f331bef2c7c968a9d6d29be5b07cfde61aaf Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Thu, 23 May 2024 15:21:25 -0400
Subject: [PATCH 2/4] Rework integer range analysis interfaces

Modify the integer range analysis interfaces to handle uninitialized
values by allowing the inferred input ranges to be optional.
---
 .../Analysis/DataFlow/IntegerRangeAnalysis.h  |   2 +-
 .../mlir/Interfaces/InferIntRangeInterface.h  |   3 +-
 .../mlir/Interfaces/InferIntRangeInterface.td |   2 +-
 .../Interfaces/Utils/InferIntRangeCommon.h    |   7 +-
 .../DataFlow/IntegerRangeAnalysis.cpp         |  45 +--
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  | 167 ++++++-----
 .../GPU/IR/InferIntRangeInterfaceImpls.cpp    |  32 ++-
 .../Index/IR/InferIntRangeInterfaceImpls.cpp  | 265 ++++++++++++------
 .../Interfaces/Utils/InferIntRangeCommon.cpp  |  17 ++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  31 +-
 10 files changed, 366 insertions(+), 205 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 8bd7cf880c6af..fb07013041c0e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -33,7 +33,7 @@ class IntegerValueRange {
   static IntegerValueRange getMaxRange(Value value);
 
   /// Create an integer value range lattice value.
-  IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+  IntegerValueRange(OptionalIntRanges value = std::nullopt)
       : value(std::move(value)) {}
 
   /// Whether the range is uninitialized. This happens when the state hasn't
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 05064a72ef02e..3d499b420eadd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -105,10 +105,11 @@ class ConstantIntRanges {
 
 raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
 
+using OptionalIntRanges = std::optional<ConstantIntRanges>;
 /// The type of the `setResultRanges` callback provided to ops implementing
 /// InferIntRangeInterface. It should be called once for each integer result
 /// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
 } // end namespace mlir
 
 #include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index dbdc526c6f10b..f8e2c98d87cdb 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
        APInts in their `argRanges` element.
     }],
     "void", "inferResultRanges", (ins
-      "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+      "::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
       "::mlir::SetIntRangeFn":$setResultRanges)
   >];
 }
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 851bb534bc7ee..9e3b04535dcab 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -25,7 +25,10 @@ namespace intrange {
 /// abstracted away here to permit writing the function that handles both
 /// 64- and 32-bit index types.
 using InferRangeFn =
-    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+    std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+
+using OptionalRangeFn =
+    std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
 
 static constexpr unsigned indexMinWidth = 32;
 static constexpr unsigned indexMaxWidth = 64;
@@ -44,6 +47,8 @@ enum class OverflowFlags : uint32_t {
 using InferRangeWithOvfFlagsFn =
     function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
 
+OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index b69b2e0416209..622d875a63ace 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -36,8 +36,26 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
+namespace {
+
+OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
+  if (range.isUninitialized())
+    return std::nullopt;
+  return range.getValue();
+}
+
+OptionalIntRanges
+getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
+  return getOptionalRange(lattice->getValue());
+}
+
+} // end namespace
+
 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+  if (width == 0)
+    return {};
+
   APInt umin = APInt::getMinValue(width);
   APInt umax = APInt::getMaxValue(width);
   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
@@ -71,23 +89,14 @@ void IntegerRangeAnalysis::visitOperation(
     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
     ArrayRef<IntegerValueRangeLattice *> results) {
   // If the lattice on any operand is unitialized, bail out.
-  if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
-        return lattice->getValue().isUninitialized();
-      })) {
-    return;
-  }
-
   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
   if (!inferrable)
     return setAllToEntryStates(results);
 
   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
-  SmallVector<ConstantIntRanges> argRanges(
-      llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
-        return val->getValue().getValue();
-      }));
+  auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
 
-  auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+  auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
     auto result = dyn_cast<OpResult>(v);
     if (!result)
       return;
@@ -97,7 +106,9 @@ void IntegerRangeAnalysis::visitOperation(
     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
     IntegerValueRange oldRange = lattice->getValue();
 
-    ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+    ChangeResult changed =
+        attrs ? lattice->join(IntegerValueRange{attrs})
+              : lattice->join(IntegerValueRange::getMaxRange(v));
 
     // Catch loop results with loop variant bounds and conservatively make
     // them [-inf, inf] so we don't circle around infinitely often (because
@@ -127,12 +138,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
           return getLatticeElementFor(op, value)->getValue().isUninitialized();
         }))
       return;
-    SmallVector<ConstantIntRanges> argRanges(
+    SmallVector<OptionalIntRanges> argRanges(
         llvm::map_range(op->getOperands(), [&](Value value) {
-          return getLatticeElementFor(op, value)->getValue().getValue();
+          return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
         }));
 
-    auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+    auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
       auto arg = dyn_cast<BlockArgument>(v);
       if (!arg)
         return;
@@ -143,7 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
       IntegerValueRange oldRange = lattice->getValue();
 
-      ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+      ChangeResult changed =
+          attrs ? lattice->join(IntegerValueRange{attrs})
+                : lattice->join(IntegerValueRange::getMaxRange(v));
 
       // Catch loop results with loop variant bounds and conservatively make
       // them [-inf, inf] so we don't circle around infinitely often (because
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index fbe2ecab8adca..b59e5f9ec5a3e 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
-#include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "int-range-analysis"
@@ -33,7 +32,7 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
-void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                           SetIntRangeFn setResultRange) {
   auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
   if (constAttr) {
@@ -46,48 +45,57 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::DivUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferDivU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferDivU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // DivSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferDivS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -95,8 +103,8 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivUIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferCeilDivU(argRanges));
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), inferFromOptionals(inferCeilDivU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -104,8 +112,8 @@ void arith::CeilDivUIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivSIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferCeilDivS(argRanges));
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), inferFromOptionals(inferCeilDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -113,122 +121,132 @@ void arith::CeilDivSIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::FloorDivSIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  return setResultRange(getResult(), inferFloorDivS(argRanges));
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  return setResultRange(getResult(),
+                        inferFromOptionals(inferFloorDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::RemUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferRemU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferRemU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::RemSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferRemS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferRemS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::AndIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAnd(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferAnd)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // OrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::OrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferOr(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferOr)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::XOrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferXor(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferXor)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MaxSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMaxS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMaxS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MaxUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMaxU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMaxU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MinSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMinS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMinS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MinUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMinU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMinU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ExtUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+  setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ExtSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+  setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                         SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+  setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
@@ -236,18 +254,21 @@ void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+    setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+    setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
   else
-    setResultRange(getResult(), argRanges[0]);
+    setResultRange(getResult(), *argRanges[0]);
 }
 
 //===----------------------------------------------------------------------===//
@@ -255,34 +276,40 @@ void arith::IndexCastOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastUIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+    setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+    setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
   else
-    setResultRange(getResult(), argRanges[0]);
+    setResultRange(getResult(), *argRanges[0]);
 }
 
 //===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
   arith::CmpIPredicate arithPred = getPredicate();
   intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const OptionalIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  if (!lhs || !rhs)
+    return;
 
   APInt min = APInt::getZero(1);
   APInt max = APInt::getAllOnes(1);
 
-  std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
+  std::optional<bool> truthValue = intrange::evaluatePred(pred, *lhs, *rhs);
   if (truthValue.has_value() && *truthValue)
     min = max;
   else if (truthValue.has_value() && !(*truthValue))
@@ -295,9 +322,10 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::SelectOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                         SetIntRangeFn setResultRange) {
-  std::optional<APInt> mbCondVal = argRanges[0].getConstantValue();
+  std::optional<APInt> mbCondVal =
+      argRanges[0] ? argRanges[0]->getConstantValue() : std::nullopt;
 
   if (mbCondVal) {
     if (mbCondVal->isZero())
@@ -306,33 +334,40 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
       setResultRange(getResult(), argRanges[1]);
     return;
   }
-  setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
+
+  if (argRanges[1] && argRanges[2])
+    setResultRange(getResult(), argRanges[1]->rangeUnion(*argRanges[2]));
 }
 
 //===----------------------------------------------------------------------===//
 // ShLIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ShRUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShrU(argRanges));
+  auto infer = inferFromOptionals(inferShrU);
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ShRSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShRSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShrS(argRanges));
+  auto infer = inferFromOptionals(inferShrS);
+  setResultRange(getResult(), infer(argRanges));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 69017efb9a0e6..1342271029fa9 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -84,18 +84,18 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   return std::nullopt;
 }
 
-void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ClusterDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
 }
 
-void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ClusterIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                     SetIntRangeFn setResultRange) {
   uint64_t max = kMaxClusterDim;
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                    SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal =
       getKnownLaunchDim(*this, LaunchDims::Block);
@@ -105,13 +105,13 @@ void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void BlockIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                   SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                   SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
   if (knownVal)
@@ -120,23 +120,23 @@ void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ThreadIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                    SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void LaneIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                  SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
 }
 
-void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void SubgroupIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
 }
 
-void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                    SetIntRangeFn setResultRange) {
   uint64_t blockDimMax =
       getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
@@ -146,24 +146,26 @@ void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
 }
 
-void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void NumSubgroupsOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void SubgroupSizeOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
 }
 
-void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void LaunchOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                  SetIntRangeFn setResultRange) {
-  auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
+  auto setRange = [&](const OptionalIntRanges &argRange, Value dimResult,
                       Value idxResult) {
-    if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
+    if (!argRange ||
+        argRange->umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
       return;
+
     ConstantIntRanges dimRange =
-        argRange.intersection(getIndexRange(1, kMaxDim));
+        argRange->intersection(getIndexRange(1, kMaxDim));
     setResultRange(dimResult, dimRange);
     ConstantIntRanges idxRange =
         getIndexRange(0, dimRange.umax().getZExtValue() - 1);
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index 64adb6b850524..cc6709f1253da 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
-#include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "int-range-analysis"
@@ -23,13 +22,13 @@ using namespace mlir::intrange;
 // Constants
 //===----------------------------------------------------------------------===//
 
-void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
   const APInt &value = getValue();
   setResultRange(getResult(), ConstantIntRanges::constant(value));
 }
 
-void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void BoolConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
   bool value = getValue();
   APInt asInt(/*numBits=*/1, value);
@@ -49,129 +48,195 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // the inference function without any `OverflowFlags`.
 static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
 inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
-  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+  return [inferWithOvfFn](
+             ArrayRef<ConstantIntRanges> argRanges) -> ConstantIntRanges {
     return inferWithOvfFn(argRanges, OverflowFlags::None);
   };
 }
 
-void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void AddOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void SubOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MulOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void DivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+                        CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void DivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CeilDivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CeilDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void FloorDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                     SetIntRangeFn setResultRange) {
-  return setResultRange(
-      getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
+  });
+
+  return setResultRange(getResult(), infer(argRanges));
 }
 
-void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void RemSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void RemUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MaxSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MaxUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MinSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MinUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShlOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShrSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShrUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void AndOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void OrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void XOrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -208,56 +273,70 @@ static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
   return ret;
 }
 
-void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CastSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
-  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
-                                             /*isSigned=*/true));
+
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/true);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
-  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
-                                             /*isSigned=*/false));
+
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/false);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // CmpOp
 //===----------------------------------------------------------------------===//
 
-void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  index::IndexCmpPredicate indexPred = getPred();
-  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  APInt min = APInt::getZero(1);
-  APInt max = APInt::getAllOnes(1);
-
-  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
-
-  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
-                    rhsTrunc = truncRange(rhs, indexMinWidth);
-  std::optional<bool> truthValue32 =
-      intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
-  if (truthValue64 == truthValue32) {
-    if (truthValue64.has_value() && *truthValue64)
-      min = max;
-    else if (truthValue64.has_value() && !(*truthValue64))
-      max = min;
-  }
-  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    index::IndexCmpPredicate indexPred = getPred();
+    intrange::CmpPredicate pred =
+        static_cast<intrange::CmpPredicate>(indexPred);
+    const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
+
+    APInt min = APInt::getZero(1);
+    APInt max = APInt::getAllOnes(1);
+
+    std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
+
+    ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+                      rhsTrunc = truncRange(rhs, indexMinWidth);
+    std::optional<bool> truthValue32 =
+        intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+    if (truthValue64 == truthValue32) {
+      if (truthValue64.has_value() && *truthValue64)
+        min = max;
+      else if (truthValue64.has_value() && !(*truthValue64))
+        max = min;
+    }
+
+    return ConstantIntRanges::fromUnsigned(min, max);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
 //===----------------------------------------------------------------------===//
 
-void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void SizeOfOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                  SetIntRangeFn setResultRange) {
   unsigned storageWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index fe1a67d628738..78754680ae58d 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -36,6 +36,23 @@ using namespace mlir;
 using ConstArithFn =
     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
 
+std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>
+mlir::intrange::inferFromOptionals(intrange::InferRangeFn inferFn) {
+  return [inferFn = std::move(inferFn)](
+             ArrayRef<OptionalIntRanges> args) -> OptionalIntRanges {
+    llvm::SmallVector<ConstantIntRanges> unpacked;
+    unpacked.reserve(args.size());
+
+    for (const OptionalIntRanges &arg : args) {
+      if (!arg)
+        return std::nullopt;
+      unpacked.push_back(*arg);
+    }
+
+    return inferFn(unpacked);
+  };
+}
+
 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
 /// If either computation overflows, make the result unbounded.
 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b058a8e1abbcb..145b076c95a76 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -648,9 +648,10 @@ LogicalResult TestVerifiersOp::verifyRegions() {
 //===----------------------------------------------------------------------===//
 // TestWithBoundsOp
 
-void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void TestWithBoundsOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                          SetIntRangeFn setResultRanges) {
-  setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+  setResultRanges(getResult(), ConstantIntRanges{getUmin(), getUmax(),
+                                                 getSmin(), getSmax()});
 }
 
 //===----------------------------------------------------------------------===//
@@ -681,29 +682,37 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
 }
 
 void TestWithBoundsRegionOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
   Value arg = getRegion().getArgument(0);
-  setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+  setResultRanges(
+      arg, ConstantIntRanges{getUmin(), getUmax(), getSmin(), getSmax()});
 }
 
 //===----------------------------------------------------------------------===//
 // TestIncrementOp
 
-void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                         SetIntRangeFn setResultRanges) {
-  const ConstantIntRanges &range = argRanges[0];
+  if (!argRanges[0])
+    return;
+
+  const ConstantIntRanges &range = *argRanges[0];
   APInt one(range.umin().getBitWidth(), 1);
-  setResultRanges(getResult(),
-                  {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
-                   range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+  setResultRanges(getResult(), ConstantIntRanges{range.umin().uadd_sat(one),
+                                                 range.umax().uadd_sat(one),
+                                                 range.smin().sadd_sat(one),
+                                                 range.smax().sadd_sat(one)});
 }
 
 //===----------------------------------------------------------------------===//
 // TestReflectBoundsOp
 
 void TestReflectBoundsOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
-  const ConstantIntRanges &range = argRanges[0];
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  if (!argRanges[0])
+    return;
+
+  const ConstantIntRanges &range = *argRanges[0];
   MLIRContext *ctx = getContext();
   Builder b(ctx);
   Type sIntTy, uIntTy;

>From b354130d207aadb58f28467f2c485180e6acbbdd Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Fri, 24 May 2024 08:32:14 -0400
Subject: [PATCH 3/4] Convert uses of OptionalIntRange to IntegerValueRange

IntegerValueRange already exists and encodes the extact information that
we want to represent with OptionalIntRange. This makes the APIs clearer
than passing an std::optional everywhere.
---
 .../Analysis/DataFlow/IntegerRangeAnalysis.h  |  45 ---
 .../mlir/Interfaces/InferIntRangeInterface.h  |  53 +++-
 .../mlir/Interfaces/InferIntRangeInterface.td |   2 +-
 .../Interfaces/Utils/InferIntRangeCommon.h    |  13 +-
 .../DataFlow/IntegerRangeAnalysis.cpp         |  57 +---
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  | 167 ++++++-----
 .../GPU/IR/InferIntRangeInterfaceImpls.cpp    |  35 ++-
 .../Index/IR/InferIntRangeInterfaceImpls.cpp  | 278 ++++++++++--------
 .../lib/Interfaces/InferIntRangeInterface.cpp |  17 ++
 .../Interfaces/Utils/InferIntRangeCommon.cpp  |  18 +-
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  16 +-
 11 files changed, 367 insertions(+), 334 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index fb07013041c0e..191c023fb642c 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -24,51 +24,6 @@
 namespace mlir {
 namespace dataflow {
 
-/// This lattice value represents the integer range of an SSA value.
-class IntegerValueRange {
-public:
-  /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
-  /// range that is used to mark the value as unable to be analyzed further,
-  /// where `t` is the type of `value`.
-  static IntegerValueRange getMaxRange(Value value);
-
-  /// Create an integer value range lattice value.
-  IntegerValueRange(OptionalIntRanges value = std::nullopt)
-      : value(std::move(value)) {}
-
-  /// Whether the range is uninitialized. This happens when the state hasn't
-  /// been set during the analysis.
-  bool isUninitialized() const { return !value.has_value(); }
-
-  /// Get the known integer value range.
-  const ConstantIntRanges &getValue() const {
-    assert(!isUninitialized());
-    return *value;
-  }
-
-  /// Compare two ranges.
-  bool operator==(const IntegerValueRange &rhs) const {
-    return value == rhs.value;
-  }
-
-  /// Take the union of two ranges.
-  static IntegerValueRange join(const IntegerValueRange &lhs,
-                                const IntegerValueRange &rhs) {
-    if (lhs.isUninitialized())
-      return rhs;
-    if (rhs.isUninitialized())
-      return lhs;
-    return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
-  }
-
-  /// Print the integer value range.
-  void print(raw_ostream &os) const { os << value; }
-
-private:
-  /// The known integer value range.
-  std::optional<ConstantIntRanges> value;
-};
-
 /// This lattice element represents the integer value range of an SSA value.
 /// When this lattice is updated, it automatically updates the constant value
 /// of the SSA value (if the range can be narrowed to one).
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 3d499b420eadd..73013837f1227 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -105,11 +105,60 @@ class ConstantIntRanges {
 
 raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
 
-using OptionalIntRanges = std::optional<ConstantIntRanges>;
+/// This lattice value represents the integer range of an SSA value.
+class IntegerValueRange {
+public:
+  /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+  /// range that is used to mark the value as unable to be analyzed further,
+  /// where `t` is the type of `value`.
+  static IntegerValueRange getMaxRange(Value value);
+
+  /// Create an integer value range lattice value.
+  IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
+
+  /// Create an integer value range lattice value.
+  IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+      : value(std::move(value)) {}
+
+  /// Whether the range is uninitialized. This happens when the state hasn't
+  /// been set during the analysis.
+  bool isUninitialized() const { return !value.has_value(); }
+
+  /// Get the known integer value range.
+  const ConstantIntRanges &getValue() const {
+    assert(!isUninitialized());
+    return *value;
+  }
+
+  /// Compare two ranges.
+  bool operator==(const IntegerValueRange &rhs) const {
+    return value == rhs.value;
+  }
+
+  /// Compute the least upper bound of two ranges.
+  static IntegerValueRange join(const IntegerValueRange &lhs,
+                                const IntegerValueRange &rhs) {
+    if (lhs.isUninitialized())
+      return rhs;
+    if (rhs.isUninitialized())
+      return lhs;
+    return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
+  }
+
+  /// Print the integer value range.
+  void print(raw_ostream &os) const { os << value; }
+
+private:
+  /// The known integer value range.
+  std::optional<ConstantIntRanges> value;
+};
+
+raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
+
 /// The type of the `setResultRanges` callback provided to ops implementing
 /// InferIntRangeInterface. It should be called once for each integer result
 /// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
+using SetIntRangeFn = function_ref<void(Value, const IntegerValueRange &)>;
 } // end namespace mlir
 
 #include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index f8e2c98d87cdb..795e67b8431bd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
        APInts in their `argRanges` element.
     }],
     "void", "inferResultRanges", (ins
-      "::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
+      "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
       "::mlir::SetIntRangeFn":$setResultRanges)
   >];
 }
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 9e3b04535dcab..8746a1cfba85c 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -27,8 +27,9 @@ namespace intrange {
 using InferRangeFn =
     std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
 
-using OptionalRangeFn =
-    std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
+/// Function that performs inferrence on an array of `IntegerValueRange`.
+using InferIntegerValueRangeFn =
+    std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;
 
 static constexpr unsigned indexMinWidth = 32;
 static constexpr unsigned indexMaxWidth = 64;
@@ -47,7 +48,11 @@ enum class OverflowFlags : uint32_t {
 using InferRangeWithOvfFlagsFn =
     function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
 
-OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
+/// Perform a pointwise extension of a function operating on `ConstantIntRanges`
+/// to a function operating on `IntegerValueRange` such that undefined input
+/// ranges propagate.
+InferIntegerValueRangeFn
+inferFromIntegerValueRange(intrange::InferRangeFn inferFn);
 
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
@@ -57,7 +62,7 @@ OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
 ///
 /// The `mode` argument specifies if the unsigned, signed, or both results of
 /// the inference computation should be used when comparing the results.
-ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
+ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
                                ArrayRef<ConstantIntRanges> argRanges,
                                CmpMode mode);
 
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 622d875a63ace..b2f8b5a72d0ba 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -36,33 +36,6 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
-namespace {
-
-OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
-  if (range.isUninitialized())
-    return std::nullopt;
-  return range.getValue();
-}
-
-OptionalIntRanges
-getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
-  return getOptionalRange(lattice->getValue());
-}
-
-} // end namespace
-
-IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
-  unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
-  if (width == 0)
-    return {};
-
-  APInt umin = APInt::getMinValue(width);
-  APInt umax = APInt::getMaxValue(width);
-  APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
-  APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
-  return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
-}
-
 void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   Lattice::onUpdate(solver);
 
@@ -94,9 +67,12 @@ void IntegerRangeAnalysis::visitOperation(
     return setAllToEntryStates(results);
 
   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
-  auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
+  auto argRanges = llvm::map_to_vector(
+      operands, [](const IntegerValueRangeLattice *lattice) {
+        return lattice->getValue();
+      });
 
-  auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
+  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
     auto result = dyn_cast<OpResult>(v);
     if (!result)
       return;
@@ -106,9 +82,7 @@ void IntegerRangeAnalysis::visitOperation(
     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
     IntegerValueRange oldRange = lattice->getValue();
 
-    ChangeResult changed =
-        attrs ? lattice->join(IntegerValueRange{attrs})
-              : lattice->join(IntegerValueRange::getMaxRange(v));
+    ChangeResult changed = lattice->join(attrs);
 
     // Catch loop results with loop variant bounds and conservatively make
     // them [-inf, inf] so we don't circle around infinitely often (because
@@ -133,17 +107,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
     ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
-    // If the lattice on any operand is unitialized, bail out.
-    if (llvm::any_of(op->getOperands(), [&](Value value) {
-          return getLatticeElementFor(op, value)->getValue().isUninitialized();
-        }))
-      return;
-    SmallVector<OptionalIntRanges> argRanges(
-        llvm::map_range(op->getOperands(), [&](Value value) {
-          return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
-        }));
 
-    auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
+    auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
+      return getLatticeElementFor(op, value)->getValue();
+    });
+
+    auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
       auto arg = dyn_cast<BlockArgument>(v);
       if (!arg)
         return;
@@ -154,9 +123,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
       IntegerValueRange oldRange = lattice->getValue();
 
-      ChangeResult changed =
-          attrs ? lattice->join(IntegerValueRange{attrs})
-                : lattice->join(IntegerValueRange::getMaxRange(v));
+      ChangeResult changed = lattice->join(attrs);
 
       // Catch loop results with loop variant bounds and conservatively make
       // them [-inf, inf] so we don't circle around infinitely often (because
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index b59e5f9ec5a3e..9456c9e87a277 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -32,7 +32,7 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
-void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                           SetIntRangeFn setResultRange) {
   auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
   if (constAttr) {
@@ -45,11 +45,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::AddIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
-    return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
-  });
+  auto infer =
+      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+        return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -58,11 +59,12 @@ void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::SubIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
-    return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
-  });
+  auto infer =
+      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+        return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -71,11 +73,12 @@ void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MulIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
-    return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
-  });
+  auto infer =
+      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+        return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -84,18 +87,18 @@ void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::DivUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferDivU)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferDivU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // DivSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::DivSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferDivS)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -103,8 +106,9 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivUIOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferCeilDivU)(argRanges));
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(),
+                 inferFromIntegerValueRange(inferCeilDivU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -112,8 +116,9 @@ void arith::CeilDivUIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivSIOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferCeilDivS)(argRanges));
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(),
+                 inferFromIntegerValueRange(inferCeilDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -121,132 +126,132 @@ void arith::CeilDivSIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::FloorDivSIOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
   return setResultRange(getResult(),
-                        inferFromOptionals(inferFloorDivS)(argRanges));
+                        inferFromIntegerValueRange(inferFloorDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::RemUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferRemU)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferRemU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::RemSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferRemS)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferRemS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AndIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::AndIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferAnd)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferAnd)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // OrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::OrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::OrIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                      SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferOr)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferOr)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::XOrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::XOrIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferXor)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferXor)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MaxSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferMaxS)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferMaxS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MaxUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferMaxU)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferMaxU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MinSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferMinS)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferMinS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MinUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromOptionals(inferMinU)(argRanges));
+  setResultRange(getResult(), inferFromIntegerValueRange(inferMinU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ExtUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  if (!argRanges[0])
+  if (argRanges[0].isUninitialized())
     return;
 
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
+  setResultRange(getResult(), extUIRange(argRanges[0].getValue(), destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ExtSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  if (!argRanges[0])
+  if (argRanges[0].isUninitialized())
     return;
 
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
+  setResultRange(getResult(), extSIRange(argRanges[0].getValue(), destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::TruncIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                         SetIntRangeFn setResultRange) {
-  if (!argRanges[0])
+  if (argRanges[0].isUninitialized())
     return;
 
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
+  setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
 }
 
 //===----------------------------------------------------------------------===//
@@ -254,8 +259,8 @@ void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  if (!argRanges[0])
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+  if (argRanges[0].isUninitialized())
     return;
 
   Type sourceType = getOperand().getType();
@@ -264,11 +269,11 @@ void arith::IndexCastOp::inferResultRanges(
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
+    setResultRange(getResult(), extSIRange(argRanges[0].getValue(), destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
+    setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
   else
-    setResultRange(getResult(), *argRanges[0]);
+    setResultRange(getResult(), argRanges[0]);
 }
 
 //===----------------------------------------------------------------------===//
@@ -276,8 +281,8 @@ void arith::IndexCastOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastUIOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  if (!argRanges[0])
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+  if (argRanges[0].isUninitialized())
     return;
 
   Type sourceType = getOperand().getType();
@@ -286,30 +291,31 @@ void arith::IndexCastUIOp::inferResultRanges(
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
+    setResultRange(getResult(), extUIRange(argRanges[0].getValue(), destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
+    setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
   else
-    setResultRange(getResult(), *argRanges[0]);
+    setResultRange(getResult(), argRanges[0]);
 }
 
 //===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::CmpIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
   arith::CmpIPredicate arithPred = getPredicate();
   intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
-  const OptionalIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const IntegerValueRange &lhs = argRanges[0], &rhs = argRanges[1];
 
-  if (!lhs || !rhs)
+  if (lhs.isUninitialized() || rhs.isUninitialized())
     return;
 
   APInt min = APInt::getZero(1);
   APInt max = APInt::getAllOnes(1);
 
-  std::optional<bool> truthValue = intrange::evaluatePred(pred, *lhs, *rhs);
+  std::optional<bool> truthValue =
+      intrange::evaluatePred(pred, lhs.getValue(), rhs.getValue());
   if (truthValue.has_value() && *truthValue)
     min = max;
   else if (truthValue.has_value() && !(*truthValue))
@@ -322,32 +328,37 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-void arith::SelectOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::SelectOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                         SetIntRangeFn setResultRange) {
   std::optional<APInt> mbCondVal =
-      argRanges[0] ? argRanges[0]->getConstantValue() : std::nullopt;
+      !argRanges[0].isUninitialized()
+          ? argRanges[0].getValue().getConstantValue()
+          : std::nullopt;
+
+  const IntegerValueRange &trueCase = argRanges[1];
+  const IntegerValueRange &falseCase = argRanges[2];
 
   if (mbCondVal) {
     if (mbCondVal->isZero())
-      setResultRange(getResult(), argRanges[2]);
+      setResultRange(getResult(), falseCase);
     else
-      setResultRange(getResult(), argRanges[1]);
+      setResultRange(getResult(), trueCase);
     return;
   }
 
-  if (argRanges[1] && argRanges[2])
-    setResultRange(getResult(), argRanges[1]->rangeUnion(*argRanges[2]));
+  setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
 }
 
 //===----------------------------------------------------------------------===//
 // ShLIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ShLIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
-  });
+  auto infer =
+      inferFromIntegerValueRange([&](ArrayRef<ConstantIntRanges> ranges) {
+        return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -356,9 +367,9 @@ void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // ShRUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ShRUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals(inferShrU);
+  auto infer = inferFromIntegerValueRange(inferShrU);
   setResultRange(getResult(), infer(argRanges));
 }
 
@@ -366,8 +377,8 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // ShRSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ShRSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals(inferShrS);
+  auto infer = inferFromIntegerValueRange(inferShrS);
   setResultRange(getResult(), infer(argRanges));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 1342271029fa9..3676800ae0be5 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -84,18 +84,18 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   return std::nullopt;
 }
 
-void ClusterDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void ClusterDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
 }
 
-void ClusterIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void ClusterIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                     SetIntRangeFn setResultRange) {
   uint64_t max = kMaxClusterDim;
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void BlockDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                    SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal =
       getKnownLaunchDim(*this, LaunchDims::Block);
@@ -105,13 +105,13 @@ void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void BlockIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void BlockIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                   SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void GridDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                   SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
   if (knownVal)
@@ -120,23 +120,23 @@ void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void ThreadIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void ThreadIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                    SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void LaneIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void LaneIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                  SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
 }
 
-void SubgroupIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void SubgroupIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
 }
 
-void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void GlobalIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                    SetIntRangeFn setResultRange) {
   uint64_t blockDimMax =
       getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
@@ -146,26 +146,29 @@ void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
 }
 
-void NumSubgroupsOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void NumSubgroupsOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void SubgroupSizeOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void SubgroupSizeOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
 }
 
-void LaunchOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void LaunchOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                  SetIntRangeFn setResultRange) {
-  auto setRange = [&](const OptionalIntRanges &argRange, Value dimResult,
+  auto setRange = [&](const IntegerValueRange &argRange, Value dimResult,
                       Value idxResult) {
-    if (!argRange ||
-        argRange->umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
+    if (argRange.isUninitialized())
+      return;
+
+    const ConstantIntRanges &constRange = argRange.getValue();
+    if (constRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
       return;
 
     ConstantIntRanges dimRange =
-        argRange->intersection(getIndexRange(1, kMaxDim));
+        constRange.intersection(getIndexRange(1, kMaxDim));
     setResultRange(dimResult, dimRange);
     ConstantIntRanges idxRange =
         getIndexRange(0, dimRange.umax().getZExtValue() - 1);
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index cc6709f1253da..4d92957a86f92 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -22,13 +22,13 @@ using namespace mlir::intrange;
 // Constants
 //===----------------------------------------------------------------------===//
 
-void ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                    SetIntRangeFn setResultRange) {
   const APInt &value = getValue();
   setResultRange(getResult(), ConstantIntRanges::constant(value));
 }
 
-void BoolConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void BoolConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                        SetIntRangeFn setResultRange) {
   bool value = getValue();
   APInt asInt(/*numBits=*/1, value);
@@ -54,187 +54,207 @@ inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
   };
 }
 
-void AddOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void AddOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
-                        CmpMode::Both);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
+                            CmpMode::Both);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void SubOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void SubOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
-                        CmpMode::Both);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+                            CmpMode::Both);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void MulOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MulOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
-                        CmpMode::Both);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
+                            CmpMode::Both);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void DivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void DivUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
-                        CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+                            CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void DivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void DivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void CeilDivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CeilDivUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                    SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void CeilDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CeilDivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                    SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void FloorDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void FloorDivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                     SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
+      });
 
   return setResultRange(getResult(), infer(argRanges));
 }
 
-void RemSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void RemSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void RemUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void RemUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void MaxSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MaxSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void MaxUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MaxUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void MinSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MinSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void MinUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MinUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void ShlOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ShlOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
-                        CmpMode::Both);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
+                            CmpMode::Both);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void ShrSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ShrSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void ShrUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ShrUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void AndOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void AndOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void OrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void OrOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                              SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void XOrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void XOrOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
-  });
+  auto infer =
+      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+        return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -273,26 +293,30 @@ static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
   return ret;
 }
 
-void CastSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CastSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                 SetIntRangeFn setResultRange) {
-  Type sourceType = getOperand().getType();
-  Type destType = getResult().getType();
+  auto infer =
+      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+        Type sourceType = getOperand().getType();
+        Type destType = getResult().getType();
 
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/true);
-  });
+        return inferIndexCast(ranges[0], sourceType, destType,
+                              /*isSigned=*/true);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
 
-void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CastUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                 SetIntRangeFn setResultRange) {
-  Type sourceType = getOperand().getType();
-  Type destType = getResult().getType();
+  auto infer =
+      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+        Type sourceType = getOperand().getType();
+        Type destType = getResult().getType();
 
-  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
-    return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/false);
-  });
+        return inferIndexCast(ranges[0], sourceType, destType,
+                              /*isSigned=*/false);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -301,33 +325,35 @@ void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // CmpOp
 //===----------------------------------------------------------------------===//
 
-void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CmpOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
-    index::IndexCmpPredicate indexPred = getPred();
-    intrange::CmpPredicate pred =
-        static_cast<intrange::CmpPredicate>(indexPred);
-    const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
-
-    APInt min = APInt::getZero(1);
-    APInt max = APInt::getAllOnes(1);
-
-    std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
-
-    ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
-                      rhsTrunc = truncRange(rhs, indexMinWidth);
-    std::optional<bool> truthValue32 =
-        intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
-    if (truthValue64 == truthValue32) {
-      if (truthValue64.has_value() && *truthValue64)
-        min = max;
-      else if (truthValue64.has_value() && !(*truthValue64))
-        max = min;
-    }
-
-    return ConstantIntRanges::fromUnsigned(min, max);
-  });
+  auto infer =
+      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+        index::IndexCmpPredicate indexPred = getPred();
+        intrange::CmpPredicate pred =
+            static_cast<intrange::CmpPredicate>(indexPred);
+        const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
+
+        APInt min = APInt::getZero(1);
+        APInt max = APInt::getAllOnes(1);
+
+        std::optional<bool> truthValue64 =
+            intrange::evaluatePred(pred, lhs, rhs);
+
+        ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+                          rhsTrunc = truncRange(rhs, indexMinWidth);
+        std::optional<bool> truthValue32 =
+            intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+        if (truthValue64 == truthValue32) {
+          if (truthValue64.has_value() && *truthValue64)
+            min = max;
+          else if (truthValue64.has_value() && !(*truthValue64))
+            max = min;
+        }
+
+        return ConstantIntRanges::fromUnsigned(min, max);
+      });
 
   setResultRange(getResult(), infer(argRanges));
 }
@@ -336,7 +362,7 @@ void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
 //===----------------------------------------------------------------------===//
 
-void SizeOfOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void SizeOfOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                  SetIntRangeFn setResultRange) {
   unsigned storageWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index b3f6c0ee3cc32..1891f6a1756f3 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -126,3 +126,20 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   return os << "unsigned : [" << range.umin() << ", " << range.umax()
             << "] signed : [" << range.smin() << ", " << range.smax() << "]";
 }
+
+IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+  if (width == 0)
+    return {};
+
+  APInt umin = APInt::getMinValue(width);
+  APInt umax = APInt::getMaxValue(width);
+  APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+  APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+  return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
+  range.print(os);
+  return os;
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 78754680ae58d..43cca0d2c9845 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -36,20 +36,20 @@ using namespace mlir;
 using ConstArithFn =
     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
 
-std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>
-mlir::intrange::inferFromOptionals(intrange::InferRangeFn inferFn) {
+std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>
+mlir::intrange::inferFromIntegerValueRange(intrange::InferRangeFn inferFn) {
   return [inferFn = std::move(inferFn)](
-             ArrayRef<OptionalIntRanges> args) -> OptionalIntRanges {
+             ArrayRef<IntegerValueRange> args) -> IntegerValueRange {
     llvm::SmallVector<ConstantIntRanges> unpacked;
     unpacked.reserve(args.size());
 
-    for (const OptionalIntRanges &arg : args) {
-      if (!arg)
-        return std::nullopt;
-      unpacked.push_back(*arg);
+    for (const IntegerValueRange &arg : args) {
+      if (arg.isUninitialized())
+        return {};
+      unpacked.push_back(arg.getValue());
     }
 
-    return inferFn(unpacked);
+    return IntegerValueRange{inferFn(unpacked)};
   };
 }
 
@@ -93,7 +93,7 @@ static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
 //===----------------------------------------------------------------------===//
 
 ConstantIntRanges
-mlir::intrange::inferIndexOp(InferRangeFn inferFn,
+mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
                              ArrayRef<ConstantIntRanges> argRanges,
                              intrange::CmpMode mode) {
   ConstantIntRanges sixtyFour = inferFn(argRanges);
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 145b076c95a76..bb0687463c831 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -648,7 +648,7 @@ LogicalResult TestVerifiersOp::verifyRegions() {
 //===----------------------------------------------------------------------===//
 // TestWithBoundsOp
 
-void TestWithBoundsOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void TestWithBoundsOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                          SetIntRangeFn setResultRanges) {
   setResultRanges(getResult(), ConstantIntRanges{getUmin(), getUmax(),
                                                  getSmin(), getSmax()});
@@ -682,7 +682,7 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
 }
 
 void TestWithBoundsRegionOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRanges) {
   Value arg = getRegion().getArgument(0);
   setResultRanges(
       arg, ConstantIntRanges{getUmin(), getUmax(), getSmin(), getSmax()});
@@ -691,12 +691,12 @@ void TestWithBoundsRegionOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 // TestIncrementOp
 
-void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void TestIncrementOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
                                         SetIntRangeFn setResultRanges) {
-  if (!argRanges[0])
+  if (argRanges[0].isUninitialized())
     return;
 
-  const ConstantIntRanges &range = *argRanges[0];
+  const ConstantIntRanges &range = argRanges[0].getValue();
   APInt one(range.umin().getBitWidth(), 1);
   setResultRanges(getResult(), ConstantIntRanges{range.umin().uadd_sat(one),
                                                  range.umax().uadd_sat(one),
@@ -708,11 +708,11 @@ void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
 // TestReflectBoundsOp
 
 void TestReflectBoundsOp::inferResultRanges(
-    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
-  if (!argRanges[0])
+    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRanges) {
+  if (argRanges[0].isUninitialized())
     return;
 
-  const ConstantIntRanges &range = *argRanges[0];
+  const ConstantIntRanges &range = argRanges[0].getValue();
   MLIRContext *ctx = getContext();
   Builder b(ctx);
   Type sIntTy, uIntTy;

>From ef2c8121b35c254c63517a47760895e332c3fd25 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Sat, 25 May 2024 07:44:29 -0400
Subject: [PATCH 4/4] Add new IntegerRangeAnalysis interface method

This new method allows downstream implementers to easily opt into the
old behavior while providing an easy way to transition to the more
powerful interface methods.
---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  16 +-
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  12 +-
 .../include/mlir/Dialect/Index/IR/IndexOps.td |   2 +-
 .../mlir/Interfaces/InferIntRangeInterface.h  |  25 +-
 .../mlir/Interfaces/InferIntRangeInterface.td |  46 ++-
 .../Interfaces/Utils/InferIntRangeCommon.h    |   6 -
 .../DataFlow/IntegerRangeAnalysis.cpp         |   5 +-
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  | 172 ++++------
 .../GPU/IR/InferIntRangeInterfaceImpls.cpp    |  35 +-
 .../Index/IR/InferIntRangeInterfaceImpls.cpp  | 299 ++++++------------
 .../lib/Interfaces/InferIntRangeInterface.cpp |  31 ++
 .../Interfaces/Utils/InferIntRangeCommon.cpp  |  17 -
 .../Dialect/Arith/int-range-interface.mlir    |   1 +
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  31 +-
 mlir/test/lib/Dialect/Test/TestOps.td         |   9 +-
 15 files changed, 303 insertions(+), 404 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ead52332e8eec..46248dad3be9e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -49,7 +49,7 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
 // Base class for integer binary operations.
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
-      [DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+      [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
     Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
     Results<(outs SignlessIntegerLike:$result)>;
 
@@ -107,7 +107,7 @@ class Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
     Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
                            SignlessFixedWidthIntegerLike,
                            traits #
-                           [DeclareOpInterfaceMethods<InferIntRangeInterface>]>;
+                           [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>;
 // Cast from an integer type to a floating point type.
 class Arith_IToFCastOp<string mnemonic, list<Trait> traits = []> :
     Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike, traits>;
@@ -139,7 +139,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
 
 class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
-      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+      [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
     Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
       DefaultValuedAttr<
@@ -159,7 +159,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
     [ConstantLike, Pure,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
      AllTypesMatch<["value", "result"]>,
-     DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "integer or floating point constant";
   let description = [{
     The `constant` operation produces an SSA value equal to some integer or
@@ -1327,7 +1327,7 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
 
 def Arith_IndexCastOp
   : Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
-                 [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1346,7 +1346,7 @@ def Arith_IndexCastOp
 
 def Arith_IndexCastUIOp
   : Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
-                 [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+                 [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "unsigned cast between index and integer types";
   let description = [{
     Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1400,7 +1400,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
 
 def Arith_CmpIOp
   : Arith_CompareOpOfAnyRank<"cmpi",
-                             [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+                             [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "integer comparison operation";
   let description = [{
     The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1555,7 +1555,7 @@ class ScalarConditionOrMatchingShape<list<string> names> :
 def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
     ScalarConditionOrMatchingShape<["condition", "result"]>,
-    DeclareOpInterfaceMethods<InferIntRangeInterface>,
+    DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 1da68ed2176d8..10719aae5c8b4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -52,7 +52,7 @@ def GPU_DimensionAttr : EnumAttr<GPU_Dialect, GPU_Dimension, "dim">;
 class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
     GPU_Op<mnemonic, !listconcat(traits, [
         Pure,
-        DeclareOpInterfaceMethods<InferIntRangeInterface>,
+        DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
         DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>])>,
     Arguments<(ins GPU_DimensionAttr:$dimension)>, Results<(outs Index)> {
   let assemblyFormat = "$dimension attr-dict";
@@ -144,7 +144,7 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
 }
 
 def GPU_LaneIdOp : GPU_Op<"lane_id", [
-      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let description = [{
     Returns the lane id within the subgroup (warp/wave).
 
@@ -158,7 +158,7 @@ def GPU_LaneIdOp : GPU_Op<"lane_id", [
 }
 
 def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [
-      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
     Arguments<(ins)>, Results<(outs Index:$result)> {
   let description = [{
     Returns the subgroup id, i.e., the index of the current subgroup within the
@@ -190,7 +190,7 @@ def GPU_GlobalIdOp : GPU_IndexOp<"global_id"> {
 
 
 def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
-      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
     Arguments<(ins)>, Results<(outs Index:$result)> {
   let description = [{
     Returns the number of subgroups within a workgroup.
@@ -206,7 +206,7 @@ def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
 }
 
 def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [
-      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+      Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
     Arguments<(ins)>, Results<(outs Index:$result)> {
   let description = [{
     Returns the number of threads within a subgroup.
@@ -687,7 +687,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
 
 def GPU_LaunchOp : GPU_Op<"launch", [
       AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface,
-      DeclareOpInterfaceMethods<InferIntRangeInterface>,
+      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
       RecursiveMemoryEffects]>,
     Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index c6079cb8a98c8..a30ae9f739cbc 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td"
 /// Base class for Index dialect operations.
 class IndexOp<string mnemonic, list<Trait> traits = []>
     : Op<IndexDialect, mnemonic,
-      [DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
+      [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>] # traits>;
 
 //===----------------------------------------------------------------------===//
 // IndexBinaryOp
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 73013837f1227..0e107e88f5232 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -158,7 +158,30 @@ raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
 /// The type of the `setResultRanges` callback provided to ops implementing
 /// InferIntRangeInterface. It should be called once for each integer result
 /// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const IntegerValueRange &)>;
+using SetIntRangeFn =
+    llvm::function_ref<void(Value, const ConstantIntRanges &)>;
+
+/// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
+/// This is the `setResultRanges` callback for the IntegerValueRange based
+/// interface method.
+using SetIntLatticeFn =
+    llvm::function_ref<void(Value, const IntegerValueRange &)>;
+
+class InferIntRangeInterface;
+
+namespace intrange::detail {
+/// Default implementation of `inferResultRanges` which dispatches to the
+/// `inferResultRangesFromOptional`.
+void defaultInferResultRanges(InferIntRangeInterface interface,
+                              ArrayRef<IntegerValueRange> argRanges,
+                              SetIntLatticeFn setResultRanges);
+
+/// Default implementation of `inferResultRangesFromOptional` which dispatches
+/// to the `inferResultRanges`.
+void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
+                                          ArrayRef<ConstantIntRanges> argRanges,
+                                          SetIntRangeFn setResultRanges);
+} // end namespace intrange::detail
 } // end namespace mlir
 
 #include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index 795e67b8431bd..6ee436ce4d6c2 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -28,9 +28,10 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
       Infer the bounds on the results of this op given the bounds on its arguments.
       For each result value or block argument (that isn't a branch argument,
       since the dataflow analysis handles those case), the method should call
-      `setValueRange` with that `Value` as an argument. When `setValueRange`
-      is not called for some value, it will recieve a default value of the mimimum
-      and maximum values for its type (the unbounded range).
+      `setValueRange` with that `Value` as an argument. When implemented,
+      `setValueRange` should be called on all result values for the operation.
+      When operations take non-integer inputs, the
+     `inferResultRangesFromOptional` method should be implemented instead.
 
       When called on an op that also implements the RegionBranchOpInterface
       or BranchOpInterface, this method should not attempt to infer the values
@@ -39,14 +40,39 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
 
       This function will only be called when at least one result of the op is a
       scalar integer value or the op has a region.
+    }],
+    /*retTy=*/"void",
+    /*methodName=*/"inferResultRanges",
+    /*args=*/(ins "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+                  "::mlir::SetIntRangeFn":$setResultRanges),
+    /*methodBody=*/"",
+    /*defaultImplementation=*/[{
+      ::mlir::intrange::detail::defaultInferResultRangesFromOptional($_op,
+                                                                     argRanges,
+                                                                     setResultRanges);
+    }]>,
+
+    InterfaceMethod<[{
+      Infer the bounds on the results of this op given the lattice representation
+      of the bounds for its arguments. For each result value or block argument
+      (that isn't a branch argument, since the dataflow analysis handles
+      those case), the method should call `setValueRange` with that `Value`
+      as an argument. When implemented, `setValueRange` should be called on
+      all result values for the operation.
 
-      `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
-       order. Non-integer arguments will have the an unbounded range of width-0
-       APInts in their `argRanges` element.
+      This method allows for more precise implementations when operations
+      want to reason about inputs which may be undefined during the analysis.
     }],
-    "void", "inferResultRanges", (ins
-      "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
-      "::mlir::SetIntRangeFn":$setResultRanges)
-  >];
+    /*retTy=*/"void",
+    /*methodName=*/"inferResultRangesFromOptional",
+    /*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
+                  "::mlir::SetIntLatticeFn":$setResultRanges),
+    /*methodBody=*/"",
+    /*defaultImplementation=*/[{
+      ::mlir::intrange::detail::defaultInferResultRanges($_op,
+                                                         argRanges,
+                                                         setResultRanges);
+    }]>
+  ];
 }
 #endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 8746a1cfba85c..3988a8826498a 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -48,12 +48,6 @@ enum class OverflowFlags : uint32_t {
 using InferRangeWithOvfFlagsFn =
     function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
 
-/// Perform a pointwise extension of a function operating on `ConstantIntRanges`
-/// to a function operating on `IntegerValueRange` such that undefined input
-/// ranges propagate.
-InferIntegerValueRangeFn
-inferFromIntegerValueRange(intrange::InferRangeFn inferFn);
-
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index b2f8b5a72d0ba..9721620807a0f 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -61,7 +61,6 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
 void IntegerRangeAnalysis::visitOperation(
     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
     ArrayRef<IntegerValueRangeLattice *> results) {
-  // If the lattice on any operand is unitialized, bail out.
   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
   if (!inferrable)
     return setAllToEntryStates(results);
@@ -99,7 +98,7 @@ void IntegerRangeAnalysis::visitOperation(
     propagateIfChanged(lattice, changed);
   };
 
-  inferrable.inferResultRanges(argRanges, joinCallback);
+  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
 }
 
 void IntegerRangeAnalysis::visitNonControlFlowArguments(
@@ -140,7 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       propagateIfChanged(lattice, changed);
     };
 
-    inferrable.inferResultRanges(argRanges, joinCallback);
+    inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
     return;
   }
 
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 9456c9e87a277..462044417b5fb 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
+#include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "int-range-analysis"
@@ -32,7 +33,7 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
-void arith::ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                           SetIntRangeFn setResultRange) {
   auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
   if (constAttr) {
@@ -45,60 +46,48 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AddIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
-        return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-void arith::SubIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
-        return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MulIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
-        return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferDivU)(argRanges));
+  setResultRange(getResult(), inferDivU(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // DivSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferDivS)(argRanges));
+  setResultRange(getResult(), inferDivS(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -106,9 +95,8 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivUIOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferFromIntegerValueRange(inferCeilDivU)(argRanges));
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), inferCeilDivU(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -116,9 +104,8 @@ void arith::CeilDivUIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivSIOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferFromIntegerValueRange(inferCeilDivS)(argRanges));
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), inferCeilDivS(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -126,132 +113,122 @@ void arith::CeilDivSIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::FloorDivSIOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
-  return setResultRange(getResult(),
-                        inferFromIntegerValueRange(inferFloorDivS)(argRanges));
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  return setResultRange(getResult(), inferFloorDivS(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferRemU)(argRanges));
+  setResultRange(getResult(), inferRemU(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferRemS)(argRanges));
+  setResultRange(getResult(), inferRemS(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AndIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferAnd)(argRanges));
+  setResultRange(getResult(), inferAnd(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // OrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::OrIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferOr)(argRanges));
+  setResultRange(getResult(), inferOr(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::XOrIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferXor)(argRanges));
+  setResultRange(getResult(), inferXor(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferMaxS)(argRanges));
+  setResultRange(getResult(), inferMaxS(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferMaxU)(argRanges));
+  setResultRange(getResult(), inferMaxU(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferMinS)(argRanges));
+  setResultRange(getResult(), inferMinS(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferFromIntegerValueRange(inferMinU)(argRanges));
+  setResultRange(getResult(), inferMinU(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  if (argRanges[0].isUninitialized())
-    return;
-
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extUIRange(argRanges[0].getValue(), destWidth));
+  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  if (argRanges[0].isUninitialized())
-    return;
-
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extSIRange(argRanges[0].getValue(), destWidth));
+  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-void arith::TruncIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                         SetIntRangeFn setResultRange) {
-  if (argRanges[0].isUninitialized())
-    return;
-
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
+  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
@@ -259,19 +236,16 @@ void arith::TruncIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
-  if (argRanges[0].isUninitialized())
-    return;
-
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extSIRange(argRanges[0].getValue(), destWidth));
+    setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
+    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
   else
     setResultRange(getResult(), argRanges[0]);
 }
@@ -281,19 +255,16 @@ void arith::IndexCastOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastUIOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
-  if (argRanges[0].isUninitialized())
-    return;
-
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extUIRange(argRanges[0].getValue(), destWidth));
+    setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
+    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
   else
     setResultRange(getResult(), argRanges[0]);
 }
@@ -302,20 +273,16 @@ void arith::IndexCastUIOp::inferResultRanges(
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-void arith::CmpIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
   arith::CmpIPredicate arithPred = getPredicate();
   intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
-  const IntegerValueRange &lhs = argRanges[0], &rhs = argRanges[1];
-
-  if (lhs.isUninitialized() || rhs.isUninitialized())
-    return;
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
 
   APInt min = APInt::getZero(1);
   APInt max = APInt::getAllOnes(1);
 
-  std::optional<bool> truthValue =
-      intrange::evaluatePred(pred, lhs.getValue(), rhs.getValue());
+  std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
   if (truthValue.has_value() && *truthValue)
     min = max;
   else if (truthValue.has_value() && !(*truthValue))
@@ -328,12 +295,12 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-void arith::SelectOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
-                                        SetIntRangeFn setResultRange) {
+void arith::SelectOp::inferResultRangesFromOptional(
+    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
   std::optional<APInt> mbCondVal =
-      !argRanges[0].isUninitialized()
-          ? argRanges[0].getValue().getConstantValue()
-          : std::nullopt;
+      argRanges[0].isUninitialized()
+          ? std::nullopt
+          : argRanges[0].getValue().getConstantValue();
 
   const IntegerValueRange &trueCase = argRanges[1];
   const IntegerValueRange &falseCase = argRanges[2];
@@ -345,7 +312,6 @@ void arith::SelectOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
       setResultRange(getResult(), trueCase);
     return;
   }
-
   setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
 }
 
@@ -353,32 +319,26 @@ void arith::SelectOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
 // ShLIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShLIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([&](ArrayRef<ConstantIntRanges> ranges) {
-        return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
+                                                      getOverflowFlags())));
 }
 
 //===----------------------------------------------------------------------===//
 // ShRUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  auto infer = inferFromIntegerValueRange(inferShrU);
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferShrU(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ShRSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  auto infer = inferFromIntegerValueRange(inferShrS);
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferShrS(argRanges));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 3676800ae0be5..69017efb9a0e6 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -84,18 +84,18 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   return std::nullopt;
 }
 
-void ClusterDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
 }
 
-void ClusterIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                     SetIntRangeFn setResultRange) {
   uint64_t max = kMaxClusterDim;
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void BlockDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal =
       getKnownLaunchDim(*this, LaunchDims::Block);
@@ -105,13 +105,13 @@ void BlockDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void BlockIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                   SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void GridDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                   SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
   if (knownVal)
@@ -120,23 +120,23 @@ void GridDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void ThreadIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void LaneIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                  SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
 }
 
-void SubgroupIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
 }
 
-void GlobalIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
   uint64_t blockDimMax =
       getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
@@ -146,29 +146,24 @@ void GlobalIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
                  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
 }
 
-void NumSubgroupsOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void SubgroupSizeOp::inferResultRanges(ArrayRef<IntegerValueRange>,
+void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
 }
 
-void LaunchOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                  SetIntRangeFn setResultRange) {
-  auto setRange = [&](const IntegerValueRange &argRange, Value dimResult,
+  auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
                       Value idxResult) {
-    if (argRange.isUninitialized())
+    if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
       return;
-
-    const ConstantIntRanges &constRange = argRange.getValue();
-    if (constRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
-      return;
-
     ConstantIntRanges dimRange =
-        constRange.intersection(getIndexRange(1, kMaxDim));
+        argRange.intersection(getIndexRange(1, kMaxDim));
     setResultRange(dimResult, dimRange);
     ConstantIntRanges idxRange =
         getIndexRange(0, dimRange.umax().getZExtValue() - 1);
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index 4d92957a86f92..64adb6b850524 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
+#include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "int-range-analysis"
@@ -22,13 +23,13 @@ using namespace mlir::intrange;
 // Constants
 //===----------------------------------------------------------------------===//
 
-void ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
   const APInt &value = getValue();
   setResultRange(getResult(), ConstantIntRanges::constant(value));
 }
 
-void BoolConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
   bool value = getValue();
   APInt asInt(/*numBits=*/1, value);
@@ -48,215 +49,129 @@ void BoolConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
 // the inference function without any `OverflowFlags`.
 static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
 inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
-  return [inferWithOvfFn](
-             ArrayRef<ConstantIntRanges> argRanges) -> ConstantIntRanges {
+  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
     return inferWithOvfFn(argRanges, OverflowFlags::None);
   };
 }
 
-void AddOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
-                            CmpMode::Both);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
+                                           argRanges, CmpMode::Both));
 }
 
-void SubOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
-                            CmpMode::Both);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
+                                           argRanges, CmpMode::Both));
 }
 
-void MulOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
-                            CmpMode::Both);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
+                                           argRanges, CmpMode::Both));
 }
 
-void DivUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
-                            CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
 }
 
-void DivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
 }
 
-void CeilDivUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
 }
 
-void CeilDivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
 }
 
-void FloorDivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                     SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
-      });
-
-  return setResultRange(getResult(), infer(argRanges));
+  return setResultRange(
+      getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
 }
 
-void RemSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
 }
 
-void RemUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
 }
 
-void MaxSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
 }
 
-void MaxUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
 }
 
-void MinSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
 }
 
-void MinUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
 }
 
-void ShlOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
-                            CmpMode::Both);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
+                                           argRanges, CmpMode::Both));
 }
 
-void ShrSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
 }
 
-void ShrUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
 }
 
-void AndOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
 }
 
-void OrOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
 }
 
-void XOrOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
-        return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  setResultRange(getResult(),
+                 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
 }
 
 //===----------------------------------------------------------------------===//
@@ -293,76 +208,56 @@ static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
   return ret;
 }
 
-void CastSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
-        Type sourceType = getOperand().getType();
-        Type destType = getResult().getType();
-
-        return inferIndexCast(ranges[0], sourceType, destType,
-                              /*isSigned=*/true);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  Type sourceType = getOperand().getType();
+  Type destType = getResult().getType();
+  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
+                                             /*isSigned=*/true));
 }
 
-void CastUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
-        Type sourceType = getOperand().getType();
-        Type destType = getResult().getType();
-
-        return inferIndexCast(ranges[0], sourceType, destType,
-                              /*isSigned=*/false);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  Type sourceType = getOperand().getType();
+  Type destType = getResult().getType();
+  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
+                                             /*isSigned=*/false));
 }
 
 //===----------------------------------------------------------------------===//
 // CmpOp
 //===----------------------------------------------------------------------===//
 
-void CmpOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  auto infer =
-      inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
-        index::IndexCmpPredicate indexPred = getPred();
-        intrange::CmpPredicate pred =
-            static_cast<intrange::CmpPredicate>(indexPred);
-        const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
-
-        APInt min = APInt::getZero(1);
-        APInt max = APInt::getAllOnes(1);
-
-        std::optional<bool> truthValue64 =
-            intrange::evaluatePred(pred, lhs, rhs);
-
-        ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
-                          rhsTrunc = truncRange(rhs, indexMinWidth);
-        std::optional<bool> truthValue32 =
-            intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
-        if (truthValue64 == truthValue32) {
-          if (truthValue64.has_value() && *truthValue64)
-            min = max;
-          else if (truthValue64.has_value() && !(*truthValue64))
-            max = min;
-        }
-
-        return ConstantIntRanges::fromUnsigned(min, max);
-      });
-
-  setResultRange(getResult(), infer(argRanges));
+  index::IndexCmpPredicate indexPred = getPred();
+  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
+  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  APInt min = APInt::getZero(1);
+  APInt max = APInt::getAllOnes(1);
+
+  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
+
+  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+                    rhsTrunc = truncRange(rhs, indexMinWidth);
+  std::optional<bool> truthValue32 =
+      intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+  if (truthValue64 == truthValue32) {
+    if (truthValue64.has_value() && *truthValue64)
+      min = max;
+    else if (truthValue64.has_value() && !(*truthValue64))
+      max = min;
+  }
+  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
 }
 
 //===----------------------------------------------------------------------===//
 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
 //===----------------------------------------------------------------------===//
 
-void SizeOfOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                  SetIntRangeFn setResultRange) {
   unsigned storageWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 1891f6a1756f3..d879b93586899 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -143,3 +143,34 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
   range.print(os);
   return os;
 }
+
+void mlir::intrange::detail::defaultInferResultRanges(
+    InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
+    SetIntLatticeFn setResultRanges) {
+  llvm::SmallVector<ConstantIntRanges> unpacked;
+  unpacked.reserve(argRanges.size());
+
+  for (const IntegerValueRange &range : argRanges) {
+    if (range.isUninitialized())
+      return;
+    unpacked.push_back(range.getValue());
+  }
+
+  interface.inferResultRanges(
+      unpacked,
+      [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
+        setResultRanges(value, IntegerValueRange{argRanges});
+      });
+}
+
+void mlir::intrange::detail::defaultInferResultRangesFromOptional(
+    InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
+    SetIntRangeFn setResultRanges) {
+  auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
+  interface.inferResultRangesFromOptional(
+      ranges,
+      [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
+        if (!argRanges.isUninitialized())
+          setResultRanges(value, argRanges.getValue());
+      });
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 43cca0d2c9845..5b8d35e7bd519 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -36,23 +36,6 @@ using namespace mlir;
 using ConstArithFn =
     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
 
-std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>
-mlir::intrange::inferFromIntegerValueRange(intrange::InferRangeFn inferFn) {
-  return [inferFn = std::move(inferFn)](
-             ArrayRef<IntegerValueRange> args) -> IntegerValueRange {
-    llvm::SmallVector<ConstantIntRanges> unpacked;
-    unpacked.reserve(args.size());
-
-    for (const IntegerValueRange &arg : args) {
-      if (arg.isUninitialized())
-        return {};
-      unpacked.push_back(arg.getValue());
-    }
-
-    return IntegerValueRange{inferFn(unpacked)};
-  };
-}
-
 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
 /// If either computation overflows, make the result unbounded.
 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index fdeb8a2e6c935..60f0ab41afa48 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -917,3 +917,4 @@ func.func @test_cmpf_propagates(%a: f32, %b: f32) -> index {
   %2 = test.reflect_bounds %1 : index
   func.return %2 : index
 }
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index bb0687463c831..b058a8e1abbcb 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -648,10 +648,9 @@ LogicalResult TestVerifiersOp::verifyRegions() {
 //===----------------------------------------------------------------------===//
 // TestWithBoundsOp
 
-void TestWithBoundsOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                          SetIntRangeFn setResultRanges) {
-  setResultRanges(getResult(), ConstantIntRanges{getUmin(), getUmax(),
-                                                 getSmin(), getSmax()});
+  setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
 }
 
 //===----------------------------------------------------------------------===//
@@ -682,37 +681,29 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
 }
 
 void TestWithBoundsRegionOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRanges) {
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
   Value arg = getRegion().getArgument(0);
-  setResultRanges(
-      arg, ConstantIntRanges{getUmin(), getUmax(), getSmin(), getSmax()});
+  setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
 }
 
 //===----------------------------------------------------------------------===//
 // TestIncrementOp
 
-void TestIncrementOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                         SetIntRangeFn setResultRanges) {
-  if (argRanges[0].isUninitialized())
-    return;
-
-  const ConstantIntRanges &range = argRanges[0].getValue();
+  const ConstantIntRanges &range = argRanges[0];
   APInt one(range.umin().getBitWidth(), 1);
-  setResultRanges(getResult(), ConstantIntRanges{range.umin().uadd_sat(one),
-                                                 range.umax().uadd_sat(one),
-                                                 range.smin().sadd_sat(one),
-                                                 range.smax().sadd_sat(one)});
+  setResultRanges(getResult(),
+                  {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+                   range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
 }
 
 //===----------------------------------------------------------------------===//
 // TestReflectBoundsOp
 
 void TestReflectBoundsOp::inferResultRanges(
-    ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRanges) {
-  if (argRanges[0].isUninitialized())
-    return;
-
-  const ConstantIntRanges &range = argRanges[0].getValue();
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  const ConstantIntRanges &range = argRanges[0];
   MLIRContext *ctx = getContext();
   Builder b(ctx);
   Type sIntTy, uIntTy;
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 18324482153a5..9d7e0a7928ab8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2750,7 +2750,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
 def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
 
 def TestWithBoundsOp : TEST_Op<"with_bounds",
-                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                          [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
                            NoMemoryEffect]> {
   let arguments = (ins APIntAttr:$umin,
                        APIntAttr:$umax,
@@ -2762,7 +2762,7 @@ def TestWithBoundsOp : TEST_Op<"with_bounds",
 }
 
 def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
-                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                          [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
                            SingleBlock, NoTerminator]> {
   let arguments = (ins APIntAttr:$umin,
                        APIntAttr:$umax,
@@ -2774,7 +2774,7 @@ def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
 }
 
 def TestIncrementOp : TEST_Op<"increment",
-                         [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
                          NoMemoryEffect, AllTypesMatch<["value", "result"]>]> {
   let arguments = (ins InferIntRangeType:$value);
   let results = (outs InferIntRangeType:$result);
@@ -2783,7 +2783,8 @@ def TestIncrementOp : TEST_Op<"increment",
 }
 
 def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
-                         [DeclareOpInterfaceMethods<InferIntRangeInterface>, AllTypesMatch<["value", "result"]>]> {
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+                          AllTypesMatch<["value", "result"]>]> {
   let arguments = (ins InferIntRangeType:$value,
                        OptionalAttr<APIntAttr>:$umin,
                        OptionalAttr<APIntAttr>:$umax,



More information about the Mlir-commits mailing list