[Mlir-commits] [mlir] [mlir][dataflow] Fix for integer range analysis propagation bug (PR #93199)
Spenser Bauman
llvmlistbot at llvm.org
Fri May 24 09:28:17 PDT 2024
https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/93199
>From 59da46ce5e4233043769949af641252df51ba894 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Thu, 23 May 2024 08:18:55 -0400
Subject: [PATCH 1/3] [mlir][dataflow] Fix for integer range analysis
propagation bug
Integer range analysis will not update the range of an operation when
any of the inferred input lattices are uninitialized. In the current
behavior, all lattice values for non integer types are uninitialized.
For operations like arith.cmpf
```mlir
%3 = arith.cmpf ugt, %arg0, %arg1 : f32
```
that will result in the range of the output also being uninitialized,
and so on for any consumer of the arith.cmpf result. When control-flow
ops are involved, the lack of propagation results in incorrect ranges,
as the back edges for loop carried values are not properly joined with
the definitions from the body region.
For example, an scf.while loop whose body region produces a value that
is in a dataflow relationship with some floating-point values through
an arith.cmpf operation:
```mlir
func.func @test_bad_range(%arg0: f32, %arg1: f32) -> (index, index) {
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%3 = arith.cmpf ugt, %arg0, %arg1 : f32
%1:2 = scf.while (%arg2 = %c0, %arg3 = %c0) : (index, index) -> (index, index) {
%2 = arith.cmpi ult, %arg2, %c4 : index
scf.condition(%2) %arg2, %arg3 : index, index
} do {
^bb0(%arg2: index, %arg3: index):
%4 = arith.select %3, %arg3, %arg3 : index
%5 = arith.addi %arg2, %c1 : index
scf.yield %5, %4 : index, index
}
return %1#0, %1#1 : index, index
}
```
The existing behavior results in the control condition %2 being
optimized to true, turning the while loop into an infinite loop. The
update to %arg2 through the body region is never factored into the range
calculation, as the ranges for the body ops all test as uninitialized.
This change causes all values initialized with setToEntryState to
be set to some initialized range, even if the values are not integers.
---
.../Analysis/DataFlow/IntegerRangeAnalysis.cpp | 2 --
.../Dialect/Arith/int-range-interface.mlir | 18 ++++++++++++++++++
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index a82c30717e275..b69b2e0416209 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -38,8 +38,6 @@ using namespace mlir::dataflow;
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
- if (width == 0)
- return {};
APInt umin = APInt::getMinValue(width);
APInt umax = APInt::getMaxValue(width);
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 5b538197a0c11..fdeb8a2e6c935 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -899,3 +899,21 @@ func.func @test_shl_i8_nowrap() -> i8 {
%2 = test.reflect_bounds %1 : i8
return %2: i8
}
+
+/// A test case to ensure that the ranges for unsupported ops are initialized
+/// properly to maxRange, rather than left uninitialized.
+/// In this test case, the previous behavior would leave the ranges for %a and
+/// %b uninitialized, resulting in arith.cmpf's range not being updated, even
+/// though it has an integer valued result.
+
+// CHECK-LABEL: func @test_cmpf_propagates
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+func.func @test_cmpf_propagates(%a: f32, %b: f32) -> index {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+ %0 = arith.cmpf ueq, %a, %b : f32
+ %1 = arith.select %0, %c1, %c2 : index
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
>From 38188164f25054cbde9b8c074a65f8bdd1c30b43 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Thu, 23 May 2024 15:21:25 -0400
Subject: [PATCH 2/3] Rework integer range analysis interfaces
Modify the integer range analysis interfaces to handle uninitialized
values by allowing the inferred input ranges to be optional.
---
.../Analysis/DataFlow/IntegerRangeAnalysis.h | 2 +-
.../mlir/Interfaces/InferIntRangeInterface.h | 3 +-
.../mlir/Interfaces/InferIntRangeInterface.td | 2 +-
.../Interfaces/Utils/InferIntRangeCommon.h | 7 +-
.../DataFlow/IntegerRangeAnalysis.cpp | 45 +--
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 167 ++++++-----
.../GPU/IR/InferIntRangeInterfaceImpls.cpp | 32 ++-
.../Index/IR/InferIntRangeInterfaceImpls.cpp | 265 ++++++++++++------
.../Interfaces/Utils/InferIntRangeCommon.cpp | 17 ++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 31 +-
10 files changed, 366 insertions(+), 205 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 8bd7cf880c6af..fb07013041c0e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -33,7 +33,7 @@ class IntegerValueRange {
static IntegerValueRange getMaxRange(Value value);
/// Create an integer value range lattice value.
- IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+ IntegerValueRange(OptionalIntRanges value = std::nullopt)
: value(std::move(value)) {}
/// Whether the range is uninitialized. This happens when the state hasn't
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 05064a72ef02e..3d499b420eadd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -105,10 +105,11 @@ class ConstantIntRanges {
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
+using OptionalIntRanges = std::optional<ConstantIntRanges>;
/// The type of the `setResultRanges` callback provided to ops implementing
/// InferIntRangeInterface. It should be called once for each integer result
/// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
} // end namespace mlir
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index dbdc526c6f10b..f8e2c98d87cdb 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
APInts in their `argRanges` element.
}],
"void", "inferResultRanges", (ins
- "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+ "::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
"::mlir::SetIntRangeFn":$setResultRanges)
>];
}
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 851bb534bc7ee..9e3b04535dcab 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -25,7 +25,10 @@ namespace intrange {
/// abstracted away here to permit writing the function that handles both
/// 64- and 32-bit index types.
using InferRangeFn =
- function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+ std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+
+using OptionalRangeFn =
+ std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
static constexpr unsigned indexMinWidth = 32;
static constexpr unsigned indexMaxWidth = 64;
@@ -44,6 +47,8 @@ enum class OverflowFlags : uint32_t {
using InferRangeWithOvfFlagsFn =
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
+OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
+
/// Compute `inferFn` on `ranges`, whose size should be the index storage
/// bitwidth. Then, compute the function on `argRanges` again after truncating
/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index b69b2e0416209..622d875a63ace 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -36,8 +36,26 @@
using namespace mlir;
using namespace mlir::dataflow;
+namespace {
+
+OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
+ if (range.isUninitialized())
+ return std::nullopt;
+ return range.getValue();
+}
+
+OptionalIntRanges
+getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
+ return getOptionalRange(lattice->getValue());
+}
+
+} // end namespace
+
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+ if (width == 0)
+ return {};
+
APInt umin = APInt::getMinValue(width);
APInt umax = APInt::getMaxValue(width);
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
@@ -71,23 +89,14 @@ void IntegerRangeAnalysis::visitOperation(
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) {
// If the lattice on any operand is unitialized, bail out.
- if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
- return lattice->getValue().isUninitialized();
- })) {
- return;
- }
-
auto inferrable = dyn_cast<InferIntRangeInterface>(op);
if (!inferrable)
return setAllToEntryStates(results);
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
- SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
- return val->getValue().getValue();
- }));
+ auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
auto result = dyn_cast<OpResult>(v);
if (!result)
return;
@@ -97,7 +106,9 @@ void IntegerRangeAnalysis::visitOperation(
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
IntegerValueRange oldRange = lattice->getValue();
- ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+ ChangeResult changed =
+ attrs ? lattice->join(IntegerValueRange{attrs})
+ : lattice->join(IntegerValueRange::getMaxRange(v));
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -127,12 +138,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return getLatticeElementFor(op, value)->getValue().isUninitialized();
}))
return;
- SmallVector<ConstantIntRanges> argRanges(
+ SmallVector<OptionalIntRanges> argRanges(
llvm::map_range(op->getOperands(), [&](Value value) {
- return getLatticeElementFor(op, value)->getValue().getValue();
+ return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
}));
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
auto arg = dyn_cast<BlockArgument>(v);
if (!arg)
return;
@@ -143,7 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
IntegerValueRange oldRange = lattice->getValue();
- ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+ ChangeResult changed =
+ attrs ? lattice->join(IntegerValueRange{attrs})
+ : lattice->join(IntegerValueRange::getMaxRange(v));
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index fbe2ecab8adca..b59e5f9ec5a3e 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
-#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "int-range-analysis"
@@ -33,7 +32,7 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
// ConstantOp
//===----------------------------------------------------------------------===//
-void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
if (constAttr) {
@@ -46,48 +45,57 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// AddIOp
//===----------------------------------------------------------------------===//
-void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+ return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
-void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+ return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// MulIOp
//===----------------------------------------------------------------------===//
-void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+ return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
-void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::DivUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferDivU(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferDivU)(argRanges));
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
-void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferDivS(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferDivS)(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -95,8 +103,8 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
//===----------------------------------------------------------------------===//
void arith::CeilDivUIOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferCeilDivU(argRanges));
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferFromOptionals(inferCeilDivU)(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -104,8 +112,8 @@ void arith::CeilDivUIOp::inferResultRanges(
//===----------------------------------------------------------------------===//
void arith::CeilDivSIOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferCeilDivS(argRanges));
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferFromOptionals(inferCeilDivS)(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -113,122 +121,132 @@ void arith::CeilDivSIOp::inferResultRanges(
//===----------------------------------------------------------------------===//
void arith::FloorDivSIOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- return setResultRange(getResult(), inferFloorDivS(argRanges));
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ return setResultRange(getResult(),
+ inferFromOptionals(inferFloorDivS)(argRanges));
}
//===----------------------------------------------------------------------===//
// RemUIOp
//===----------------------------------------------------------------------===//
-void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::RemUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferRemU(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferRemU)(argRanges));
}
//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//
-void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::RemSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferRemS(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferRemS)(argRanges));
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
-void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::AndIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferAnd(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferAnd)(argRanges));
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
-void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::OrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferOr(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferOr)(argRanges));
}
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
-void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::XOrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferXor(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferXor)(argRanges));
}
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
-void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MaxSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMaxS(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferMaxS)(argRanges));
}
//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
-void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MaxUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMaxU(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferMaxU)(argRanges));
}
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
-void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MinSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMinS(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferMinS)(argRanges));
}
//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
-void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::MinUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMinU(argRanges));
+ setResultRange(getResult(), inferFromOptionals(inferMinU)(argRanges));
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
-void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ExtUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
+ if (!argRanges[0])
+ return;
+
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
- setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+ setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
-void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ExtSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
+ if (!argRanges[0])
+ return;
+
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
- setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+ setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
-void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
+ if (!argRanges[0])
+ return;
+
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
- setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+ setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
@@ -236,18 +254,21 @@ void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
//===----------------------------------------------------------------------===//
void arith::IndexCastOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ if (!argRanges[0])
+ return;
+
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
+ setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+ setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
else
- setResultRange(getResult(), argRanges[0]);
+ setResultRange(getResult(), *argRanges[0]);
}
//===----------------------------------------------------------------------===//
@@ -255,34 +276,40 @@ void arith::IndexCastOp::inferResultRanges(
//===----------------------------------------------------------------------===//
void arith::IndexCastUIOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ if (!argRanges[0])
+ return;
+
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
+ setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncRange(argRanges[0], destWidth));
+ setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
else
- setResultRange(getResult(), argRanges[0]);
+ setResultRange(getResult(), *argRanges[0]);
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
-void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
arith::CmpIPredicate arithPred = getPredicate();
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const OptionalIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ if (!lhs || !rhs)
+ return;
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnes(1);
- std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
+ std::optional<bool> truthValue = intrange::evaluatePred(pred, *lhs, *rhs);
if (truthValue.has_value() && *truthValue)
min = max;
else if (truthValue.has_value() && !(*truthValue))
@@ -295,9 +322,10 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// SelectOp
//===----------------------------------------------------------------------===//
-void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::SelectOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- std::optional<APInt> mbCondVal = argRanges[0].getConstantValue();
+ std::optional<APInt> mbCondVal =
+ argRanges[0] ? argRanges[0]->getConstantValue() : std::nullopt;
if (mbCondVal) {
if (mbCondVal->isZero())
@@ -306,33 +334,40 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRange(getResult(), argRanges[1]);
return;
}
- setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
+
+ if (argRanges[1] && argRanges[2])
+ setResultRange(getResult(), argRanges[1]->rangeUnion(*argRanges[2]));
}
//===----------------------------------------------------------------------===//
// ShLIOp
//===----------------------------------------------------------------------===//
-void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// ShRUIOp
//===----------------------------------------------------------------------===//
-void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferShrU(argRanges));
+ auto infer = inferFromOptionals(inferShrU);
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// ShRSIOp
//===----------------------------------------------------------------------===//
-void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void arith::ShRSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferShrS(argRanges));
+ auto infer = inferFromOptionals(inferShrS);
+ setResultRange(getResult(), infer(argRanges));
}
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 69017efb9a0e6..1342271029fa9 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -84,18 +84,18 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
return std::nullopt;
}
-void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ClusterDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
}
-void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ClusterIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
uint64_t max = kMaxClusterDim;
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
-void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
std::optional<uint64_t> knownVal =
getKnownLaunchDim(*this, LaunchDims::Block);
@@ -105,13 +105,13 @@ void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
-void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void BlockIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
-void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
if (knownVal)
@@ -120,23 +120,23 @@ void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
-void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void ThreadIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
-void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void LaneIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
}
-void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void SubgroupIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
}
-void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
uint64_t blockDimMax =
getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
@@ -146,24 +146,26 @@ void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
}
-void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void NumSubgroupsOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
-void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+void SubgroupSizeOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
}
-void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void LaunchOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
+ auto setRange = [&](const OptionalIntRanges &argRange, Value dimResult,
Value idxResult) {
- if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
+ if (!argRange ||
+ argRange->umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
return;
+
ConstantIntRanges dimRange =
- argRange.intersection(getIndexRange(1, kMaxDim));
+ argRange->intersection(getIndexRange(1, kMaxDim));
setResultRange(dimResult, dimRange);
ConstantIntRanges idxRange =
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index 64adb6b850524..cc6709f1253da 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -10,7 +10,6 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
-#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "int-range-analysis"
@@ -23,13 +22,13 @@ using namespace mlir::intrange;
// Constants
//===----------------------------------------------------------------------===//
-void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
const APInt &value = getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
}
-void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void BoolConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
bool value = getValue();
APInt asInt(/*numBits=*/1, value);
@@ -49,129 +48,195 @@ void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// the inference function without any `OverflowFlags`.
static std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
- return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
+ return [inferWithOvfFn](
+ ArrayRef<ConstantIntRanges> argRanges) -> ConstantIntRanges {
return inferWithOvfFn(argRanges, OverflowFlags::None);
};
}
-void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void AddOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
- argRanges, CmpMode::Both));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
+ CmpMode::Both);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void SubOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
- argRanges, CmpMode::Both));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+ CmpMode::Both);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MulOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
- argRanges, CmpMode::Both));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
+ CmpMode::Both);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void DivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+ CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void DivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CeilDivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CeilDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void FloorDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- return setResultRange(
- getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
+ });
+
+ return setResultRange(getResult(), infer(argRanges));
}
-void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void RemSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void RemUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MaxSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MaxUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MinSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void MinUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShlOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
- argRanges, CmpMode::Both));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
+ CmpMode::Both);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShrSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void ShrUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void AndOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void OrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void XOrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -208,56 +273,70 @@ static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
return ret;
}
-void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CastSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
- setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
- /*isSigned=*/true));
+
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/true);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
-void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
- setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
- /*isSigned=*/false));
+
+ auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/false);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//
-void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- index::IndexCmpPredicate indexPred = getPred();
- intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- APInt min = APInt::getZero(1);
- APInt max = APInt::getAllOnes(1);
-
- std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
-
- ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
- rhsTrunc = truncRange(rhs, indexMinWidth);
- std::optional<bool> truthValue32 =
- intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
- if (truthValue64 == truthValue32) {
- if (truthValue64.has_value() && *truthValue64)
- min = max;
- else if (truthValue64.has_value() && !(*truthValue64))
- max = min;
- }
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+ auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
+ index::IndexCmpPredicate indexPred = getPred();
+ intrange::CmpPredicate pred =
+ static_cast<intrange::CmpPredicate>(indexPred);
+ const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
+
+ APInt min = APInt::getZero(1);
+ APInt max = APInt::getAllOnes(1);
+
+ std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
+
+ ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+ rhsTrunc = truncRange(rhs, indexMinWidth);
+ std::optional<bool> truthValue32 =
+ intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+ if (truthValue64 == truthValue32) {
+ if (truthValue64.has_value() && *truthValue64)
+ min = max;
+ else if (truthValue64.has_value() && !(*truthValue64))
+ max = min;
+ }
+
+ return ConstantIntRanges::fromUnsigned(min, max);
+ });
+
+ setResultRange(getResult(), infer(argRanges));
}
//===----------------------------------------------------------------------===//
// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
//===----------------------------------------------------------------------===//
-void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void SizeOfOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRange) {
unsigned storageWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index fe1a67d628738..78754680ae58d 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -36,6 +36,23 @@ using namespace mlir;
using ConstArithFn =
function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
+std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>
+mlir::intrange::inferFromOptionals(intrange::InferRangeFn inferFn) {
+ return [inferFn = std::move(inferFn)](
+ ArrayRef<OptionalIntRanges> args) -> OptionalIntRanges {
+ llvm::SmallVector<ConstantIntRanges> unpacked;
+ unpacked.reserve(args.size());
+
+ for (const OptionalIntRanges &arg : args) {
+ if (!arg)
+ return std::nullopt;
+ unpacked.push_back(*arg);
+ }
+
+ return inferFn(unpacked);
+ };
+}
+
/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
/// If either computation overflows, make the result unbounded.
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b058a8e1abbcb..145b076c95a76 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -648,9 +648,10 @@ LogicalResult TestVerifiersOp::verifyRegions() {
//===----------------------------------------------------------------------===//
// TestWithBoundsOp
-void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void TestWithBoundsOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+ setResultRanges(getResult(), ConstantIntRanges{getUmin(), getUmax(),
+ getSmin(), getSmax()});
}
//===----------------------------------------------------------------------===//
@@ -681,29 +682,37 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
}
void TestWithBoundsRegionOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
Value arg = getRegion().getArgument(0);
- setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+ setResultRanges(
+ arg, ConstantIntRanges{getUmin(), getUmax(), getSmin(), getSmax()});
}
//===----------------------------------------------------------------------===//
// TestIncrementOp
-void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
+ if (!argRanges[0])
+ return;
+
+ const ConstantIntRanges &range = *argRanges[0];
APInt one(range.umin().getBitWidth(), 1);
- setResultRanges(getResult(),
- {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
- range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+ setResultRanges(getResult(), ConstantIntRanges{range.umin().uadd_sat(one),
+ range.umax().uadd_sat(one),
+ range.smin().sadd_sat(one),
+ range.smax().sadd_sat(one)});
}
//===----------------------------------------------------------------------===//
// TestReflectBoundsOp
void TestReflectBoundsOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
+ ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ if (!argRanges[0])
+ return;
+
+ const ConstantIntRanges &range = *argRanges[0];
MLIRContext *ctx = getContext();
Builder b(ctx);
Type sIntTy, uIntTy;
>From 5e23a7fec9bf3f9eeb4b86270c40600cb3d41324 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Fri, 24 May 2024 08:32:14 -0400
Subject: [PATCH 3/3] Convert uses of OptionalIntRange to IntegerValueRange
IntegerValueRange already exists and encodes the extact information that
we want to represent with OptionalIntRange. This makes the APIs clearer
than passing an std::optional everywhere.
---
.../Analysis/DataFlow/IntegerRangeAnalysis.h | 45 ---
.../mlir/Interfaces/InferIntRangeInterface.h | 53 +++-
.../mlir/Interfaces/InferIntRangeInterface.td | 2 +-
.../Interfaces/Utils/InferIntRangeCommon.h | 13 +-
.../DataFlow/IntegerRangeAnalysis.cpp | 57 +---
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 167 ++++++-----
.../GPU/IR/InferIntRangeInterfaceImpls.cpp | 35 ++-
.../Index/IR/InferIntRangeInterfaceImpls.cpp | 278 ++++++++++--------
.../lib/Interfaces/InferIntRangeInterface.cpp | 17 ++
.../Interfaces/Utils/InferIntRangeCommon.cpp | 18 +-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 16 +-
11 files changed, 367 insertions(+), 334 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index fb07013041c0e..191c023fb642c 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -24,51 +24,6 @@
namespace mlir {
namespace dataflow {
-/// This lattice value represents the integer range of an SSA value.
-class IntegerValueRange {
-public:
- /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
- /// range that is used to mark the value as unable to be analyzed further,
- /// where `t` is the type of `value`.
- static IntegerValueRange getMaxRange(Value value);
-
- /// Create an integer value range lattice value.
- IntegerValueRange(OptionalIntRanges value = std::nullopt)
- : value(std::move(value)) {}
-
- /// Whether the range is uninitialized. This happens when the state hasn't
- /// been set during the analysis.
- bool isUninitialized() const { return !value.has_value(); }
-
- /// Get the known integer value range.
- const ConstantIntRanges &getValue() const {
- assert(!isUninitialized());
- return *value;
- }
-
- /// Compare two ranges.
- bool operator==(const IntegerValueRange &rhs) const {
- return value == rhs.value;
- }
-
- /// Take the union of two ranges.
- static IntegerValueRange join(const IntegerValueRange &lhs,
- const IntegerValueRange &rhs) {
- if (lhs.isUninitialized())
- return rhs;
- if (rhs.isUninitialized())
- return lhs;
- return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
- }
-
- /// Print the integer value range.
- void print(raw_ostream &os) const { os << value; }
-
-private:
- /// The known integer value range.
- std::optional<ConstantIntRanges> value;
-};
-
/// This lattice element represents the integer value range of an SSA value.
/// When this lattice is updated, it automatically updates the constant value
/// of the SSA value (if the range can be narrowed to one).
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 3d499b420eadd..73013837f1227 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -105,11 +105,60 @@ class ConstantIntRanges {
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
-using OptionalIntRanges = std::optional<ConstantIntRanges>;
+/// This lattice value represents the integer range of an SSA value.
+class IntegerValueRange {
+public:
+ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+ /// range that is used to mark the value as unable to be analyzed further,
+ /// where `t` is the type of `value`.
+ static IntegerValueRange getMaxRange(Value value);
+
+ /// Create an integer value range lattice value.
+ IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
+
+ /// Create an integer value range lattice value.
+ IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+ : value(std::move(value)) {}
+
+ /// Whether the range is uninitialized. This happens when the state hasn't
+ /// been set during the analysis.
+ bool isUninitialized() const { return !value.has_value(); }
+
+ /// Get the known integer value range.
+ const ConstantIntRanges &getValue() const {
+ assert(!isUninitialized());
+ return *value;
+ }
+
+ /// Compare two ranges.
+ bool operator==(const IntegerValueRange &rhs) const {
+ return value == rhs.value;
+ }
+
+ /// Compute the least upper bound of two ranges.
+ static IntegerValueRange join(const IntegerValueRange &lhs,
+ const IntegerValueRange &rhs) {
+ if (lhs.isUninitialized())
+ return rhs;
+ if (rhs.isUninitialized())
+ return lhs;
+ return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
+ }
+
+ /// Print the integer value range.
+ void print(raw_ostream &os) const { os << value; }
+
+private:
+ /// The known integer value range.
+ std::optional<ConstantIntRanges> value;
+};
+
+raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
+
/// The type of the `setResultRanges` callback provided to ops implementing
/// InferIntRangeInterface. It should be called once for each integer result
/// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const OptionalIntRanges &)>;
+using SetIntRangeFn = function_ref<void(Value, const IntegerValueRange &)>;
} // end namespace mlir
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index f8e2c98d87cdb..795e67b8431bd 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -45,7 +45,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
APInts in their `argRanges` element.
}],
"void", "inferResultRanges", (ins
- "::llvm::ArrayRef<::std::optional<::mlir::ConstantIntRanges>>":$argRanges,
+ "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
"::mlir::SetIntRangeFn":$setResultRanges)
>];
}
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 9e3b04535dcab..8746a1cfba85c 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -27,8 +27,9 @@ namespace intrange {
using InferRangeFn =
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
-using OptionalRangeFn =
- std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>;
+/// Function that performs inferrence on an array of `IntegerValueRange`.
+using InferIntegerValueRangeFn =
+ std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;
static constexpr unsigned indexMinWidth = 32;
static constexpr unsigned indexMaxWidth = 64;
@@ -47,7 +48,11 @@ enum class OverflowFlags : uint32_t {
using InferRangeWithOvfFlagsFn =
function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>, OverflowFlags)>;
-OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
+/// Perform a pointwise extension of a function operating on `ConstantIntRanges`
+/// to a function operating on `IntegerValueRange` such that undefined input
+/// ranges propagate.
+InferIntegerValueRangeFn
+inferFromIntegerValueRange(intrange::InferRangeFn inferFn);
/// Compute `inferFn` on `ranges`, whose size should be the index storage
/// bitwidth. Then, compute the function on `argRanges` again after truncating
@@ -57,7 +62,7 @@ OptionalRangeFn inferFromOptionals(intrange::InferRangeFn inferFn);
///
/// The `mode` argument specifies if the unsigned, signed, or both results of
/// the inference computation should be used when comparing the results.
-ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
+ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
ArrayRef<ConstantIntRanges> argRanges,
CmpMode mode);
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 622d875a63ace..b2f8b5a72d0ba 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -36,33 +36,6 @@
using namespace mlir;
using namespace mlir::dataflow;
-namespace {
-
-OptionalIntRanges getOptionalRange(const IntegerValueRange &range) {
- if (range.isUninitialized())
- return std::nullopt;
- return range.getValue();
-}
-
-OptionalIntRanges
-getOptionalRangeFromLattice(const IntegerValueRangeLattice *lattice) {
- return getOptionalRange(lattice->getValue());
-}
-
-} // end namespace
-
-IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
- unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
- if (width == 0)
- return {};
-
- APInt umin = APInt::getMinValue(width);
- APInt umax = APInt::getMaxValue(width);
- APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
- APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
- return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
-}
-
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
Lattice::onUpdate(solver);
@@ -94,9 +67,12 @@ void IntegerRangeAnalysis::visitOperation(
return setAllToEntryStates(results);
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
- auto argRanges = llvm::map_to_vector(operands, getOptionalRangeFromLattice);
+ auto argRanges = llvm::map_to_vector(
+ operands, [](const IntegerValueRangeLattice *lattice) {
+ return lattice->getValue();
+ });
- auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
+ auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
auto result = dyn_cast<OpResult>(v);
if (!result)
return;
@@ -106,9 +82,7 @@ void IntegerRangeAnalysis::visitOperation(
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
IntegerValueRange oldRange = lattice->getValue();
- ChangeResult changed =
- attrs ? lattice->join(IntegerValueRange{attrs})
- : lattice->join(IntegerValueRange::getMaxRange(v));
+ ChangeResult changed = lattice->join(attrs);
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -133,17 +107,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
- // If the lattice on any operand is unitialized, bail out.
- if (llvm::any_of(op->getOperands(), [&](Value value) {
- return getLatticeElementFor(op, value)->getValue().isUninitialized();
- }))
- return;
- SmallVector<OptionalIntRanges> argRanges(
- llvm::map_range(op->getOperands(), [&](Value value) {
- return getOptionalRangeFromLattice(getLatticeElementFor(op, value));
- }));
- auto joinCallback = [&](Value v, const OptionalIntRanges &attrs) {
+ auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
+ return getLatticeElementFor(op, value)->getValue();
+ });
+
+ auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
auto arg = dyn_cast<BlockArgument>(v);
if (!arg)
return;
@@ -154,9 +123,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
IntegerValueRange oldRange = lattice->getValue();
- ChangeResult changed =
- attrs ? lattice->join(IntegerValueRange{attrs})
- : lattice->join(IntegerValueRange::getMaxRange(v));
+ ChangeResult changed = lattice->join(attrs);
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index b59e5f9ec5a3e..9456c9e87a277 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -32,7 +32,7 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
// ConstantOp
//===----------------------------------------------------------------------===//
-void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
if (constAttr) {
@@ -45,11 +45,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// AddIOp
//===----------------------------------------------------------------------===//
-void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::AddIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
- return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
- });
+ auto infer =
+ inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+ return inferAdd(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -58,11 +59,12 @@ void arith::AddIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// SubIOp
//===----------------------------------------------------------------------===//
-void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::SubIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
- return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
- });
+ auto infer =
+ inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+ return inferSub(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -71,11 +73,12 @@ void arith::SubIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// MulIOp
//===----------------------------------------------------------------------===//
-void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MulIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
- return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
- });
+ auto infer =
+ inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+ return inferMul(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -84,18 +87,18 @@ void arith::MulIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// DivUIOp
//===----------------------------------------------------------------------===//
-void arith::DivUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::DivUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferDivU)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferDivU)(argRanges));
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
-void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::DivSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferDivS)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferDivS)(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -103,8 +106,9 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
//===----------------------------------------------------------------------===//
void arith::CeilDivUIOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferCeilDivU)(argRanges));
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferFromIntegerValueRange(inferCeilDivU)(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -112,8 +116,9 @@ void arith::CeilDivUIOp::inferResultRanges(
//===----------------------------------------------------------------------===//
void arith::CeilDivSIOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferCeilDivS)(argRanges));
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferFromIntegerValueRange(inferCeilDivS)(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -121,132 +126,132 @@ void arith::CeilDivSIOp::inferResultRanges(
//===----------------------------------------------------------------------===//
void arith::FloorDivSIOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
return setResultRange(getResult(),
- inferFromOptionals(inferFloorDivS)(argRanges));
+ inferFromIntegerValueRange(inferFloorDivS)(argRanges));
}
//===----------------------------------------------------------------------===//
// RemUIOp
//===----------------------------------------------------------------------===//
-void arith::RemUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::RemUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferRemU)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferRemU)(argRanges));
}
//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//
-void arith::RemSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::RemSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferRemS)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferRemS)(argRanges));
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
-void arith::AndIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::AndIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferAnd)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferAnd)(argRanges));
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
-void arith::OrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::OrIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferOr)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferOr)(argRanges));
}
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
-void arith::XOrIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::XOrIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferXor)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferXor)(argRanges));
}
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
-void arith::MaxSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MaxSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferMaxS)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferMaxS)(argRanges));
}
//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
-void arith::MaxUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MaxUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferMaxU)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferMaxU)(argRanges));
}
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
-void arith::MinSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MinSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferMinS)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferMinS)(argRanges));
}
//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
-void arith::MinUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::MinUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferFromOptionals(inferMinU)(argRanges));
+ setResultRange(getResult(), inferFromIntegerValueRange(inferMinU)(argRanges));
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
-void arith::ExtUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ExtUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- if (!argRanges[0])
+ if (argRanges[0].isUninitialized())
return;
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
- setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), extUIRange(argRanges[0].getValue(), destWidth));
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
-void arith::ExtSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ExtSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- if (!argRanges[0])
+ if (argRanges[0].isUninitialized())
return;
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
- setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), extSIRange(argRanges[0].getValue(), destWidth));
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
-void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::TruncIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- if (!argRanges[0])
+ if (argRanges[0].isUninitialized())
return;
unsigned destWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
- setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
}
//===----------------------------------------------------------------------===//
@@ -254,8 +259,8 @@ void arith::TruncIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
//===----------------------------------------------------------------------===//
void arith::IndexCastOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
- if (!argRanges[0])
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+ if (argRanges[0].isUninitialized())
return;
Type sourceType = getOperand().getType();
@@ -264,11 +269,11 @@ void arith::IndexCastOp::inferResultRanges(
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extSIRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), extSIRange(argRanges[0].getValue(), destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
else
- setResultRange(getResult(), *argRanges[0]);
+ setResultRange(getResult(), argRanges[0]);
}
//===----------------------------------------------------------------------===//
@@ -276,8 +281,8 @@ void arith::IndexCastOp::inferResultRanges(
//===----------------------------------------------------------------------===//
void arith::IndexCastUIOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRange) {
- if (!argRanges[0])
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRange) {
+ if (argRanges[0].isUninitialized())
return;
Type sourceType = getOperand().getType();
@@ -286,30 +291,31 @@ void arith::IndexCastUIOp::inferResultRanges(
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extUIRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), extUIRange(argRanges[0].getValue(), destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncRange(*argRanges[0], destWidth));
+ setResultRange(getResult(), truncRange(argRanges[0].getValue(), destWidth));
else
- setResultRange(getResult(), *argRanges[0]);
+ setResultRange(getResult(), argRanges[0]);
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
-void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::CmpIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
arith::CmpIPredicate arithPred = getPredicate();
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
- const OptionalIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const IntegerValueRange &lhs = argRanges[0], &rhs = argRanges[1];
- if (!lhs || !rhs)
+ if (lhs.isUninitialized() || rhs.isUninitialized())
return;
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnes(1);
- std::optional<bool> truthValue = intrange::evaluatePred(pred, *lhs, *rhs);
+ std::optional<bool> truthValue =
+ intrange::evaluatePred(pred, lhs.getValue(), rhs.getValue());
if (truthValue.has_value() && *truthValue)
min = max;
else if (truthValue.has_value() && !(*truthValue))
@@ -322,32 +328,37 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// SelectOp
//===----------------------------------------------------------------------===//
-void arith::SelectOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::SelectOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
std::optional<APInt> mbCondVal =
- argRanges[0] ? argRanges[0]->getConstantValue() : std::nullopt;
+ !argRanges[0].isUninitialized()
+ ? argRanges[0].getValue().getConstantValue()
+ : std::nullopt;
+
+ const IntegerValueRange &trueCase = argRanges[1];
+ const IntegerValueRange &falseCase = argRanges[2];
if (mbCondVal) {
if (mbCondVal->isZero())
- setResultRange(getResult(), argRanges[2]);
+ setResultRange(getResult(), falseCase);
else
- setResultRange(getResult(), argRanges[1]);
+ setResultRange(getResult(), trueCase);
return;
}
- if (argRanges[1] && argRanges[2])
- setResultRange(getResult(), argRanges[1]->rangeUnion(*argRanges[2]));
+ setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
}
//===----------------------------------------------------------------------===//
// ShLIOp
//===----------------------------------------------------------------------===//
-void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ShLIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
- });
+ auto infer =
+ inferFromIntegerValueRange([&](ArrayRef<ConstantIntRanges> ranges) {
+ return inferShl(ranges, convertArithOverflowFlags(getOverflowFlags()));
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -356,9 +367,9 @@ void arith::ShLIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// ShRUIOp
//===----------------------------------------------------------------------===//
-void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ShRUIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals(inferShrU);
+ auto infer = inferFromIntegerValueRange(inferShrU);
setResultRange(getResult(), infer(argRanges));
}
@@ -366,8 +377,8 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// ShRSIOp
//===----------------------------------------------------------------------===//
-void arith::ShRSIOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void arith::ShRSIOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals(inferShrS);
+ auto infer = inferFromIntegerValueRange(inferShrS);
setResultRange(getResult(), infer(argRanges));
}
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 1342271029fa9..3676800ae0be5 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -84,18 +84,18 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
return std::nullopt;
}
-void ClusterDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void ClusterDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
}
-void ClusterIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void ClusterIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
uint64_t max = kMaxClusterDim;
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
-void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void BlockDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
std::optional<uint64_t> knownVal =
getKnownLaunchDim(*this, LaunchDims::Block);
@@ -105,13 +105,13 @@ void BlockDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
-void BlockIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void BlockIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
-void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void GridDimOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
if (knownVal)
@@ -120,23 +120,23 @@ void GridDimOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
-void ThreadIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void ThreadIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
}
-void LaneIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void LaneIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
}
-void SubgroupIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void SubgroupIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
}
-void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void GlobalIdOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
uint64_t blockDimMax =
getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
@@ -146,26 +146,29 @@ void GlobalIdOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
}
-void NumSubgroupsOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void NumSubgroupsOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(1, kMaxDim));
}
-void SubgroupSizeOp::inferResultRanges(ArrayRef<OptionalIntRanges>,
+void SubgroupSizeOp::inferResultRanges(ArrayRef<IntegerValueRange>,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
}
-void LaunchOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void LaunchOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto setRange = [&](const OptionalIntRanges &argRange, Value dimResult,
+ auto setRange = [&](const IntegerValueRange &argRange, Value dimResult,
Value idxResult) {
- if (!argRange ||
- argRange->umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
+ if (argRange.isUninitialized())
+ return;
+
+ const ConstantIntRanges &constRange = argRange.getValue();
+ if (constRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
return;
ConstantIntRanges dimRange =
- argRange->intersection(getIndexRange(1, kMaxDim));
+ constRange.intersection(getIndexRange(1, kMaxDim));
setResultRange(dimResult, dimRange);
ConstantIntRanges idxRange =
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
index cc6709f1253da..4d92957a86f92 100644
--- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -22,13 +22,13 @@ using namespace mlir::intrange;
// Constants
//===----------------------------------------------------------------------===//
-void ConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
const APInt &value = getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
}
-void BoolConstantOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void BoolConstantOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
bool value = getValue();
APInt asInt(/*numBits=*/1, value);
@@ -54,187 +54,207 @@ inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn) {
};
}
-void AddOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void AddOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
- CmpMode::Both);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferAdd), ranges,
+ CmpMode::Both);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void SubOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void SubOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
- CmpMode::Both);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+ CmpMode::Both);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void MulOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MulOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
- CmpMode::Both);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferMul), ranges,
+ CmpMode::Both);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void DivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void DivUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
- CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferSub), ranges,
+ CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void DivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void DivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferDivS, ranges, CmpMode::Signed);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void CeilDivUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CeilDivUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferCeilDivU, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void CeilDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CeilDivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferCeilDivS, ranges, CmpMode::Signed);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void FloorDivSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void FloorDivSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferFloorDivS, ranges, CmpMode::Signed);
+ });
return setResultRange(getResult(), infer(argRanges));
}
-void RemSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void RemSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferRemS, ranges, CmpMode::Signed);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void RemUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void RemUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferRemU, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void MaxSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MaxSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMaxS, ranges, CmpMode::Signed);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void MaxUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MaxUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMaxU, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void MinSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MinSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMinS, ranges, CmpMode::Signed);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void MinUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void MinUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferMinU, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void ShlOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ShlOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
- CmpMode::Both);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferWithoutOverflowFlags(inferShl), ranges,
+ CmpMode::Both);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void ShrSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ShrSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferShrS, ranges, CmpMode::Signed);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void ShrUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void ShrUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferShrU, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void AndOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void AndOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferAnd, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void OrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void OrOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferOr, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void XOrOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void XOrOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
- });
+ auto infer =
+ inferFromIntegerValueRange([](ArrayRef<ConstantIntRanges> ranges) {
+ return inferIndexOp(inferXor, ranges, CmpMode::Unsigned);
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -273,26 +293,30 @@ static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
return ret;
}
-void CastSOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CastSOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- Type sourceType = getOperand().getType();
- Type destType = getResult().getType();
+ auto infer =
+ inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/true);
- });
+ return inferIndexCast(ranges[0], sourceType, destType,
+ /*isSigned=*/true);
+ });
setResultRange(getResult(), infer(argRanges));
}
-void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CastUOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- Type sourceType = getOperand().getType();
- Type destType = getResult().getType();
+ auto infer =
+ inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
- auto infer = inferFromOptionals([&](ArrayRef<ConstantIntRanges> ranges) {
- return inferIndexCast(ranges[0], sourceType, destType, /*isSigned=*/false);
- });
+ return inferIndexCast(ranges[0], sourceType, destType,
+ /*isSigned=*/false);
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -301,33 +325,35 @@ void CastUOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// CmpOp
//===----------------------------------------------------------------------===//
-void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void CmpOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
- auto infer = inferFromOptionals([this](ArrayRef<ConstantIntRanges> ranges) {
- index::IndexCmpPredicate indexPred = getPred();
- intrange::CmpPredicate pred =
- static_cast<intrange::CmpPredicate>(indexPred);
- const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
-
- APInt min = APInt::getZero(1);
- APInt max = APInt::getAllOnes(1);
-
- std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
-
- ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
- rhsTrunc = truncRange(rhs, indexMinWidth);
- std::optional<bool> truthValue32 =
- intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
-
- if (truthValue64 == truthValue32) {
- if (truthValue64.has_value() && *truthValue64)
- min = max;
- else if (truthValue64.has_value() && !(*truthValue64))
- max = min;
- }
-
- return ConstantIntRanges::fromUnsigned(min, max);
- });
+ auto infer =
+ inferFromIntegerValueRange([this](ArrayRef<ConstantIntRanges> ranges) {
+ index::IndexCmpPredicate indexPred = getPred();
+ intrange::CmpPredicate pred =
+ static_cast<intrange::CmpPredicate>(indexPred);
+ const ConstantIntRanges &lhs = ranges[0], &rhs = ranges[1];
+
+ APInt min = APInt::getZero(1);
+ APInt max = APInt::getAllOnes(1);
+
+ std::optional<bool> truthValue64 =
+ intrange::evaluatePred(pred, lhs, rhs);
+
+ ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+ rhsTrunc = truncRange(rhs, indexMinWidth);
+ std::optional<bool> truthValue32 =
+ intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+ if (truthValue64 == truthValue32) {
+ if (truthValue64.has_value() && *truthValue64)
+ min = max;
+ else if (truthValue64.has_value() && !(*truthValue64))
+ max = min;
+ }
+
+ return ConstantIntRanges::fromUnsigned(min, max);
+ });
setResultRange(getResult(), infer(argRanges));
}
@@ -336,7 +362,7 @@ void CmpOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
//===----------------------------------------------------------------------===//
-void SizeOfOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void SizeOfOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRange) {
unsigned storageWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index b3f6c0ee3cc32..1891f6a1756f3 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -126,3 +126,20 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
return os << "unsigned : [" << range.umin() << ", " << range.umax()
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
}
+
+IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+ if (width == 0)
+ return {};
+
+ APInt umin = APInt::getMinValue(width);
+ APInt umax = APInt::getMaxValue(width);
+ APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+ APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+ return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
+ range.print(os);
+ return os;
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 78754680ae58d..43cca0d2c9845 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -36,20 +36,20 @@ using namespace mlir;
using ConstArithFn =
function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
-std::function<OptionalIntRanges(ArrayRef<OptionalIntRanges>)>
-mlir::intrange::inferFromOptionals(intrange::InferRangeFn inferFn) {
+std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>
+mlir::intrange::inferFromIntegerValueRange(intrange::InferRangeFn inferFn) {
return [inferFn = std::move(inferFn)](
- ArrayRef<OptionalIntRanges> args) -> OptionalIntRanges {
+ ArrayRef<IntegerValueRange> args) -> IntegerValueRange {
llvm::SmallVector<ConstantIntRanges> unpacked;
unpacked.reserve(args.size());
- for (const OptionalIntRanges &arg : args) {
- if (!arg)
- return std::nullopt;
- unpacked.push_back(*arg);
+ for (const IntegerValueRange &arg : args) {
+ if (arg.isUninitialized())
+ return {};
+ unpacked.push_back(arg.getValue());
}
- return inferFn(unpacked);
+ return IntegerValueRange{inferFn(unpacked)};
};
}
@@ -93,7 +93,7 @@ static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
//===----------------------------------------------------------------------===//
ConstantIntRanges
-mlir::intrange::inferIndexOp(InferRangeFn inferFn,
+mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
ArrayRef<ConstantIntRanges> argRanges,
intrange::CmpMode mode) {
ConstantIntRanges sixtyFour = inferFn(argRanges);
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 145b076c95a76..bb0687463c831 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -648,7 +648,7 @@ LogicalResult TestVerifiersOp::verifyRegions() {
//===----------------------------------------------------------------------===//
// TestWithBoundsOp
-void TestWithBoundsOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void TestWithBoundsOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), ConstantIntRanges{getUmin(), getUmax(),
getSmin(), getSmax()});
@@ -682,7 +682,7 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
}
void TestWithBoundsRegionOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRanges) {
Value arg = getRegion().getArgument(0);
setResultRanges(
arg, ConstantIntRanges{getUmin(), getUmax(), getSmin(), getSmax()});
@@ -691,12 +691,12 @@ void TestWithBoundsRegionOp::inferResultRanges(
//===----------------------------------------------------------------------===//
// TestIncrementOp
-void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
+void TestIncrementOp::inferResultRanges(ArrayRef<IntegerValueRange> argRanges,
SetIntRangeFn setResultRanges) {
- if (!argRanges[0])
+ if (argRanges[0].isUninitialized())
return;
- const ConstantIntRanges &range = *argRanges[0];
+ const ConstantIntRanges &range = argRanges[0].getValue();
APInt one(range.umin().getBitWidth(), 1);
setResultRanges(getResult(), ConstantIntRanges{range.umin().uadd_sat(one),
range.umax().uadd_sat(one),
@@ -708,11 +708,11 @@ void TestIncrementOp::inferResultRanges(ArrayRef<OptionalIntRanges> argRanges,
// TestReflectBoundsOp
void TestReflectBoundsOp::inferResultRanges(
- ArrayRef<OptionalIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- if (!argRanges[0])
+ ArrayRef<IntegerValueRange> argRanges, SetIntRangeFn setResultRanges) {
+ if (argRanges[0].isUninitialized())
return;
- const ConstantIntRanges &range = *argRanges[0];
+ const ConstantIntRanges &range = argRanges[0].getValue();
MLIRContext *ctx = getContext();
Builder b(ctx);
Type sIntTy, uIntTy;
More information about the Mlir-commits
mailing list