[Mlir-commits] [mlir] [mlir][dataflow] Fix for integer range analysis propagation bug (PR #93199)

Spenser Bauman llvmlistbot at llvm.org
Fri May 24 04:11:06 PDT 2024


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/93199

>From 59da46ce5e4233043769949af641252df51ba894 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Thu, 23 May 2024 08:18:55 -0400
Subject: [PATCH 1/2] [mlir][dataflow] Fix for integer range analysis
 propagation bug

Integer range analysis will not update the range of an operation when
any of the inferred input lattices are uninitialized. In the current
behavior, all lattice values for non integer types are uninitialized.

For operations like arith.cmpf

```mlir
%3 = arith.cmpf ugt, %arg0, %arg1 : f32
```

that will result in the range of the output also being uninitialized,
and so on for any consumer of the arith.cmpf result. When control-flow
ops are involved, the lack of propagation results in incorrect ranges,
as the back edges for loop carried values are not properly joined with
the definitions from the body region.

For example, an scf.while loop whose body region produces a value that
is in a dataflow relationship with some floating-point values through
an arith.cmpf operation:

```mlir
func.func @test_bad_range(%arg0: f32, %arg1: f32) -> (index, index) {
  %c4 = arith.constant 4 : index
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index

  %3 = arith.cmpf ugt, %arg0, %arg1 : f32

  %1:2 = scf.while (%arg2 = %c0, %arg3 = %c0) : (index, index) -> (index, index) {
    %2 = arith.cmpi ult, %arg2, %c4 : index
    scf.condition(%2) %arg2, %arg3 : index, index
  } do {
  ^bb0(%arg2: index, %arg3: index):
    %4 = arith.select %3, %arg3, %arg3 : index
    %5 = arith.addi %arg2, %c1 : index
    scf.yield %5, %4 : index, index
  }

  return %1#0, %1#1 : index, index
}
```

The existing behavior results in the control condition %2 being
optimized to true, turning the while loop into an infinite loop. The
update to %arg2 through the body region is never factored into the range
calculation, as the ranges for the body ops all test as uninitialized.

This change causes all values initialized with setToEntryState to
be set to some initialized range, even if the values are not integers.
---
 .../Analysis/DataFlow/IntegerRangeAnalysis.cpp |  2 --
 .../Dialect/Arith/int-range-interface.mlir     | 18 ++++++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index a82c30717e275..b69b2e0416209 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -38,8 +38,6 @@ using namespace mlir::dataflow;
 
 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
-  if (width == 0)
-    return {};
   APInt umin = APInt::getMinValue(width);
   APInt umax = APInt::getMaxValue(width);
   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 5b538197a0c11..fdeb8a2e6c935 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -899,3 +899,21 @@ func.func @test_shl_i8_nowrap() -> i8 {
   %2 = test.reflect_bounds %1 : i8
   return %2: i8
 }
+
+/// A test case to ensure that the ranges for unsupported ops are initialized
+/// properly to maxRange, rather than left uninitialized.
+/// In this test case, the previous behavior would leave the ranges for %a and
+/// %b uninitialized, resulting in arith.cmpf's range not being updated, even
+/// though it has an integer valued result.
+
+// CHECK-LABEL: func @test_cmpf_propagates
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+func.func @test_cmpf_propagates(%a: f32, %b: f32) -> index {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+
+  %0 = arith.cmpf ueq, %a, %b : f32
+  %1 = arith.select %0, %c1, %c2 : index
+  %2 = test.reflect_bounds %1 : index
+  func.return %2 : index
+}

>From 38188164f25054cbde9b8c074a65f8bdd1c30b43 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Thu, 23 May 2024 15:21:25 -0400
Subject: [PATCH 2/2] Rework integer range analysis interfaces

Modify the integer range analysis interfaces to handle uninitialized
values by allowing the inferred input ranges to be optional.
---
 .../Analysis/DataFlow/IntegerRangeAnalysis.h  |   2 +-
 .../mlir/Interfaces/InferIntRangeInterface.h  |   3 +-
 .../mlir/Interfaces/InferIntRangeInterface.td |   2 +-
 .../Interfaces/Utils/InferIntRangeCommon.h    |   7 +-
 .../DataFlow/IntegerRangeAnalysis.cpp         |  45 +--
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  | 167 ++++++-----
 .../GPU/IR/InferIntRangeInterfaceImpls.cpp    |  32 ++-
 .../Index/IR/InferIntRangeInterfaceImpls.cpp  | 265 ++++++++++++------
 .../Interfaces/Utils/InferIntRangeCommon.cpp  |  17 ++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  31 +-
 10 files changed, 366 insertions(+), 205 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 8bd7cf880c6af..fb07013041c0e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -33,7 +33,7 @@ class IntegerValueRange {
   static IntegerValueRange getMaxRange(Value value);
 
   /// Create an integer value range lattice value.
-  IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+  IntegerValueRange(OptionalIntRanges value = std::nullopt)
       : value(std::move(value)) {}
 
   /// Whether the range is uninitialized. This happens when the state hasn't
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 05064a72ef02e..3d499b420eadd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -105,10 +105,11 @@ class ConstantIntRanges {
 
 raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
 
+using OptionalIntRanges = std::optional<ConstantIntRanges>;
 /// The type of the `setResultRanges` callback provided to ops implementing
 /// InferIntRangeInterface. It should be called once for each integer result
 /// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
 } // end namespace mlir
 
 #include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index dbdc526c6f10b..f8e2c98d87cdb 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
        APInts in their `argRanges` element.
     }],
     "void", "inferResultRanges", (ins
-      "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+      "::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
       "::mlir::SetIntRangeFn":$setResultRanges)
   >];
 }
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 851bb534bc7ee..9e3b04535dcab 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -25,7 +25,10 @@ namespace intrange {
 /// abstracted away here to permit writing the function that handles both
 /// 64- and 32-bit index types.
 using InferRangeFn =
-    function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+    std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+
+using OptionalRangeFn =
+    std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
 
 static constexpr unsigned indexMinWidth = 32;
 static constexpr unsigned indexMaxWidth = 64;
@@ -44,6 +47,8 @@ enum class OverflowFlags : uint32_t {
 using InferRangeWithOvfFlagsFn =
     function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
 
+OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
+
 /// Compute `inferFn` on `ranges`, whose size should be the index storage
 /// bitwidth. Then, compute the function on `argRanges` again after truncating
 /// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index b69b2e0416209..622d875a63ace 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -36,8 +36,26 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
+namespace {
+
+OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
+  if (range.isUninitialized())
+    return std::nullopt;
+  return range.getValue();
+}
+
+OptionalIntRanges
+getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
+  return getOptionalRange(lattice->getValue());
+}
+
+} // end namespace
+
 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+  if (width == 0)
+    return {};
+
   APInt umin = APInt::getMinValue(width);
   APInt umax = APInt::getMaxValue(width);
   APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
@@ -71,23 +89,14 @@ void IntegerRangeAnalysis::visitOperation(
     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
     ArrayRef<IntegerValueRangeLattice *> results) {
   // If the lattice on any operand is unitialized, bail out.
-  if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
-        return lattice->getValue().isUninitialized();
-      })) {
-    return;
-  }
-
   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
   if (!inferrable)
     return setAllToEntryStates(results);
 
   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
-  SmallVector<ConstantIntRanges> argRanges(
-      llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
-        return val->getValue().getValue();
-      }));
+  auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
 
-  auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+  auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
     auto result = dyn_cast<OpResult>(v);
     if (!result)
       return;
@@ -97,7 +106,9 @@ void IntegerRangeAnalysis::visitOperation(
     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
     IntegerValueRange oldRange = lattice->getValue();
 
-    ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+    ChangeResult changed =
+        attrs ? lattice->join(IntegerValueRange{attrs})
+              : lattice->join(IntegerValueRange::getMaxRange(v));
 
     // Catch loop results with loop variant bounds and conservatively make
     // them [-inf, inf] so we don't circle around infinitely often (because
@@ -127,12 +138,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
           return getLatticeElementFor(op, value)->getValue().isUninitialized();
         }))
       return;
-    SmallVector<ConstantIntRanges> argRanges(
+    SmallVector<OptionalIntRanges> argRanges(
         llvm::map_range(op->getOperands(), [&](Value value) {
-          return getLatticeElementFor(op, value)->getValue().getValue();
+          return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
         }));
 
-    auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+    auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
       auto arg = dyn_cast<BlockArgument>(v);
       if (!arg)
         return;
@@ -143,7 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
       IntegerValueRange oldRange = lattice->getValue();
 
-      ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+      ChangeResult changed =
+          attrs ? lattice->join(IntegerValueRange{attrs})
+                : lattice->join(IntegerValueRange::getMaxRange(v));
 
       // Catch loop results with loop variant bounds and conservatively make
       // them [-inf, inf] so we don't circle around infinitely often (because
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index fbe2ecab8adca..b59e5f9ec5a3e 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
-#include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "int-range-analysis"
@@ -33,7 +32,7 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
-void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                           SetIntRangeFn setResultRange) {
   auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
   if (constAttr) {
@@ -46,48 +45,57 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // AddIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // SubIOp
 //===----------------------------------------------------------------------===//
 
-void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MulIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // DivUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::DivUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferDivU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferDivU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // DivSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferDivS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -95,8 +103,8 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivUIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferCeilDivU(argRanges));
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), inferFromOptionals(inferCeilDivU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -104,8 +112,8 @@ void arith::CeilDivUIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::CeilDivSIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferCeilDivS(argRanges));
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), inferFromOptionals(inferCeilDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -113,122 +121,132 @@ void arith::CeilDivSIOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::FloorDivSIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
-  return setResultRange(getResult(), inferFloorDivS(argRanges));
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  return setResultRange(getResult(),
+                        inferFromOptionals(inferFloorDivS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::RemUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferRemU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferRemU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // RemSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::RemSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferRemS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferRemS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // AndIOp
 //===----------------------------------------------------------------------===//
 
-void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::AndIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAnd(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferAnd)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // OrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::OrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferOr(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferOr)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
 
-void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::XOrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferXor(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferXor)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MaxSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMaxS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMaxS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MaxUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MaxUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMaxU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMaxU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MinSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMinS(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMinS)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // MinUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MinUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMinU(argRanges));
+  setResultRange(getResult(), inferFromOptionals(inferMinU)(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ExtUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+  setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ExtSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+  setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                         SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   unsigned destWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
-  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+  setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
 }
 
 //===----------------------------------------------------------------------===//
@@ -236,18 +254,21 @@ void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+    setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+    setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
   else
-    setResultRange(getResult(), argRanges[0]);
+    setResultRange(getResult(), *argRanges[0]);
 }
 
 //===----------------------------------------------------------------------===//
@@ -255,34 +276,40 @@ void arith::IndexCastOp::inferResultRanges(
 //===----------------------------------------------------------------------===//
 
 void arith::IndexCastUIOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+  if (!argRanges[0])
+    return;
+
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
 
   if (srcWidth < destWidth)
-    setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+    setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
   else if (srcWidth > destWidth)
-    setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+    setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
   else
-    setResultRange(getResult(), argRanges[0]);
+    setResultRange(getResult(), *argRanges[0]);
 }
 
 //===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
   arith::CmpIPredicate arithPred = getPredicate();
   intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  const OptionalIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+  if (!lhs || !rhs)
+    return;
 
   APInt min = APInt::getZero(1);
   APInt max = APInt::getAllOnes(1);
 
-  std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
+  std::optional<bool> truthValue = intrange::evaluatePred(pred, *lhs, *rhs);
   if (truthValue.has_value() && *truthValue)
     min = max;
   else if (truthValue.has_value() && !(*truthValue))
@@ -295,9 +322,10 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::SelectOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                         SetIntRangeFn setResultRange) {
-  std::optional<APInt> mbCondVal = argRanges[0].getConstantValue();
+  std::optional<APInt> mbCondVal =
+      argRanges[0] ? argRanges[0]->getConstantValue() : std::nullopt;
 
   if (mbCondVal) {
     if (mbCondVal->isZero())
@@ -306,33 +334,40 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
       setResultRange(getResult(), argRanges[1]);
     return;
   }
-  setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
+
+  if (argRanges[1] && argRanges[2])
+    setResultRange(getResult(), argRanges[1]->rangeUnion(*argRanges[2]));
 }
 
 //===----------------------------------------------------------------------===//
 // ShLIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ShRUIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShrU(argRanges));
+  auto infer = inferFromOptionals(inferShrU);
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // ShRSIOp
 //===----------------------------------------------------------------------===//
 
-void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShRSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferShrS(argRanges));
+  auto infer = inferFromOptionals(inferShrS);
+  setResultRange(getResult(), infer(argRanges));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 69017efb9a0e6..1342271029fa9 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -84,18 +84,18 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   return std::nullopt;
 }
 
-void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ClusterDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
 }
 
-void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ClusterIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                     SetIntRangeFn setResultRange) {
   uint64_t max = kMaxClusterDim;
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                    SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal =
       getKnownLaunchDim(*this, LaunchDims::Block);
@@ -105,13 +105,13 @@ void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void BlockIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                   SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                   SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
   if (knownVal)
@@ -120,23 +120,23 @@ void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
     setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ThreadIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                    SetIntRangeFn setResultRange) {
   uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
   setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
-void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void LaneIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                  SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
 }
 
-void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void SubgroupIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                      SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
 }
 
-void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                    SetIntRangeFn setResultRange) {
   uint64_t blockDimMax =
       getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
@@ -146,24 +146,26 @@ void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
 }
 
-void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void NumSubgroupsOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
-void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void SubgroupSizeOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
                                        SetIntRangeFn setResultRange) {
   setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
 }
 
-void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void LaunchOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                  SetIntRangeFn setResultRange) {
-  auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
+  auto setRange = [&](const OptionalIntRanges &argRange, Value dimResult,
                       Value idxResult) {
-    if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
+    if (!argRange ||
+        argRange->umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
       return;
+
     ConstantIntRanges dimRange =
-        argRange.intersection(getIndexRange(1, kMaxDim));
+        argRange->intersection(getIndexRange(1, kMaxDim));
     setResultRange(dimResult, dimRange);
     ConstantIntRanges idxRange =
         getIndexRange(0, dimRange.umax().getZExtValue() - 1);
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index 64adb6b850524..cc6709f1253da 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
-#include "llvm/Support/Debug.h"
 #include <optional>
 
 #define DEBUG_TYPE "int-range-analysis"
@@ -23,13 +22,13 @@ using namespace mlir::intrange;
 // Constants
 //===----------------------------------------------------------------------===//
 
-void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
   const APInt &value = getValue();
   setResultRange(getResult(), ConstantIntRanges::constant(value));
 }
 
-void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void BoolConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
   bool value = getValue();
   APInt asInt(/*numBits=*/1, value);
@@ -49,129 +48,195 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 // the inference function without any `OverflowFlags`.
 static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
 inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
-  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+  return [inferWithOvfFn](
+             ArrayRef<ConstantIntRanges> argRanges) -> ConstantIntRanges {
     return inferWithOvfFn(argRanges, OverflowFlags::None);
   };
 }
 
-void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void AddOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void SubOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MulOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void DivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+                        CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void DivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CeilDivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CeilDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                    SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void FloorDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                     SetIntRangeFn setResultRange) {
-  return setResultRange(
-      getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
+  });
+
+  return setResultRange(getResult(), infer(argRanges));
 }
 
-void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void RemSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void RemUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MaxSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MaxUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MinSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MinUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShlOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
-                                           argRanges, CmpMode::Both));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
+                        CmpMode::Both);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShrSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShrUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void AndOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void OrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                              SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void XOrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  setResultRange(getResult(),
-                 inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
@@ -208,56 +273,70 @@ static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
   return ret;
 }
 
-void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CastSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
-  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
-                                             /*isSigned=*/true));
+
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/true);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
-void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                 SetIntRangeFn setResultRange) {
   Type sourceType = getOperand().getType();
   Type destType = getResult().getType();
-  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
-                                             /*isSigned=*/false));
+
+  auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+    return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/false);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // CmpOp
 //===----------------------------------------------------------------------===//
 
-void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                               SetIntRangeFn setResultRange) {
-  index::IndexCmpPredicate indexPred = getPred();
-  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
-  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
-  APInt min = APInt::getZero(1);
-  APInt max = APInt::getAllOnes(1);
-
-  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
-
-  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
-                    rhsTrunc = truncRange(rhs, indexMinWidth);
-  std::optional<bool> truthValue32 =
-      intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
-  if (truthValue64 == truthValue32) {
-    if (truthValue64.has_value() && *truthValue64)
-      min = max;
-    else if (truthValue64.has_value() && !(*truthValue64))
-      max = min;
-  }
-  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+  auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+    index::IndexCmpPredicate indexPred = getPred();
+    intrange::CmpPredicate pred =
+        static_cast<intrange::CmpPredicate>(indexPred);
+    const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
+
+    APInt min = APInt::getZero(1);
+    APInt max = APInt::getAllOnes(1);
+
+    std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
+
+    ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+                      rhsTrunc = truncRange(rhs, indexMinWidth);
+    std::optional<bool> truthValue32 =
+        intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+    if (truthValue64 == truthValue32) {
+      if (truthValue64.has_value() && *truthValue64)
+        min = max;
+      else if (truthValue64.has_value() && !(*truthValue64))
+        max = min;
+    }
+
+    return ConstantIntRanges::fromUnsigned(min, max);
+  });
+
+  setResultRange(getResult(), infer(argRanges));
 }
 
 //===----------------------------------------------------------------------===//
 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
 //===----------------------------------------------------------------------===//
 
-void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void SizeOfOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                  SetIntRangeFn setResultRange) {
   unsigned storageWidth =
       ConstantIntRanges::getStorageBitwidth(getResult().getType());
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index fe1a67d628738..78754680ae58d 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -36,6 +36,23 @@ using namespace mlir;
 using ConstArithFn =
     function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
 
+std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>
+mlir::intrange::inferFromOptionals(intrange::InferRangeFn inferFn) {
+  return [inferFn = std::move(inferFn)](
+             ArrayRef<OptionalIntRanges> args) -> OptionalIntRanges {
+    llvm::SmallVector<ConstantIntRanges> unpacked;
+    unpacked.reserve(args.size());
+
+    for (const OptionalIntRanges &arg : args) {
+      if (!arg)
+        return std::nullopt;
+      unpacked.push_back(*arg);
+    }
+
+    return inferFn(unpacked);
+  };
+}
+
 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
 /// If either computation overflows, make the result unbounded.
 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b058a8e1abbcb..145b076c95a76 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -648,9 +648,10 @@ LogicalResult TestVerifiersOp::verifyRegions() {
 //===----------------------------------------------------------------------===//
 // TestWithBoundsOp
 
-void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void TestWithBoundsOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                          SetIntRangeFn setResultRanges) {
-  setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+  setResultRanges(getResult(), ConstantIntRanges{getUmin(), getUmax(),
+                                                 getSmin(), getSmax()});
 }
 
 //===----------------------------------------------------------------------===//
@@ -681,29 +682,37 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
 }
 
 void TestWithBoundsRegionOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
   Value arg = getRegion().getArgument(0);
-  setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+  setResultRanges(
+      arg, ConstantIntRanges{getUmin(), getUmax(), getSmin(), getSmax()});
 }
 
 //===----------------------------------------------------------------------===//
 // TestIncrementOp
 
-void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
                                         SetIntRangeFn setResultRanges) {
-  const ConstantIntRanges &range = argRanges[0];
+  if (!argRanges[0])
+    return;
+
+  const ConstantIntRanges &range = *argRanges[0];
   APInt one(range.umin().getBitWidth(), 1);
-  setResultRanges(getResult(),
-                  {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
-                   range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+  setResultRanges(getResult(), ConstantIntRanges{range.umin().uadd_sat(one),
+                                                 range.umax().uadd_sat(one),
+                                                 range.smin().sadd_sat(one),
+                                                 range.smax().sadd_sat(one)});
 }
 
 //===----------------------------------------------------------------------===//
 // TestReflectBoundsOp
 
 void TestReflectBoundsOp::inferResultRanges(
-    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
-  const ConstantIntRanges &range = argRanges[0];
+    ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  if (!argRanges[0])
+    return;
+
+  const ConstantIntRanges &range = *argRanges[0];
   MLIRContext *ctx = getContext();
   Builder b(ctx);
   Type sIntTy, uIntTy;



More information about the Mlir-commits mailing list