[Mlir-commits] [mlir] 4553056 - [mlir][Index] Implement InferIntRangeInterface
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Jan 19 09:48:29 PST 2023
Author: Krzysztof Drewniak
Date: 2023-01-19T17:48:24Z
New Revision: 455305624884cf9237143e2ba0635fcc5ba5206a
URL: https://github.com/llvm/llvm-project/commit/455305624884cf9237143e2ba0635fcc5ba5206a
DIFF: https://github.com/llvm/llvm-project/commit/455305624884cf9237143e2ba0635fcc5ba5206a.diff
LOG: [mlir][Index] Implement InferIntRangeInterface
Implement InferIntRangeInterface for all operations in the Index dialect. The
inference implementation, unlike the one for Arith, accounts for the
fact that Index can be either 64 or 32 bits long by evaluating both
cases. Bounds are stored as if index were i64, but when inferring new
bounds, we compute both f(...) and f(trunc(...)). We then compare
trunc(f(...)) to f(trunc(...)). If they are equal in the relevant
range components, we use the 64-bit range computation, otherwise we
give the range ext(f(trunc(...))) union f(...).
Note that this can cause surprising behavior as seen in the tests,
where, for example, the order of min and max operations impacts the
behavior of the inference. The inference could perhaps be made more
precise in the future (ex. by tracking 32 and 64-bit results
separately and having them influence each other somehow) butt, since
my project targets an index=i32 platform and doesn't see index-valued
values > uint32_max, I'm not too concerned about it.
Depends on https://reviews.llvm.org/D141299
Depends on https://reviews.llvm.org/D141296
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D140899
Added:
mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Interfaces/Utils/CMakeLists.txt
mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
mlir/test/Dialect/Index/int-range-inference.mlir
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.h
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Dialect/Arith/IR/CMakeLists.txt
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Dialect/Index/IR/CMakeLists.txt
mlir/lib/Interfaces/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
index 85a0549edd4dd..d8debfb731323 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
@@ -13,6 +13,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 76008a17364f9..8fbccc4ba94fc 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Index/IR/IndexDialect.td"
include "mlir/Dialect/Index/IR/IndexEnums.td"
include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
@@ -23,7 +24,8 @@ include "mlir/IR/OpBase.td"
/// Base class for Index dialect operations.
class IndexOp<string mnemonic, list<Trait> traits = []>
- : Op<IndexDialect, mnemonic, [Pure] # traits>;
+ : Op<IndexDialect, mnemonic,
+ [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
//===----------------------------------------------------------------------===//
// IndexBinaryOp
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
new file mode 100644
index 0000000000000..7ee059cf342ce
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -0,0 +1,126 @@
+//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares implementations of range inference for operations that are
+// common to both the `arith` and `index` dialects to facilitate reuse.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
+#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+namespace intrange {
+/// Function that performs inference on an array of `ConstantIntRanges`,
+/// abstracted away here to permit writing the function that handles both
+/// 64- and 32-bit index types.
+using InferRangeFn =
+ function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+
+static constexpr unsigned indexMinWidth = 32;
+static constexpr unsigned indexMaxWidth = 64;
+
+enum class CmpMode : uint32_t { Both, Signed, Unsigned };
+
+/// Compute `inferFn` on `ranges`, whose size should be the index storage
+/// bitwidth. Then, compute the function on `argRanges` again after truncating
+/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
+/// equal to the 32-bit result, use it (to preserve compatibility with folders
+/// and inference precision), and take the union of the results otherwise.
+///
+/// The `mode` argument specifies if the unsigned, signed, or both results of
+/// the inference computation should be used when comparing the results.
+ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
+ ArrayRef<ConstantIntRanges> argRanges,
+ CmpMode mode);
+
+/// Independently zero-extend the unsigned values and sign-extend the signed
+/// values in `range` to `destWidth` bits, returning the resulting range.
+ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth);
+
+/// Use the unsigned values in `range` to zero-extend it to `destWidth`.
+ConstantIntRanges extUIRange(const ConstantIntRanges &range,
+ unsigned destWidth);
+
+/// Use the signed values in `range` to sign-extend it to `destWidth`.
+ConstantIntRanges extSIRange(const ConstantIntRanges &range,
+ unsigned destWidth);
+
+/// Truncate `range` to `destWidth` bits, taking care to handle cases such as
+/// the truncation of [255, 256] to i8 not being a uniform range.
+ConstantIntRanges truncRange(const ConstantIntRanges &range,
+ unsigned destWidth);
+
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferRemS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferRemU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMaxS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMaxU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMinS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMinU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferAnd(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferShrU(ArrayRef<ConstantIntRanges> argRanges);
+
+/// Copy of the enum from `arith` and `index` to allow the common integer range
+/// infrastructure to not depend on either dialect.
+enum class CmpPredicate : uint64_t {
+ eq,
+ ne,
+ slt,
+ sle,
+ sgt,
+ sge,
+ ult,
+ ule,
+ ugt,
+ uge,
+};
+
+/// Returns a boolean value if `pred` is statically true or false for
+/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the
+/// value of the predicate cannot be determined.
+Optional<bool> evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs);
+
+} // namespace intrange
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index 0de17bbfbd12a..ffbe80105911e 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArithDialect
LINK_LIBS PUBLIC
MLIRDialect
+ MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 10d6ef29756c6..971477fa94cb9 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -16,48 +17,7 @@
using namespace mlir;
using namespace mlir::arith;
-
-/// Function that evaluates the result of doing something on arithmetic
-/// constants and returns std::nullopt on overflow.
-using ConstArithFn =
- function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
-
-/// Return the maxmially wide signed or unsigned range for a given bitwidth.
-
-/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
-/// If either computation overflows, make the result unbounded.
-static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
- const APInt &minRight,
- const APInt &maxLeft,
- const APInt &maxRight, bool isSigned) {
- std::optional<APInt> maybeMin = op(minLeft, minRight);
- std::optional<APInt> maybeMax = op(maxLeft, maxRight);
- if (maybeMin && maybeMax)
- return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
- return ConstantIntRanges::maxRange(minLeft.getBitWidth());
-}
-
-/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
-/// ignoring unbounded values. Returns the maximal range if `op` overflows.
-static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
- ArrayRef<APInt> rhs, bool isSigned) {
- unsigned width = lhs[0].getBitWidth();
- APInt min =
- isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
- APInt max =
- isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
- for (const APInt &left : lhs) {
- for (const APInt &right : rhs) {
- std::optional<APInt> maybeThisResult = op(left, right);
- if (!maybeThisResult)
- return ConstantIntRanges::maxRange(width);
- APInt result = std::move(*maybeThisResult);
- min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
- max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
- }
- }
- return ConstantIntRanges::range(min, max, isSigned);
-}
+using namespace mlir::intrange;
//===----------------------------------------------------------------------===//
// ConstantOp
@@ -78,25 +38,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn uadd = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.uadd_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstArithFn sadd = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.sadd_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
-
- ConstantIntRanges urange = computeBoundsBy(
- uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
- ConstantIntRanges srange = computeBoundsBy(
- sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferAdd(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -105,25 +47,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn usub = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.usub_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstArithFn ssub = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.ssub_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstantIntRanges urange = computeBoundsBy(
- usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
- ConstantIntRanges srange = computeBoundsBy(
- ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferSub(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -132,96 +56,25 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn umul = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.umul_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstArithFn smul = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.smul_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
-
- ConstantIntRanges urange =
- minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
- /*isSigned=*/false);
- ConstantIntRanges srange =
- minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
- /*isSigned=*/true);
-
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferMul(argRanges));
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
-/// Fix up division results (ex. for ceiling and floor), returning an APInt
-/// if there has been no overflow
-using DivisionFixupFn = function_ref<std::optional<APInt>(
- const APInt &lhs, const APInt &rhs, const APInt &result)>;
-
-static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs,
- DivisionFixupFn fixup) {
- const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
- &rhsMax = rhs.umax();
-
- if (!rhsMin.isZero()) {
- auto udiv = [&fixup](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- return fixup(a, b, a.udiv(b));
- };
- return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
- /*isSigned=*/false);
- }
- // Otherwise, it's possible we might divide by 0.
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
-}
-
void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferDivUIRange(argRanges[0], argRanges[1],
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) { return result; }));
+ setResultRange(getResult(), inferDivU(argRanges));
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs,
- DivisionFixupFn fixup) {
- const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
- &rhsMax = rhs.smax();
- bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
-
- if (canDivide) {
- auto sdiv = [&fixup](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.sdiv_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : fixup(a, b, result);
- };
- return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
- /*isSigned=*/true);
- }
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
-}
-
void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferDivSIRange(argRanges[0], argRanges[1],
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) { return result; }));
+ setResultRange(getResult(), inferDivS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -230,20 +83,7 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::CeilDivUIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- DivisionFixupFn ceilDivUIFix =
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) -> std::optional<APInt> {
- if (!lhs.urem(rhs).isZero()) {
- bool overflowed = false;
- APInt corrected =
- result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
- return overflowed ? std::optional<APInt>() : corrected;
- }
- return result;
- };
- setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
+ setResultRange(getResult(), inferCeilDivU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -252,20 +92,7 @@ void arith::CeilDivUIOp::inferResultRanges(
void arith::CeilDivSIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- DivisionFixupFn ceilDivSIFix =
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) -> std::optional<APInt> {
- if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
- bool overflowed = false;
- APInt corrected =
- result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
- return overflowed ? std::optional<APInt>() : corrected;
- }
- return result;
- };
- setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
+ setResultRange(getResult(), inferCeilDivS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -274,20 +101,7 @@ void arith::CeilDivSIOp::inferResultRanges(
void arith::FloorDivSIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- DivisionFixupFn floorDivSIFix =
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) -> std::optional<APInt> {
- if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
- bool overflowed = false;
- APInt corrected =
- result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
- return overflowed ? std::optional<APInt>() : corrected;
- }
- return result;
- };
- setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
+ return setResultRange(getResult(), inferFloorDivS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -296,29 +110,7 @@ void arith::FloorDivSIOp::inferResultRanges(
void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
-
- unsigned width = rhsMin.getBitWidth();
- APInt umin = APInt::getZero(width);
- APInt umax = APInt::getMaxValue(width);
-
- if (!rhsMin.isZero()) {
- umax = rhsMax - 1;
- // Special case: sweeping out a contiguous range in N/[modulus]
- if (rhsMin == rhsMax) {
- const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
- if ((lhsMax - lhsMin).ult(rhsMax)) {
- APInt minRem = lhsMin.urem(rhsMax);
- APInt maxRem = lhsMax.urem(rhsMax);
- if (minRem.ule(maxRem)) {
- umin = minRem;
- umax = maxRem;
- }
- }
- }
- }
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), inferRemU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -327,67 +119,16 @@ void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
- &rhsMax = rhs.smax();
-
- unsigned width = rhsMax.getBitWidth();
- APInt smin = APInt::getSignedMinValue(width);
- APInt smax = APInt::getSignedMaxValue(width);
- // No bounds if zero could be a divisor.
- bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
- if (canBound) {
- APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
- bool canNegativeDividend = lhsMin.isNegative();
- bool canPositiveDividend = lhsMax.isStrictlyPositive();
- APInt zero = APInt::getZero(maxDivisor.getBitWidth());
- APInt maxPositiveResult = maxDivisor - 1;
- APInt minNegativeResult = -maxPositiveResult;
- smin = canNegativeDividend ? minNegativeResult : zero;
- smax = canPositiveDividend ? maxPositiveResult : zero;
- // Special case: sweeping out a contiguous range in N/[modulus].
- if (rhsMin == rhsMax) {
- if ((lhsMax - lhsMin).ult(maxDivisor)) {
- APInt minRem = lhsMin.srem(maxDivisor);
- APInt maxRem = lhsMax.srem(maxDivisor);
- if (minRem.sle(maxRem)) {
- smin = minRem;
- smax = maxRem;
- }
- }
- }
- }
- setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+ setResultRange(getResult(), inferRemS(argRanges));
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
-/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
-/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
-/// that both bonuds have in common. This gives us a consertive approximation
-/// for what values can be passed to bitwise operations.
-static std::tuple<APInt, APInt>
-widenBitwiseBounds(const ConstantIntRanges &bound) {
- APInt leftVal = bound.umin(), rightVal = bound.umax();
- unsigned bitwidth = leftVal.getBitWidth();
- unsigned
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
- leftVal.clearLowBits(
diff eringBits);
- rightVal.setLowBits(
diff eringBits);
- return std::make_tuple(std::move(leftVal), std::move(rightVal));
-}
-
void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a & b;
- };
- setResultRange(getResult(),
- minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferAnd(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -396,14 +137,7 @@ void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a | b;
- };
- setResultRange(getResult(),
- minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferOr(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -412,14 +146,7 @@ void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a ^ b;
- };
- setResultRange(getResult(),
- minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferXor(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -428,11 +155,7 @@ void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
- const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
- setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+ setResultRange(getResult(), inferMaxS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -441,11 +164,7 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
- const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), inferMaxU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -454,11 +173,7 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
- const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
- setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+ setResultRange(getResult(), inferMinS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -467,94 +182,40 @@ void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
- const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), inferMinU(argRanges));
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges extUIRange(const ConstantIntRanges &range,
- Type destType) {
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- APInt umin = range.umin().zext(destWidth);
- APInt umax = range.umax().zext(destWidth);
- return ConstantIntRanges::fromUnsigned(umin, umax);
-}
-
void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- Type destType = getResult().getType();
- setResultRange(getResult(), extUIRange(argRanges[0], destType));
+ unsigned destWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
- Type destType) {
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- APInt smin = range.smin().sext(destWidth);
- APInt smax = range.smax().sext(destWidth);
- return ConstantIntRanges::fromSigned(smin, smax);
-}
-
void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- Type destType = getResult().getType();
- setResultRange(getResult(), extSIRange(argRanges[0], destType));
+ unsigned destWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
- Type destType) {
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
- // the range of the resulting value is not contiguous ind includes 0.
- // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
- // but you can't truncate [255, 257] similarly.
- bool hasUnsignedRollover =
- range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
- APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
- : range.umin().trunc(destWidth);
- APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
- : range.umax().trunc(destWidth);
-
- // Signed post-truncation rollover will not occur when either:
- // - The high parts of the min and max, plus the sign bit, are the same
- // - The high halves + sign bit of the min and max are either all 1s or all 0s
- // and you won't create a [positive, negative] range by truncating.
- // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
- // but not [255, 257]_i16 to a range of i8s. You can also truncate
- // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
- // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
- // will truncate to 0x7e, which is greater than 0
- APInt sminHighPart = range.smin().ashr(destWidth - 1);
- APInt smaxHighPart = range.smax().ashr(destWidth - 1);
- bool hasSignedOverflow =
- (sminHighPart != smaxHighPart) &&
- !(sminHighPart.isAllOnes() &&
- (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
- !(sminHighPart.isZero() && smaxHighPart.isZero());
- APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
- : range.smin().trunc(destWidth);
- APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
- : range.smax().trunc(destWidth);
- return {umin, umax, smin, smax};
-}
-
void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- Type destType = getResult().getType();
- setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ unsigned destWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ setResultRange(getResult(), truncRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
@@ -569,9 +230,9 @@ void arith::IndexCastOp::inferResultRanges(
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extSIRange(argRanges[0], destType));
+ setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ setResultRange(getResult(), truncRange(argRanges[0], destWidth));
else
setResultRange(getResult(), argRanges[0]);
}
@@ -588,9 +249,9 @@ void arith::IndexCastUIOp::inferResultRanges(
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extUIRange(argRanges[0], destType));
+ setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ setResultRange(getResult(), truncRange(argRanges[0], destWidth));
else
setResultRange(getResult(), argRanges[0]);
}
@@ -599,51 +260,19 @@ void arith::IndexCastUIOp::inferResultRanges(
// CmpIOp
//===----------------------------------------------------------------------===//
-bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs) {
- switch (pred) {
- case arith::CmpIPredicate::sle:
- case arith::CmpIPredicate::slt:
- return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
- case arith::CmpIPredicate::ule:
- case arith::CmpIPredicate::ult:
- return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
- case arith::CmpIPredicate::sge:
- case arith::CmpIPredicate::sgt:
- return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
- case arith::CmpIPredicate::uge:
- case arith::CmpIPredicate::ugt:
- return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
- case arith::CmpIPredicate::eq: {
- std::optional<APInt> lhsConst = lhs.getConstantValue();
- std::optional<APInt> rhsConst = rhs.getConstantValue();
- return lhsConst && rhsConst && lhsConst == rhsConst;
- }
- case arith::CmpIPredicate::ne: {
- // While equality requires that there is an interpration of the preceeding
- // computations that produces equal constants, whether that be signed or
- // unsigned, statically determining inequality requires that neither
- // interpretation produce potentially overlapping ranges.
- bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
- isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
- bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
- isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
- return sne && une;
- }
- }
- return false;
-}
-
void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- arith::CmpIPredicate pred = getPredicate();
+ arith::CmpIPredicate arithPred = getPredicate();
+ intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnesValue(1);
- if (isStaticallyTrue(pred, lhs, rhs))
+
+ Optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
+ if (truthValue.has_value() && *truthValue)
min = max;
- else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
+ else if (truthValue.has_value() && !(*truthValue))
max = min;
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
@@ -673,18 +302,7 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn shl = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
- };
- ConstantIntRanges urange =
- minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
- /*isSigned=*/false);
- ConstantIntRanges srange =
- minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
- /*isSigned=*/true);
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferShl(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -693,15 +311,7 @@ void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn lshr = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
- };
- setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
- {rhs.umin(), rhs.umax()},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferShrU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -710,14 +320,5 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn ashr = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
- };
-
- setResultRange(getResult(),
- minMaxBy(ashr, {lhs.smin(), lhs.smax()},
- {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
+ setResultRange(getResult(), inferShrS(argRanges));
}
diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt
index 53321f1ea3f25..e820eececa483 100644
--- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIndexDialect
IndexAttrs.cpp
IndexDialect.cpp
IndexOps.cpp
+ InferIntRangeInterfaceImpls.cpp
DEPENDS
MLIRIndexOpsIncGen
@@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRIndexDialect
MLIRDialect
MLIRIR
MLIRCastInterfaces
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
new file mode 100644
index 0000000000000..6daa7640b017e
--- /dev/null
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -0,0 +1,252 @@
+//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+using namespace mlir::index;
+using namespace mlir::intrange;
+
+//===----------------------------------------------------------------------===//
+// Constants
+//===----------------------------------------------------------------------===//
+
+void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ const APInt &value = getValue();
+ setResultRange(getResult(), ConstantIntRanges::constant(value));
+}
+
+void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ bool value = getValue();
+ APInt asInt(/*numBits=*/1, value);
+ setResultRange(getResult(), ConstantIntRanges::constant(asInt));
+}
+
+//===----------------------------------------------------------------------===//
+// Arithmec operations. All of these operations will have their results inferred
+// using both the 64-bit values and truncated 32-bit values of their inputs,
+// with the results being the union of those inferences, except where the
+// truncation of the 64-bit result is equal to the 32-bit result (at which time
+// we take the 64-bit result).
+//===----------------------------------------------------------------------===//
+
+void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+}
+
+void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+}
+
+void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+}
+
+void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
+}
+
+void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
+}
+
+void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
+}
+
+void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
+}
+
+void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ return setResultRange(
+ getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
+}
+
+void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
+}
+
+void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
+}
+
+void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
+}
+
+void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
+}
+
+void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
+}
+
+void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
+}
+
+void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+}
+
+void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
+}
+
+void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
+}
+
+void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
+}
+
+void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
+}
+
+void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
+}
+
+//===----------------------------------------------------------------------===//
+// Casts
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
+ unsigned srcWidth, unsigned destWidth,
+ bool isSigned) {
+ if (srcWidth < destWidth)
+ return isSigned ? extSIRange(range, destWidth)
+ : extUIRange(range, destWidth);
+ if (srcWidth > destWidth)
+ return truncRange(range, destWidth);
+ return range;
+}
+
+// When casting to `index`, we will take the union of the possible fixed-width
+// casts.
+static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
+ Type sourceType, Type destType,
+ bool isSigned) {
+ unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
+ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+ if (sourceType.isIndex())
+ return makeLikeDest(range, srcWidth, destWidth, isSigned);
+ // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
+ ConstantIntRanges storageRange =
+ makeLikeDest(range, srcWidth, destWidth, isSigned);
+ ConstantIntRanges minWidthRange =
+ makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
+ ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
+ ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
+ return ret;
+}
+
+void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
+ setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
+ /*isSigned=*/true));
+}
+
+void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
+ setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
+ /*isSigned=*/false));
+}
+
+//===----------------------------------------------------------------------===//
+// CmpOp
+//===----------------------------------------------------------------------===//
+
+void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ index::IndexCmpPredicate indexPred = getPred();
+ intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ APInt min = APInt::getZero(1);
+ APInt max = APInt::getAllOnesValue(1);
+
+ Optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
+
+ ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+ rhsTrunc = truncRange(rhs, indexMinWidth);
+ Optional<bool> truthValue32 =
+ intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+ if (truthValue64 == truthValue32) {
+ if (truthValue64.has_value() && *truthValue64)
+ min = max;
+ else if (truthValue64.has_value() && !(*truthValue64))
+ max = min;
+ }
+ setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+}
+
+//===----------------------------------------------------------------------===//
+// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
+//===----------------------------------------------------------------------===//
+
+void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ unsigned storageWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ APInt min(/*numBits=*/storageWidth, indexMinWidth);
+ APInt max(/*numBits=*/storageWidth, indexMaxWidth);
+ setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index a7cdbb5b3a6fe..38ad0e4a2231c 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -51,3 +51,5 @@ add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)
add_mlir_interface_library(VectorInterfaces)
add_mlir_interface_library(ViewLikeInterface)
+
+add_subdirectory(Utils)
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..ece6c8e46ffea
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_library(MLIRInferIntRangeCommon
+ InferIntRangeCommon.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
+
+ DEPENDS
+ MLIRInferIntRangeInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRInferIntRangeInterface
+ MLIRIR
+)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
new file mode 100644
index 0000000000000..c81f004ecf5f9
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -0,0 +1,663 @@
+//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementations of range inference for operations that are
+// common to both the `arith` and `index` dialects to facilitate reuse.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+
+#include "llvm/Support/Debug.h"
+
+#include <iterator>
+#include <optional>
+
+using namespace mlir;
+
+#define DEBUG_TYPE "int-range-analysis"
+
+//===----------------------------------------------------------------------===//
+// General utilities
+//===----------------------------------------------------------------------===//
+
+/// Function that evaluates the result of doing something on arithmetic
+/// constants and returns std::nullopt on overflow.
+using ConstArithFn =
+ function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
+
+/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
+/// If either computation overflows, make the result unbounded.
+static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
+ const APInt &minRight,
+ const APInt &maxLeft,
+ const APInt &maxRight, bool isSigned) {
+ std::optional<APInt> maybeMin = op(minLeft, minRight);
+ std::optional<APInt> maybeMax = op(maxLeft, maxRight);
+ if (maybeMin && maybeMax)
+ return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
+ return ConstantIntRanges::maxRange(minLeft.getBitWidth());
+}
+
+/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
+/// ignoring unbounded values. Returns the maximal range if `op` overflows.
+static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
+ ArrayRef<APInt> rhs, bool isSigned) {
+ unsigned width = lhs[0].getBitWidth();
+ APInt min =
+ isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
+ APInt max =
+ isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
+ for (const APInt &left : lhs) {
+ for (const APInt &right : rhs) {
+ std::optional<APInt> maybeThisResult = op(left, right);
+ if (!maybeThisResult)
+ return ConstantIntRanges::maxRange(width);
+ APInt result = std::move(*maybeThisResult);
+ min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
+ max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
+ }
+ }
+ return ConstantIntRanges::range(min, max, isSigned);
+}
+
+//===----------------------------------------------------------------------===//
+// Ext, trunc, index op handling
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferIndexOp(InferRangeFn inferFn,
+ ArrayRef<ConstantIntRanges> argRanges,
+ intrange::CmpMode mode) {
+ ConstantIntRanges sixtyFour = inferFn(argRanges);
+ SmallVector<ConstantIntRanges, 2> truncated;
+ llvm::transform(argRanges, std::back_inserter(truncated),
+ [](const ConstantIntRanges &range) {
+ return truncRange(range, /*destWidth=*/indexMinWidth);
+ });
+ ConstantIntRanges thirtyTwo = inferFn(truncated);
+ ConstantIntRanges thirtyTwoAsSixtyFour =
+ extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
+ ConstantIntRanges sixtyFourAsThirtyTwo =
+ truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
+
+ LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
+ << " 32-bit = " << thirtyTwo << "\n");
+ bool truncEqual = false;
+ switch (mode) {
+ case intrange::CmpMode::Both:
+ truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
+ break;
+ case intrange::CmpMode::Signed:
+ truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
+ thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
+ break;
+ case intrange::CmpMode::Unsigned:
+ truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
+ thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
+ break;
+ }
+ if (truncEqual)
+ // Returing the 64-bit result preserves more information.
+ return sixtyFour;
+ ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
+ return merged;
+}
+
+ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
+ unsigned int destWidth) {
+ APInt umin = range.umin().zext(destWidth);
+ APInt umax = range.umax().zext(destWidth);
+ APInt smin = range.smin().sext(destWidth);
+ APInt smax = range.smax().sext(destWidth);
+ return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
+ unsigned destWidth) {
+ APInt umin = range.umin().zext(destWidth);
+ APInt umax = range.umax().zext(destWidth);
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
+ unsigned destWidth) {
+ APInt smin = range.smin().sext(destWidth);
+ APInt smax = range.smax().sext(destWidth);
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
+ unsigned int destWidth) {
+ // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
+ // the range of the resulting value is not contiguous ind includes 0.
+ // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
+ // but you can't truncate [255, 257] similarly.
+ bool hasUnsignedRollover =
+ range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
+ APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
+ : range.umin().trunc(destWidth);
+ APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
+ : range.umax().trunc(destWidth);
+
+ // Signed post-truncation rollover will not occur when either:
+ // - The high parts of the min and max, plus the sign bit, are the same
+ // - The high halves + sign bit of the min and max are either all 1s or all 0s
+ // and you won't create a [positive, negative] range by truncating.
+ // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
+ // but not [255, 257]_i16 to a range of i8s. You can also truncate
+ // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
+ // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
+ // will truncate to 0x7e, which is greater than 0
+ APInt sminHighPart = range.smin().ashr(destWidth - 1);
+ APInt smaxHighPart = range.smax().ashr(destWidth - 1);
+ bool hasSignedOverflow =
+ (sminHighPart != smaxHighPart) &&
+ !(sminHighPart.isAllOnes() &&
+ (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
+ !(sminHighPart.isZero() && smaxHighPart.isZero());
+ APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
+ : range.smin().trunc(destWidth);
+ APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
+ : range.smax().trunc(destWidth);
+ return {umin, umax, smin, smax};
+}
+
+//===----------------------------------------------------------------------===//
+// Addition
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ ConstArithFn uadd = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.uadd_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn sadd = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.sadd_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+
+ ConstantIntRanges urange = computeBoundsBy(
+ uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
+ ConstantIntRanges srange = computeBoundsBy(
+ sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+//===----------------------------------------------------------------------===//
+// Subtraction
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn usub = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.usub_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn ssub = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.ssub_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstantIntRanges urange = computeBoundsBy(
+ usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
+ ConstantIntRanges srange = computeBoundsBy(
+ ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+//===----------------------------------------------------------------------===//
+// Multiplication
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn umul = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.umul_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn smul = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.smul_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+
+ ConstantIntRanges urange =
+ minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/false);
+ ConstantIntRanges srange =
+ minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
+ /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+//===----------------------------------------------------------------------===//
+// DivU, CeilDivU (Unsigned division)
+//===----------------------------------------------------------------------===//
+
+/// Fix up division results (ex. for ceiling and floor), returning an APInt
+/// if there has been no overflow
+using DivisionFixupFn = function_ref<std::optional<APInt>(
+ const APInt &lhs, const APInt &rhs, const APInt &result)>;
+
+static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs,
+ DivisionFixupFn fixup) {
+ const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
+ &rhsMax = rhs.umax();
+
+ if (!rhsMin.isZero()) {
+ auto udiv = [&fixup](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ return fixup(a, b, a.udiv(b));
+ };
+ return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+ /*isSigned=*/false);
+ }
+ // Otherwise, it's possible we might divide by 0.
+ return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
+ConstantIntRanges
+mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
+ return inferDivURange(argRanges[0], argRanges[1],
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) { return result; });
+}
+
+ConstantIntRanges
+mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ DivisionFixupFn ceilDivUIFix =
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) -> std::optional<APInt> {
+ if (!lhs.urem(rhs).isZero()) {
+ bool overflowed = false;
+ APInt corrected =
+ result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+ return overflowed ? std::optional<APInt>() : corrected;
+ }
+ return result;
+ };
+ return inferDivURange(lhs, rhs, ceilDivUIFix);
+}
+
+//===----------------------------------------------------------------------===//
+// DivS, CeilDivS, FloorDivS (Signed division)
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs,
+ DivisionFixupFn fixup) {
+ const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+ &rhsMax = rhs.smax();
+ bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
+
+ if (canDivide) {
+ auto sdiv = [&fixup](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.sdiv_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : fixup(a, b, result);
+ };
+ return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+ /*isSigned=*/true);
+ }
+ return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
+ConstantIntRanges
+mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
+ return inferDivSRange(argRanges[0], argRanges[1],
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) { return result; });
+}
+
+ConstantIntRanges
+mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ DivisionFixupFn ceilDivSIFix =
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) -> std::optional<APInt> {
+ if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
+ bool overflowed = false;
+ APInt corrected =
+ result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+ return overflowed ? std::optional<APInt>() : corrected;
+ }
+ return result;
+ };
+ return inferDivSRange(lhs, rhs, ceilDivSIFix);
+}
+
+ConstantIntRanges
+mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ DivisionFixupFn floorDivSIFix =
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) -> std::optional<APInt> {
+ if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
+ bool overflowed = false;
+ APInt corrected =
+ result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
+ return overflowed ? std::optional<APInt>() : corrected;
+ }
+ return result;
+ };
+ return inferDivSRange(lhs, rhs, floorDivSIFix);
+}
+
+//===----------------------------------------------------------------------===//
+// Signed remainder (RemS)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+ &rhsMax = rhs.smax();
+
+ unsigned width = rhsMax.getBitWidth();
+ APInt smin = APInt::getSignedMinValue(width);
+ APInt smax = APInt::getSignedMaxValue(width);
+ // No bounds if zero could be a divisor.
+ bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
+ if (canBound) {
+ APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
+ bool canNegativeDividend = lhsMin.isNegative();
+ bool canPositiveDividend = lhsMax.isStrictlyPositive();
+ APInt zero = APInt::getZero(maxDivisor.getBitWidth());
+ APInt maxPositiveResult = maxDivisor - 1;
+ APInt minNegativeResult = -maxPositiveResult;
+ smin = canNegativeDividend ? minNegativeResult : zero;
+ smax = canPositiveDividend ? maxPositiveResult : zero;
+ // Special case: sweeping out a contiguous range in N/[modulus].
+ if (rhsMin == rhsMax) {
+ if ((lhsMax - lhsMin).ult(maxDivisor)) {
+ APInt minRem = lhsMin.srem(maxDivisor);
+ APInt maxRem = lhsMax.srem(maxDivisor);
+ if (minRem.sle(maxRem)) {
+ smin = minRem;
+ smax = maxRem;
+ }
+ }
+ }
+ }
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+//===----------------------------------------------------------------------===//
+// Unsigned remainder (RemU)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
+
+ unsigned width = rhsMin.getBitWidth();
+ APInt umin = APInt::getZero(width);
+ APInt umax = APInt::getMaxValue(width);
+
+ if (!rhsMin.isZero()) {
+ umax = rhsMax - 1;
+ // Special case: sweeping out a contiguous range in N/[modulus]
+ if (rhsMin == rhsMax) {
+ const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
+ if ((lhsMax - lhsMin).ult(rhsMax)) {
+ APInt minRem = lhsMin.urem(rhsMax);
+ APInt maxRem = lhsMax.urem(rhsMax);
+ if (minRem.ule(maxRem)) {
+ umin = minRem;
+ umax = maxRem;
+ }
+ }
+ }
+ }
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+//===----------------------------------------------------------------------===//
+// Max and min (MaxS, MaxU, MinS, MinU)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
+ const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+ConstantIntRanges
+mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
+ const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+ConstantIntRanges
+mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
+ const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+ConstantIntRanges
+mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
+ const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+//===----------------------------------------------------------------------===//
+// Bitwise operators (And, Or, Xor)
+//===----------------------------------------------------------------------===//
+
+/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
+/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
+/// that both bonuds have in common. This gives us a consertive approximation
+/// for what values can be passed to bitwise operations.
+static std::tuple<APInt, APInt>
+widenBitwiseBounds(const ConstantIntRanges &bound) {
+ APInt leftVal = bound.umin(), rightVal = bound.umax();
+ unsigned bitwidth = leftVal.getBitWidth();
+ unsigned
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
+ leftVal.clearLowBits(
diff eringBits);
+ rightVal.setLowBits(
diff eringBits);
+ return std::make_tuple(std::move(leftVal), std::move(rightVal));
+}
+
+ConstantIntRanges
+mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
+ auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+ auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+ auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+ return a & b;
+ };
+ return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+ /*isSigned=*/false);
+}
+
+ConstantIntRanges
+mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
+ auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+ auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+ auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+ return a | b;
+ };
+ return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+ /*isSigned=*/false);
+}
+
+ConstantIntRanges
+mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
+ auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+ auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+ auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+ return a ^ b;
+ };
+ return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+ /*isSigned=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// Shifts (Shl, ShrS, ShrU)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ ConstArithFn shl = [](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+ };
+ ConstantIntRanges urange =
+ minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/false);
+ ConstantIntRanges srange =
+ minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+ConstantIntRanges
+mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn ashr = [](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
+ };
+
+ return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/true);
+}
+
+ConstantIntRanges
+mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn lshr = [](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
+ };
+ return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// Comparisons (Cmp)
+//===----------------------------------------------------------------------===//
+
+static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
+ switch (pred) {
+ case intrange::CmpPredicate::eq:
+ return intrange::CmpPredicate::ne;
+ case intrange::CmpPredicate::ne:
+ return intrange::CmpPredicate::eq;
+ case intrange::CmpPredicate::slt:
+ return intrange::CmpPredicate::sge;
+ case intrange::CmpPredicate::sle:
+ return intrange::CmpPredicate::sgt;
+ case intrange::CmpPredicate::sgt:
+ return intrange::CmpPredicate::sle;
+ case intrange::CmpPredicate::sge:
+ return intrange::CmpPredicate::slt;
+ case intrange::CmpPredicate::ult:
+ return intrange::CmpPredicate::uge;
+ case intrange::CmpPredicate::ule:
+ return intrange::CmpPredicate::ugt;
+ case intrange::CmpPredicate::ugt:
+ return intrange::CmpPredicate::ule;
+ case intrange::CmpPredicate::uge:
+ return intrange::CmpPredicate::ult;
+ }
+ llvm_unreachable("unknown cmp predicate value");
+}
+
+static bool isStaticallyTrue(intrange::CmpPredicate pred,
+ const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs) {
+ switch (pred) {
+ case intrange::CmpPredicate::sle:
+ return lhs.smax().sle(rhs.smin());
+ case intrange::CmpPredicate::slt:
+ return lhs.smax().slt(rhs.smin());
+ case intrange::CmpPredicate::ule:
+ return lhs.umax().ule(rhs.umin());
+ case intrange::CmpPredicate::ult:
+ return lhs.umax().ult(rhs.umin());
+ case intrange::CmpPredicate::sge:
+ return lhs.smin().sge(rhs.smax());
+ case intrange::CmpPredicate::sgt:
+ return lhs.smin().sgt(rhs.smax());
+ case intrange::CmpPredicate::uge:
+ return lhs.umin().uge(rhs.umax());
+ case intrange::CmpPredicate::ugt:
+ return lhs.umin().ugt(rhs.umax());
+ case intrange::CmpPredicate::eq: {
+ std::optional<APInt> lhsConst = lhs.getConstantValue();
+ std::optional<APInt> rhsConst = rhs.getConstantValue();
+ return lhsConst && rhsConst && lhsConst == rhsConst;
+ }
+ case intrange::CmpPredicate::ne: {
+ // While equality requires that there is an interpration of the preceeding
+ // computations that produces equal constants, whether that be signed or
+ // unsigned, statically determining inequality requires that neither
+ // interpretation produce potentially overlapping ranges.
+ bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
+ isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
+ bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
+ isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
+ return sne && une;
+ }
+ }
+ return false;
+}
+
+std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
+ const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs) {
+ if (isStaticallyTrue(pred, lhs, rhs))
+ return true;
+ if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
+ return false;
+ return std::nullopt;
+}
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
new file mode 100644
index 0000000000000..2784d5fd5cf70
--- /dev/null
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+
+// Most operations are covered by the `arith` tests, which use the same code
+// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
+// code is operating as expected.
+
+// CHECK-LABEL: func @add_same_for_both
+// CHECK: %[[true:.*]] = index.bool.constant true
+// CHECK: return %[[true]]
+func.func @add_same_for_both(%arg0 : index) -> i1 {
+ %c1 = index.constant 1
+ %calmostBig = index.constant 0xfffffffe
+ %0 = index.minu %arg0, %calmostBig
+ %1 = index.add %0, %c1
+ %2 = index.cmp uge(%1, %c1)
+ func.return %2 : i1
+}
+
+// CHECK-LABEL: func @add_unsigned_ov
+// CHECK: %[[uge:.*]] = index.cmp uge
+// CHECK: return %[[uge]]
+func.func @add_unsigned_ov(%arg0 : index) -> i1 {
+ %c1 = index.constant 1
+ %cu32_max = index.constant 0xffffffff
+ %0 = index.minu %arg0, %cu32_max
+ %1 = index.add %0, %c1
+ // On 32-bit, the add could wrap, so the result doesn't have to be >= 1
+ %2 = index.cmp uge(%1, %c1)
+ func.return %2 : i1
+}
+
+// CHECK-LABEL: func @add_signed_ov
+// CHECK: %[[sge:.*]] = index.cmp sge
+// CHECK: return %[[sge]]
+func.func @add_signed_ov(%arg0 : index) -> i1 {
+ %c0 = index.constant 0
+ %c1 = index.constant 1
+ %ci32_max = index.constant 0x7fffffff
+ %0 = index.minu %arg0, %ci32_max
+ %1 = index.add %0, %c1
+ // On 32-bit, the add could wrap, so the result doesn't have to be positive
+ %2 = index.cmp sge(%1, %c0)
+ func.return %2 : i1
+}
+
+// CHECK-LABEL: func @add_big
+// CHECK: %[[true:.*]] = index.bool.constant true
+// CHECK: return %[[true]]
+func.func @add_big(%arg0 : index) -> i1 {
+ %c1 = index.constant 1
+ %cmin = index.constant 0x300000000
+ %cmax = index.constant 0x30000ffff
+ // Note: the order of the clamps matters.
+ // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff]
+ // and then [0x30...0000, 0x30...ffff]
+ // If you switch the order of the below operations, you instead first infer
+ // the range [0,0x3...ffff]. Then, the min inference can't constraint
+ // this intermediate, since in the 32-bit case we could have, for example
+ // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff
+ // which means we can't do any inference.
+ %0 = index.maxu %arg0, %cmin
+ %1 = index.minu %0, %cmax
+ %2 = index.add %1, %c1
+ %3 = index.cmp uge(%1, %cmin)
+ func.return %3 : i1
+}
More information about the Mlir-commits
mailing list