[Mlir-commits] [mlir] 6aeea70 - [mlir][dataflow] Fix for integer range analysis propagation bug (#93199)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 15:29:22 PDT 2024
Author: Spenser Bauman
Date: 2024-05-28T18:29:17-04:00
New Revision: 6aeea700df6f3f8db9e6a79be4aa593c6fcc7d18
URL: https://github.com/llvm/llvm-project/commit/6aeea700df6f3f8db9e6a79be4aa593c6fcc7d18
DIFF: https://github.com/llvm/llvm-project/commit/6aeea700df6f3f8db9e6a79be4aa593c6fcc7d18.diff
LOG: [mlir][dataflow] Fix for integer range analysis propagation bug (#93199)
Integer range analysis will not update the range of an operation when
any of the inferred input lattices are uninitialized. In the current
behavior, all lattice values for non integer types are uninitialized.
For operations like arith.cmpf
```mlir
%3 = arith.cmpf ugt, %arg0, %arg1 : f32
```
that will result in the range of the output also being uninitialized,
and so on for any consumer of the arith.cmpf result. When control-flow
ops are involved, the lack of propagation results in incorrect ranges,
as the back edges for loop carried values are not properly joined with
the definitions from the body region.
For example, an scf.while loop whose body region produces a value that
is in a dataflow relationship with some floating-point values through an
arith.cmpf operation:
```mlir
func.func @test_bad_range(%arg0: f32, %arg1: f32) -> (index, index) {
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%3 = arith.cmpf ugt, %arg0, %arg1 : f32
%1:2 = scf.while (%arg2 = %c0, %arg3 = %c0) : (index, index) -> (index, index) {
%2 = arith.cmpi ult, %arg2, %c4 : index
scf.condition(%2) %arg2, %arg3 : index, index
} do {
^bb0(%arg2: index, %arg3: index):
%4 = arith.select %3, %arg3, %arg3 : index
%5 = arith.addi %arg2, %c1 : index
scf.yield %5, %4 : index, index
}
return %1#0, %1#1 : index, index
}
```
The existing behavior results in the control condition %2 being
optimized to true, turning the while loop into an infinite loop. The
update to %arg2 through the body region is never factored into the range
calculation, as the ranges for the body ops all test as uninitialized.
This change causes all values initialized with setToEntryState to be set
to some initialized range, even if the values are not integers.
---------
Co-authored-by: Spenser Bauman <sabauma at fastmail>
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/include/mlir/Interfaces/InferIntRangeInterface.h
mlir/include/mlir/Interfaces/InferIntRangeInterface.td
mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Interfaces/InferIntRangeInterface.cpp
mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
mlir/test/Dialect/Arith/int-range-interface.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 8bd7cf880c6af..191c023fb642c 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -24,51 +24,6 @@
namespace mlir {
namespace dataflow {
-/// This lattice value represents the integer range of an SSA value.
-class IntegerValueRange {
-public:
- /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
- /// range that is used to mark the value as unable to be analyzed further,
- /// where `t` is the type of `value`.
- static IntegerValueRange getMaxRange(Value value);
-
- /// Create an integer value range lattice value.
- IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
- : value(std::move(value)) {}
-
- /// Whether the range is uninitialized. This happens when the state hasn't
- /// been set during the analysis.
- bool isUninitialized() const { return !value.has_value(); }
-
- /// Get the known integer value range.
- const ConstantIntRanges &getValue() const {
- assert(!isUninitialized());
- return *value;
- }
-
- /// Compare two ranges.
- bool operator==(const IntegerValueRange &rhs) const {
- return value == rhs.value;
- }
-
- /// Take the union of two ranges.
- static IntegerValueRange join(const IntegerValueRange &lhs,
- const IntegerValueRange &rhs) {
- if (lhs.isUninitialized())
- return rhs;
- if (rhs.isUninitialized())
- return lhs;
- return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
- }
-
- /// Print the integer value range.
- void print(raw_ostream &os) const { os << value; }
-
-private:
- /// The known integer value range.
- std::optional<ConstantIntRanges> value;
-};
-
/// This lattice element represents the integer value range of an SSA value.
/// When this lattice is updated, it automatically updates the constant value
/// of the SSA value (if the range can be narrowed to one).
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ead52332e8eec..46248dad3be9e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -49,7 +49,7 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
// Base class for integer binary operations.
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
- [DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
Results<(outs SignlessIntegerLike:$result)>;
@@ -107,7 +107,7 @@ class Arith_IToICastOp<string mnemonic, list<Trait> traits = []> :
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
SignlessFixedWidthIntegerLike,
traits #
- [DeclareOpInterfaceMethods<InferIntRangeInterface>]>;
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>;
// Cast from an integer type to a floating point type.
class Arith_IToFCastOp<string mnemonic, list<Trait> traits = []> :
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike, traits>;
@@ -139,7 +139,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
- [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
DefaultValuedAttr<
@@ -159,7 +159,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
[ConstantLike, Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllTypesMatch<["value", "result"]>,
- DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
@@ -1327,7 +1327,7 @@ def IndexCastTypeConstraint : TypeConstraint<Or<[
def Arith_IndexCastOp
: Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
- [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1346,7 +1346,7 @@ def Arith_IndexCastOp
def Arith_IndexCastUIOp
: Arith_CastOp<"index_castui", IndexCastTypeConstraint, IndexCastTypeConstraint,
- [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "unsigned cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
@@ -1400,7 +1400,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
def Arith_CmpIOp
: Arith_CompareOpOfAnyRank<"cmpi",
- [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1555,7 +1555,7 @@ class ScalarConditionOrMatchingShape<list<string> names> :
def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
ScalarConditionOrMatchingShape<["condition", "result"]>,
- DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 1da68ed2176d8..10719aae5c8b4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -52,7 +52,7 @@ def GPU_DimensionAttr : EnumAttr<GPU_Dialect, GPU_Dimension, "dim">;
class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
GPU_Op<mnemonic, !listconcat(traits, [
Pure,
- DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>])>,
Arguments<(ins GPU_DimensionAttr:$dimension)>, Results<(outs Index)> {
let assemblyFormat = "$dimension attr-dict";
@@ -144,7 +144,7 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
}
def GPU_LaneIdOp : GPU_Op<"lane_id", [
- Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let description = [{
Returns the lane id within the subgroup (warp/wave).
@@ -158,7 +158,7 @@ def GPU_LaneIdOp : GPU_Op<"lane_id", [
}
def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [
- Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+ Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
Returns the subgroup id, i.e., the index of the current subgroup within the
@@ -190,7 +190,7 @@ def GPU_GlobalIdOp : GPU_IndexOp<"global_id"> {
def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
- Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+ Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
Returns the number of subgroups within a workgroup.
@@ -206,7 +206,7 @@ def GPU_NumSubgroupsOp : GPU_Op<"num_subgroups", [
}
def GPU_SubgroupSizeOp : GPU_Op<"subgroup_size", [
- Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]>,
+ Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
Returns the number of threads within a subgroup.
@@ -687,7 +687,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
def GPU_LaunchOp : GPU_Op<"launch", [
AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface,
- DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
RecursiveMemoryEffects]>,
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index c6079cb8a98c8..a30ae9f739cbc 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -25,7 +25,7 @@ include "mlir/IR/OpBase.td"
/// Base class for Index dialect operations.
class IndexOp<string mnemonic, list<Trait> traits = []>
: Op<IndexDialect, mnemonic,
- [DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>] # traits>;
//===----------------------------------------------------------------------===//
// IndexBinaryOp
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 05064a72ef02e..0e107e88f5232 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -105,10 +105,83 @@ class ConstantIntRanges {
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
+/// This lattice value represents the integer range of an SSA value.
+class IntegerValueRange {
+public:
+ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+ /// range that is used to mark the value as unable to be analyzed further,
+ /// where `t` is the type of `value`.
+ static IntegerValueRange getMaxRange(Value value);
+
+ /// Create an integer value range lattice value.
+ IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
+
+ /// Create an integer value range lattice value.
+ IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+ : value(std::move(value)) {}
+
+ /// Whether the range is uninitialized. This happens when the state hasn't
+ /// been set during the analysis.
+ bool isUninitialized() const { return !value.has_value(); }
+
+ /// Get the known integer value range.
+ const ConstantIntRanges &getValue() const {
+ assert(!isUninitialized());
+ return *value;
+ }
+
+ /// Compare two ranges.
+ bool operator==(const IntegerValueRange &rhs) const {
+ return value == rhs.value;
+ }
+
+ /// Compute the least upper bound of two ranges.
+ static IntegerValueRange join(const IntegerValueRange &lhs,
+ const IntegerValueRange &rhs) {
+ if (lhs.isUninitialized())
+ return rhs;
+ if (rhs.isUninitialized())
+ return lhs;
+ return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
+ }
+
+ /// Print the integer value range.
+ void print(raw_ostream &os) const { os << value; }
+
+private:
+ /// The known integer value range.
+ std::optional<ConstantIntRanges> value;
+};
+
+raw_ostream &operator<<(raw_ostream &, const IntegerValueRange &);
+
/// The type of the `setResultRanges` callback provided to ops implementing
/// InferIntRangeInterface. It should be called once for each integer result
/// value and be passed the ConstantIntRanges corresponding to that value.
-using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+using SetIntRangeFn =
+ llvm::function_ref<void(Value, const ConstantIntRanges &)>;
+
+/// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
+/// This is the `setResultRanges` callback for the IntegerValueRange based
+/// interface method.
+using SetIntLatticeFn =
+ llvm::function_ref<void(Value, const IntegerValueRange &)>;
+
+class InferIntRangeInterface;
+
+namespace intrange::detail {
+/// Default implementation of `inferResultRanges` which dispatches to the
+/// `inferResultRangesFromOptional`.
+void defaultInferResultRanges(InferIntRangeInterface interface,
+ ArrayRef<IntegerValueRange> argRanges,
+ SetIntLatticeFn setResultRanges);
+
+/// Default implementation of `inferResultRangesFromOptional` which dispatches
+/// to the `inferResultRanges`.
+void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
+ ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges);
+} // end namespace intrange::detail
} // end namespace mlir
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index dbdc526c6f10b..6ee436ce4d6c2 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -28,9 +28,10 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
Infer the bounds on the results of this op given the bounds on its arguments.
For each result value or block argument (that isn't a branch argument,
since the dataflow analysis handles those case), the method should call
- `setValueRange` with that `Value` as an argument. When `setValueRange`
- is not called for some value, it will recieve a default value of the mimimum
- and maximum values for its type (the unbounded range).
+ `setValueRange` with that `Value` as an argument. When implemented,
+ `setValueRange` should be called on all result values for the operation.
+ When operations take non-integer inputs, the
+ `inferResultRangesFromOptional` method should be implemented instead.
When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
@@ -39,14 +40,39 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
This function will only be called when at least one result of the op is a
scalar integer value or the op has a region.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"inferResultRanges",
+ /*args=*/(ins "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+ "::mlir::SetIntRangeFn":$setResultRanges),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ ::mlir::intrange::detail::defaultInferResultRangesFromOptional($_op,
+ argRanges,
+ setResultRanges);
+ }]>,
+
+ InterfaceMethod<[{
+ Infer the bounds on the results of this op given the lattice representation
+ of the bounds for its arguments. For each result value or block argument
+ (that isn't a branch argument, since the dataflow analysis handles
+ those case), the method should call `setValueRange` with that `Value`
+ as an argument. When implemented, `setValueRange` should be called on
+ all result values for the operation.
- `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
- order. Non-integer arguments will have the an unbounded range of width-0
- APInts in their `argRanges` element.
+ This method allows for more precise implementations when operations
+ want to reason about inputs which may be undefined during the analysis.
}],
- "void", "inferResultRanges", (ins
- "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
- "::mlir::SetIntRangeFn":$setResultRanges)
- >];
+ /*retTy=*/"void",
+ /*methodName=*/"inferResultRangesFromOptional",
+ /*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
+ "::mlir::SetIntLatticeFn":$setResultRanges),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ ::mlir::intrange::detail::defaultInferResultRanges($_op,
+ argRanges,
+ setResultRanges);
+ }]>
+ ];
}
#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 851bb534bc7ee..3988a8826498a 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -25,7 +25,11 @@ namespace intrange {
/// abstracted away here to permit writing the function that handles both
/// 64- and 32-bit index types.
using InferRangeFn =
- function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+ std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+
+/// Function that performs inferrence on an array of `IntegerValueRange`.
+using InferIntegerValueRangeFn =
+ std::function<IntegerValueRange(ArrayRef<IntegerValueRange>)>;
static constexpr unsigned indexMinWidth = 32;
static constexpr unsigned indexMaxWidth = 64;
@@ -52,7 +56,7 @@ using InferRangeWithOvfFlagsFn =
///
/// The `mode` argument specifies if the unsigned, signed, or both results of
/// the inference computation should be used when comparing the results.
-ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
+ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn,
ArrayRef<ConstantIntRanges> argRanges,
CmpMode mode);
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index a82c30717e275..9721620807a0f 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -36,17 +36,6 @@
using namespace mlir;
using namespace mlir::dataflow;
-IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
- unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
- if (width == 0)
- return {};
- APInt umin = APInt::getMinValue(width);
- APInt umax = APInt::getMaxValue(width);
- APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
- APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
- return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
-}
-
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
Lattice::onUpdate(solver);
@@ -72,24 +61,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
void IntegerRangeAnalysis::visitOperation(
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) {
- // If the lattice on any operand is unitialized, bail out.
- if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
- return lattice->getValue().isUninitialized();
- })) {
- return;
- }
-
auto inferrable = dyn_cast<InferIntRangeInterface>(op);
if (!inferrable)
return setAllToEntryStates(results);
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
- SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
- return val->getValue().getValue();
- }));
+ auto argRanges = llvm::map_to_vector(
+ operands, [](const IntegerValueRangeLattice *lattice) {
+ return lattice->getValue();
+ });
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
auto result = dyn_cast<OpResult>(v);
if (!result)
return;
@@ -99,7 +81,7 @@ void IntegerRangeAnalysis::visitOperation(
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
IntegerValueRange oldRange = lattice->getValue();
- ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+ ChangeResult changed = lattice->join(attrs);
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -116,7 +98,7 @@ void IntegerRangeAnalysis::visitOperation(
propagateIfChanged(lattice, changed);
};
- inferrable.inferResultRanges(argRanges, joinCallback);
+ inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
}
void IntegerRangeAnalysis::visitNonControlFlowArguments(
@@ -124,17 +106,12 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
- // If the lattice on any operand is unitialized, bail out.
- if (llvm::any_of(op->getOperands(), [&](Value value) {
- return getLatticeElementFor(op, value)->getValue().isUninitialized();
- }))
- return;
- SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(op->getOperands(), [&](Value value) {
- return getLatticeElementFor(op, value)->getValue().getValue();
- }));
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
+ return getLatticeElementFor(op, value)->getValue();
+ });
+
+ auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
auto arg = dyn_cast<BlockArgument>(v);
if (!arg)
return;
@@ -145,7 +122,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
IntegerValueRange oldRange = lattice->getValue();
- ChangeResult changed = lattice->join(IntegerValueRange{attrs});
+ ChangeResult changed = lattice->join(attrs);
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -162,7 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
propagateIfChanged(lattice, changed);
};
- inferrable.inferResultRanges(argRanges, joinCallback);
+ inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
return;
}
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index fbe2ecab8adca..462044417b5fb 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -295,18 +295,24 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
// SelectOp
//===----------------------------------------------------------------------===//
-void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRange) {
- std::optional<APInt> mbCondVal = argRanges[0].getConstantValue();
+void arith::SelectOp::inferResultRangesFromOptional(
+ ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
+ std::optional<APInt> mbCondVal =
+ argRanges[0].isUninitialized()
+ ? std::nullopt
+ : argRanges[0].getValue().getConstantValue();
+
+ const IntegerValueRange &trueCase = argRanges[1];
+ const IntegerValueRange &falseCase = argRanges[2];
if (mbCondVal) {
if (mbCondVal->isZero())
- setResultRange(getResult(), argRanges[2]);
+ setResultRange(getResult(), falseCase);
else
- setResultRange(getResult(), argRanges[1]);
+ setResultRange(getResult(), trueCase);
return;
}
- setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
+ setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index b3f6c0ee3cc32..d879b93586899 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -126,3 +126,51 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
return os << "unsigned : [" << range.umin() << ", " << range.umax()
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
}
+
+IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+ if (width == 0)
+ return {};
+
+ APInt umin = APInt::getMinValue(width);
+ APInt umax = APInt::getMaxValue(width);
+ APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+ APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+ return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
+ range.print(os);
+ return os;
+}
+
+void mlir::intrange::detail::defaultInferResultRanges(
+ InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
+ SetIntLatticeFn setResultRanges) {
+ llvm::SmallVector<ConstantIntRanges> unpacked;
+ unpacked.reserve(argRanges.size());
+
+ for (const IntegerValueRange &range : argRanges) {
+ if (range.isUninitialized())
+ return;
+ unpacked.push_back(range.getValue());
+ }
+
+ interface.inferResultRanges(
+ unpacked,
+ [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
+ setResultRanges(value, IntegerValueRange{argRanges});
+ });
+}
+
+void mlir::intrange::detail::defaultInferResultRangesFromOptional(
+ InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
+ interface.inferResultRangesFromOptional(
+ ranges,
+ [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
+ if (!argRanges.isUninitialized())
+ setResultRanges(value, argRanges.getValue());
+ });
+}
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index fe1a67d628738..5b8d35e7bd519 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -76,7 +76,7 @@ static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
//===----------------------------------------------------------------------===//
ConstantIntRanges
-mlir::intrange::inferIndexOp(InferRangeFn inferFn,
+mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
ArrayRef<ConstantIntRanges> argRanges,
intrange::CmpMode mode) {
ConstantIntRanges sixtyFour = inferFn(argRanges);
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 5b538197a0c11..60f0ab41afa48 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -899,3 +899,22 @@ func.func @test_shl_i8_nowrap() -> i8 {
%2 = test.reflect_bounds %1 : i8
return %2: i8
}
+
+/// A test case to ensure that the ranges for unsupported ops are initialized
+/// properly to maxRange, rather than left uninitialized.
+/// In this test case, the previous behavior would leave the ranges for %a and
+/// %b uninitialized, resulting in arith.cmpf's range not being updated, even
+/// though it has an integer valued result.
+
+// CHECK-LABEL: func @test_cmpf_propagates
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+func.func @test_cmpf_propagates(%a: f32, %b: f32) -> index {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+ %0 = arith.cmpf ueq, %a, %b : f32
+ %1 = arith.select %0, %c1, %c2 : index
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 18324482153a5..9d7e0a7928ab8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2750,7 +2750,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
def TestWithBoundsOp : TEST_Op<"with_bounds",
- [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
NoMemoryEffect]> {
let arguments = (ins APIntAttr:$umin,
APIntAttr:$umax,
@@ -2762,7 +2762,7 @@ def TestWithBoundsOp : TEST_Op<"with_bounds",
}
def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
- [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
SingleBlock, NoTerminator]> {
let arguments = (ins APIntAttr:$umin,
APIntAttr:$umax,
@@ -2774,7 +2774,7 @@ def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
}
def TestIncrementOp : TEST_Op<"increment",
- [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
NoMemoryEffect, AllTypesMatch<["value", "result"]>]> {
let arguments = (ins InferIntRangeType:$value);
let results = (outs InferIntRangeType:$result);
@@ -2783,7 +2783,8 @@ def TestIncrementOp : TEST_Op<"increment",
}
def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
- [DeclareOpInterfaceMethods<InferIntRangeInterface>, AllTypesMatch<["value", "result"]>]> {
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
+ AllTypesMatch<["value", "result"]>]> {
let arguments = (ins InferIntRangeType:$value,
OptionalAttr<APIntAttr>:$umin,
OptionalAttr<APIntAttr>:$umax,
More information about the Mlir-commits
mailing list