[Mlir-commits] [mlir] 5af9d16 - [mlir][Index] Implement InferIntRangeInterface, re-land
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri Jan 20 12:32:36 PST 2023
Author: Krzysztof Drewniak
Date: 2023-01-20T20:32:30Z
New Revision: 5af9d16dae71f2c2087ba88c5fc06893e6aecfe9
URL: https://github.com/llvm/llvm-project/commit/5af9d16dae71f2c2087ba88c5fc06893e6aecfe9
DIFF: https://github.com/llvm/llvm-project/commit/5af9d16dae71f2c2087ba88c5fc06893e6aecfe9.diff
LOG: [mlir][Index] Implement InferIntRangeInterface, re-land
Re-land D140899 to fix a missing dependency in the index dialect's
CMakeLists.txt.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D142147
Added:
mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Interfaces/Utils/CMakeLists.txt
mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
mlir/test/Dialect/Index/int-range-inference.mlir
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.h
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Dialect/Arith/IR/CMakeLists.txt
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Dialect/Index/IR/CMakeLists.txt
mlir/lib/Interfaces/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
index 85a0549edd4dd..d8debfb731323 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
@@ -13,6 +13,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 76008a17364f9..8fbccc4ba94fc 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Index/IR/IndexDialect.td"
include "mlir/Dialect/Index/IR/IndexEnums.td"
include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
@@ -23,7 +24,8 @@ include "mlir/IR/OpBase.td"
/// Base class for Index dialect operations.
class IndexOp<string mnemonic, list<Trait> traits = []>
- : Op<IndexDialect, mnemonic, [Pure] # traits>;
+ : Op<IndexDialect, mnemonic,
+ [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>] # traits>;
//===----------------------------------------------------------------------===//
// IndexBinaryOp
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
new file mode 100644
index 0000000000000..7ee059cf342ce
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -0,0 +1,126 @@
+//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares implementations of range inference for operations that are
+// common to both the `arith` and `index` dialects to facilitate reuse.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
+#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+namespace intrange {
+/// Function that performs inference on an array of `ConstantIntRanges`,
+/// abstracted away here to permit writing the function that handles both
+/// 64- and 32-bit index types.
+using InferRangeFn =
+ function_ref<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>;
+
+static constexpr unsigned indexMinWidth = 32;
+static constexpr unsigned indexMaxWidth = 64;
+
+enum class CmpMode : uint32_t { Both, Signed, Unsigned };
+
+/// Compute `inferFn` on `ranges`, whose size should be the index storage
+/// bitwidth. Then, compute the function on `argRanges` again after truncating
+/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is
+/// equal to the 32-bit result, use it (to preserve compatibility with folders
+/// and inference precision), and take the union of the results otherwise.
+///
+/// The `mode` argument specifies if the unsigned, signed, or both results of
+/// the inference computation should be used when comparing the results.
+ConstantIntRanges inferIndexOp(InferRangeFn inferFn,
+ ArrayRef<ConstantIntRanges> argRanges,
+ CmpMode mode);
+
+/// Independently zero-extend the unsigned values and sign-extend the signed
+/// values in `range` to `destWidth` bits, returning the resulting range.
+ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth);
+
+/// Use the unsigned values in `range` to zero-extend it to `destWidth`.
+ConstantIntRanges extUIRange(const ConstantIntRanges &range,
+ unsigned destWidth);
+
+/// Use the signed values in `range` to sign-extend it to `destWidth`.
+ConstantIntRanges extSIRange(const ConstantIntRanges &range,
+ unsigned destWidth);
+
+/// Truncate `range` to `destWidth` bits, taking care to handle cases such as
+/// the truncation of [255, 256] to i8 not being a uniform range.
+ConstantIntRanges truncRange(const ConstantIntRanges &range,
+ unsigned destWidth);
+
+ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferRemS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferRemU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMaxS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMaxU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMinS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferMinU(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferAnd(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferOr(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferXor(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferShl(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferShrS(ArrayRef<ConstantIntRanges> argRanges);
+
+ConstantIntRanges inferShrU(ArrayRef<ConstantIntRanges> argRanges);
+
+/// Copy of the enum from `arith` and `index` to allow the common integer range
+/// infrastructure to not depend on either dialect.
+enum class CmpPredicate : uint64_t {
+ eq,
+ ne,
+ slt,
+ sle,
+ sgt,
+ sge,
+ ult,
+ ule,
+ ugt,
+ uge,
+};
+
+/// Returns a boolean value if `pred` is statically true or false for
+/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the
+/// value of the predicate cannot be determined.
+Optional<bool> evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs);
+
+} // namespace intrange
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H
diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index 0de17bbfbd12a..ffbe80105911e 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArithDialect
LINK_LIBS PUBLIC
MLIRDialect
+ MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 10d6ef29756c6..971477fa94cb9 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -16,48 +17,7 @@
using namespace mlir;
using namespace mlir::arith;
-
-/// Function that evaluates the result of doing something on arithmetic
-/// constants and returns std::nullopt on overflow.
-using ConstArithFn =
- function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
-
-/// Return the maxmially wide signed or unsigned range for a given bitwidth.
-
-/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
-/// If either computation overflows, make the result unbounded.
-static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
- const APInt &minRight,
- const APInt &maxLeft,
- const APInt &maxRight, bool isSigned) {
- std::optional<APInt> maybeMin = op(minLeft, minRight);
- std::optional<APInt> maybeMax = op(maxLeft, maxRight);
- if (maybeMin && maybeMax)
- return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
- return ConstantIntRanges::maxRange(minLeft.getBitWidth());
-}
-
-/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
-/// ignoring unbounded values. Returns the maximal range if `op` overflows.
-static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
- ArrayRef<APInt> rhs, bool isSigned) {
- unsigned width = lhs[0].getBitWidth();
- APInt min =
- isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
- APInt max =
- isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
- for (const APInt &left : lhs) {
- for (const APInt &right : rhs) {
- std::optional<APInt> maybeThisResult = op(left, right);
- if (!maybeThisResult)
- return ConstantIntRanges::maxRange(width);
- APInt result = std::move(*maybeThisResult);
- min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
- max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
- }
- }
- return ConstantIntRanges::range(min, max, isSigned);
-}
+using namespace mlir::intrange;
//===----------------------------------------------------------------------===//
// ConstantOp
@@ -78,25 +38,7 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn uadd = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.uadd_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstArithFn sadd = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.sadd_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
-
- ConstantIntRanges urange = computeBoundsBy(
- uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
- ConstantIntRanges srange = computeBoundsBy(
- sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferAdd(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -105,25 +47,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn usub = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.usub_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstArithFn ssub = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.ssub_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstantIntRanges urange = computeBoundsBy(
- usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
- ConstantIntRanges srange = computeBoundsBy(
- ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferSub(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -132,96 +56,25 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn umul = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.umul_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
- ConstArithFn smul = [](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.smul_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : result;
- };
-
- ConstantIntRanges urange =
- minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
- /*isSigned=*/false);
- ConstantIntRanges srange =
- minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
- /*isSigned=*/true);
-
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferMul(argRanges));
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
-/// Fix up division results (ex. for ceiling and floor), returning an APInt
-/// if there has been no overflow
-using DivisionFixupFn = function_ref<std::optional<APInt>(
- const APInt &lhs, const APInt &rhs, const APInt &result)>;
-
-static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs,
- DivisionFixupFn fixup) {
- const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
- &rhsMax = rhs.umax();
-
- if (!rhsMin.isZero()) {
- auto udiv = [&fixup](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- return fixup(a, b, a.udiv(b));
- };
- return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
- /*isSigned=*/false);
- }
- // Otherwise, it's possible we might divide by 0.
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
-}
-
void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferDivUIRange(argRanges[0], argRanges[1],
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) { return result; }));
+ setResultRange(getResult(), inferDivU(argRanges));
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs,
- DivisionFixupFn fixup) {
- const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
- &rhsMax = rhs.smax();
- bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
-
- if (canDivide) {
- auto sdiv = [&fixup](const APInt &a,
- const APInt &b) -> std::optional<APInt> {
- bool overflowed = false;
- APInt result = a.sdiv_ov(b, overflowed);
- return overflowed ? std::optional<APInt>() : fixup(a, b, result);
- };
- return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
- /*isSigned=*/true);
- }
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
-}
-
void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(),
- inferDivSIRange(argRanges[0], argRanges[1],
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) { return result; }));
+ setResultRange(getResult(), inferDivS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -230,20 +83,7 @@ void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::CeilDivUIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- DivisionFixupFn ceilDivUIFix =
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) -> std::optional<APInt> {
- if (!lhs.urem(rhs).isZero()) {
- bool overflowed = false;
- APInt corrected =
- result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
- return overflowed ? std::optional<APInt>() : corrected;
- }
- return result;
- };
- setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
+ setResultRange(getResult(), inferCeilDivU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -252,20 +92,7 @@ void arith::CeilDivUIOp::inferResultRanges(
void arith::CeilDivSIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- DivisionFixupFn ceilDivSIFix =
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) -> std::optional<APInt> {
- if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
- bool overflowed = false;
- APInt corrected =
- result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
- return overflowed ? std::optional<APInt>() : corrected;
- }
- return result;
- };
- setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
+ setResultRange(getResult(), inferCeilDivS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -274,20 +101,7 @@ void arith::CeilDivSIOp::inferResultRanges(
void arith::FloorDivSIOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- DivisionFixupFn floorDivSIFix =
- [](const APInt &lhs, const APInt &rhs,
- const APInt &result) -> std::optional<APInt> {
- if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
- bool overflowed = false;
- APInt corrected =
- result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
- return overflowed ? std::optional<APInt>() : corrected;
- }
- return result;
- };
- setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
+ return setResultRange(getResult(), inferFloorDivS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -296,29 +110,7 @@ void arith::FloorDivSIOp::inferResultRanges(
void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
-
- unsigned width = rhsMin.getBitWidth();
- APInt umin = APInt::getZero(width);
- APInt umax = APInt::getMaxValue(width);
-
- if (!rhsMin.isZero()) {
- umax = rhsMax - 1;
- // Special case: sweeping out a contiguous range in N/[modulus]
- if (rhsMin == rhsMax) {
- const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
- if ((lhsMax - lhsMin).ult(rhsMax)) {
- APInt minRem = lhsMin.urem(rhsMax);
- APInt maxRem = lhsMax.urem(rhsMax);
- if (minRem.ule(maxRem)) {
- umin = minRem;
- umax = maxRem;
- }
- }
- }
- }
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), inferRemU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -327,67 +119,16 @@ void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
- &rhsMax = rhs.smax();
-
- unsigned width = rhsMax.getBitWidth();
- APInt smin = APInt::getSignedMinValue(width);
- APInt smax = APInt::getSignedMaxValue(width);
- // No bounds if zero could be a divisor.
- bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
- if (canBound) {
- APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
- bool canNegativeDividend = lhsMin.isNegative();
- bool canPositiveDividend = lhsMax.isStrictlyPositive();
- APInt zero = APInt::getZero(maxDivisor.getBitWidth());
- APInt maxPositiveResult = maxDivisor - 1;
- APInt minNegativeResult = -maxPositiveResult;
- smin = canNegativeDividend ? minNegativeResult : zero;
- smax = canPositiveDividend ? maxPositiveResult : zero;
- // Special case: sweeping out a contiguous range in N/[modulus].
- if (rhsMin == rhsMax) {
- if ((lhsMax - lhsMin).ult(maxDivisor)) {
- APInt minRem = lhsMin.srem(maxDivisor);
- APInt maxRem = lhsMax.srem(maxDivisor);
- if (minRem.sle(maxRem)) {
- smin = minRem;
- smax = maxRem;
- }
- }
- }
- }
- setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+ setResultRange(getResult(), inferRemS(argRanges));
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
-/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
-/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
-/// that both bonuds have in common. This gives us a consertive approximation
-/// for what values can be passed to bitwise operations.
-static std::tuple<APInt, APInt>
-widenBitwiseBounds(const ConstantIntRanges &bound) {
- APInt leftVal = bound.umin(), rightVal = bound.umax();
- unsigned bitwidth = leftVal.getBitWidth();
- unsigned
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
- leftVal.clearLowBits(
diff eringBits);
- rightVal.setLowBits(
diff eringBits);
- return std::make_tuple(std::move(leftVal), std::move(rightVal));
-}
-
void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a & b;
- };
- setResultRange(getResult(),
- minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferAnd(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -396,14 +137,7 @@ void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a | b;
- };
- setResultRange(getResult(),
- minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferOr(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -412,14 +146,7 @@ void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a ^ b;
- };
- setResultRange(getResult(),
- minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferXor(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -428,11 +155,7 @@ void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
- const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
- setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+ setResultRange(getResult(), inferMaxS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -441,11 +164,7 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
- const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), inferMaxU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -454,11 +173,7 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
- const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
- setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
+ setResultRange(getResult(), inferMinS(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -467,94 +182,40 @@ void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
- const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
- setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
+ setResultRange(getResult(), inferMinU(argRanges));
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges extUIRange(const ConstantIntRanges &range,
- Type destType) {
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- APInt umin = range.umin().zext(destWidth);
- APInt umax = range.umax().zext(destWidth);
- return ConstantIntRanges::fromUnsigned(umin, umax);
-}
-
void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- Type destType = getResult().getType();
- setResultRange(getResult(), extUIRange(argRanges[0], destType));
+ unsigned destWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
- Type destType) {
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- APInt smin = range.smin().sext(destWidth);
- APInt smax = range.smax().sext(destWidth);
- return ConstantIntRanges::fromSigned(smin, smax);
-}
-
void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- Type destType = getResult().getType();
- setResultRange(getResult(), extSIRange(argRanges[0], destType));
+ unsigned destWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
-static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
- Type destType) {
- unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
- // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
- // the range of the resulting value is not contiguous ind includes 0.
- // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
- // but you can't truncate [255, 257] similarly.
- bool hasUnsignedRollover =
- range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
- APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
- : range.umin().trunc(destWidth);
- APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
- : range.umax().trunc(destWidth);
-
- // Signed post-truncation rollover will not occur when either:
- // - The high parts of the min and max, plus the sign bit, are the same
- // - The high halves + sign bit of the min and max are either all 1s or all 0s
- // and you won't create a [positive, negative] range by truncating.
- // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
- // but not [255, 257]_i16 to a range of i8s. You can also truncate
- // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
- // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
- // will truncate to 0x7e, which is greater than 0
- APInt sminHighPart = range.smin().ashr(destWidth - 1);
- APInt smaxHighPart = range.smax().ashr(destWidth - 1);
- bool hasSignedOverflow =
- (sminHighPart != smaxHighPart) &&
- !(sminHighPart.isAllOnes() &&
- (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
- !(sminHighPart.isZero() && smaxHighPart.isZero());
- APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
- : range.smin().trunc(destWidth);
- APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
- : range.smax().trunc(destWidth);
- return {umin, umax, smin, smax};
-}
-
void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- Type destType = getResult().getType();
- setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ unsigned destWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ setResultRange(getResult(), truncRange(argRanges[0], destWidth));
}
//===----------------------------------------------------------------------===//
@@ -569,9 +230,9 @@ void arith::IndexCastOp::inferResultRanges(
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extSIRange(argRanges[0], destType));
+ setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ setResultRange(getResult(), truncRange(argRanges[0], destWidth));
else
setResultRange(getResult(), argRanges[0]);
}
@@ -588,9 +249,9 @@ void arith::IndexCastUIOp::inferResultRanges(
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (srcWidth < destWidth)
- setResultRange(getResult(), extUIRange(argRanges[0], destType));
+ setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
else if (srcWidth > destWidth)
- setResultRange(getResult(), truncIRange(argRanges[0], destType));
+ setResultRange(getResult(), truncRange(argRanges[0], destWidth));
else
setResultRange(getResult(), argRanges[0]);
}
@@ -599,51 +260,19 @@ void arith::IndexCastUIOp::inferResultRanges(
// CmpIOp
//===----------------------------------------------------------------------===//
-bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs) {
- switch (pred) {
- case arith::CmpIPredicate::sle:
- case arith::CmpIPredicate::slt:
- return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
- case arith::CmpIPredicate::ule:
- case arith::CmpIPredicate::ult:
- return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
- case arith::CmpIPredicate::sge:
- case arith::CmpIPredicate::sgt:
- return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
- case arith::CmpIPredicate::uge:
- case arith::CmpIPredicate::ugt:
- return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
- case arith::CmpIPredicate::eq: {
- std::optional<APInt> lhsConst = lhs.getConstantValue();
- std::optional<APInt> rhsConst = rhs.getConstantValue();
- return lhsConst && rhsConst && lhsConst == rhsConst;
- }
- case arith::CmpIPredicate::ne: {
- // While equality requires that there is an interpration of the preceeding
- // computations that produces equal constants, whether that be signed or
- // unsigned, statically determining inequality requires that neither
- // interpretation produce potentially overlapping ranges.
- bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
- isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
- bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
- isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
- return sne && une;
- }
- }
- return false;
-}
-
void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- arith::CmpIPredicate pred = getPredicate();
+ arith::CmpIPredicate arithPred = getPredicate();
+ intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnesValue(1);
- if (isStaticallyTrue(pred, lhs, rhs))
+
+ Optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
+ if (truthValue.has_value() && *truthValue)
min = max;
- else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
+ else if (truthValue.has_value() && !(*truthValue))
max = min;
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
@@ -673,18 +302,7 @@ void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
- ConstArithFn shl = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
- };
- ConstantIntRanges urange =
- minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
- /*isSigned=*/false);
- ConstantIntRanges srange =
- minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
- /*isSigned=*/true);
- setResultRange(getResult(), urange.intersection(srange));
+ setResultRange(getResult(), inferShl(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -693,15 +311,7 @@ void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn lshr = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
- };
- setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
- {rhs.umin(), rhs.umax()},
- /*isSigned=*/false));
+ setResultRange(getResult(), inferShrU(argRanges));
}
//===----------------------------------------------------------------------===//
@@ -710,14 +320,5 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
-
- ConstArithFn ashr = [](const APInt &l,
- const APInt &r) -> std::optional<APInt> {
- return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
- };
-
- setResultRange(getResult(),
- minMaxBy(ashr, {lhs.smin(), lhs.smax()},
- {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
+ setResultRange(getResult(), inferShrS(argRanges));
}
diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt
index 53321f1ea3f25..fce47d2ecc531 100644
--- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIndexDialect
IndexAttrs.cpp
IndexDialect.cpp
IndexOps.cpp
+ InferIntRangeInterfaceImpls.cpp
DEPENDS
MLIRIndexOpsIncGen
@@ -10,6 +11,8 @@ add_mlir_dialect_library(MLIRIndexDialect
MLIRDialect
MLIRIR
MLIRCastInterfaces
+ MLIRInferIntRangeCommon
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
new file mode 100644
index 0000000000000..6daa7640b017e
--- /dev/null
+++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
@@ -0,0 +1,252 @@
+//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+using namespace mlir::index;
+using namespace mlir::intrange;
+
+//===----------------------------------------------------------------------===//
+// Constants
+//===----------------------------------------------------------------------===//
+
+void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ const APInt &value = getValue();
+ setResultRange(getResult(), ConstantIntRanges::constant(value));
+}
+
+void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ bool value = getValue();
+ APInt asInt(/*numBits=*/1, value);
+ setResultRange(getResult(), ConstantIntRanges::constant(asInt));
+}
+
+//===----------------------------------------------------------------------===//
+// Arithmec operations. All of these operations will have their results inferred
+// using both the 64-bit values and truncated 32-bit values of their inputs,
+// with the results being the union of those inferences, except where the
+// truncation of the 64-bit result is equal to the 32-bit result (at which time
+// we take the 64-bit result).
+//===----------------------------------------------------------------------===//
+
+void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
+}
+
+void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
+}
+
+void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
+}
+
+void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
+}
+
+void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
+}
+
+void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
+}
+
+void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
+}
+
+void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ return setResultRange(
+ getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
+}
+
+void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
+}
+
+void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
+}
+
+void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
+}
+
+void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
+}
+
+void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
+}
+
+void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
+}
+
+void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
+}
+
+void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
+}
+
+void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
+}
+
+void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
+}
+
+void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
+}
+
+void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(),
+ inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
+}
+
+//===----------------------------------------------------------------------===//
+// Casts
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
+ unsigned srcWidth, unsigned destWidth,
+ bool isSigned) {
+ if (srcWidth < destWidth)
+ return isSigned ? extSIRange(range, destWidth)
+ : extUIRange(range, destWidth);
+ if (srcWidth > destWidth)
+ return truncRange(range, destWidth);
+ return range;
+}
+
+// When casting to `index`, we will take the union of the possible fixed-width
+// casts.
+static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
+ Type sourceType, Type destType,
+ bool isSigned) {
+ unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
+ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
+ if (sourceType.isIndex())
+ return makeLikeDest(range, srcWidth, destWidth, isSigned);
+ // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
+ ConstantIntRanges storageRange =
+ makeLikeDest(range, srcWidth, destWidth, isSigned);
+ ConstantIntRanges minWidthRange =
+ makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
+ ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
+ ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
+ return ret;
+}
+
+void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
+ setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
+ /*isSigned=*/true));
+}
+
+void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ Type sourceType = getOperand().getType();
+ Type destType = getResult().getType();
+ setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
+ /*isSigned=*/false));
+}
+
+//===----------------------------------------------------------------------===//
+// CmpOp
+//===----------------------------------------------------------------------===//
+
+void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ index::IndexCmpPredicate indexPred = getPred();
+ intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ APInt min = APInt::getZero(1);
+ APInt max = APInt::getAllOnesValue(1);
+
+ Optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
+
+ ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
+ rhsTrunc = truncRange(rhs, indexMinWidth);
+ Optional<bool> truthValue32 =
+ intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
+
+ if (truthValue64 == truthValue32) {
+ if (truthValue64.has_value() && *truthValue64)
+ min = max;
+ else if (truthValue64.has_value() && !(*truthValue64))
+ max = min;
+ }
+ setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+}
+
+//===----------------------------------------------------------------------===//
+// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
+//===----------------------------------------------------------------------===//
+
+void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRange) {
+ unsigned storageWidth =
+ ConstantIntRanges::getStorageBitwidth(getResult().getType());
+ APInt min(/*numBits=*/storageWidth, indexMinWidth);
+ APInt max(/*numBits=*/storageWidth, indexMaxWidth);
+ setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index a7cdbb5b3a6fe..38ad0e4a2231c 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -51,3 +51,5 @@ add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)
add_mlir_interface_library(VectorInterfaces)
add_mlir_interface_library(ViewLikeInterface)
+
+add_subdirectory(Utils)
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..ece6c8e46ffea
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_library(MLIRInferIntRangeCommon
+ InferIntRangeCommon.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
+
+ DEPENDS
+ MLIRInferIntRangeInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRInferIntRangeInterface
+ MLIRIR
+)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
new file mode 100644
index 0000000000000..c81f004ecf5f9
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -0,0 +1,663 @@
+//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementations of range inference for operations that are
+// common to both the `arith` and `index` dialects to facilitate reuse.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+
+#include "llvm/Support/Debug.h"
+
+#include <iterator>
+#include <optional>
+
+using namespace mlir;
+
+#define DEBUG_TYPE "int-range-analysis"
+
+//===----------------------------------------------------------------------===//
+// General utilities
+//===----------------------------------------------------------------------===//
+
+/// Function that evaluates the result of doing something on arithmetic
+/// constants and returns std::nullopt on overflow.
+using ConstArithFn =
+ function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
+
+/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
+/// If either computation overflows, make the result unbounded.
+static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
+ const APInt &minRight,
+ const APInt &maxLeft,
+ const APInt &maxRight, bool isSigned) {
+ std::optional<APInt> maybeMin = op(minLeft, minRight);
+ std::optional<APInt> maybeMax = op(maxLeft, maxRight);
+ if (maybeMin && maybeMax)
+ return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
+ return ConstantIntRanges::maxRange(minLeft.getBitWidth());
+}
+
+/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
+/// ignoring unbounded values. Returns the maximal range if `op` overflows.
+static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
+ ArrayRef<APInt> rhs, bool isSigned) {
+ unsigned width = lhs[0].getBitWidth();
+ APInt min =
+ isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
+ APInt max =
+ isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
+ for (const APInt &left : lhs) {
+ for (const APInt &right : rhs) {
+ std::optional<APInt> maybeThisResult = op(left, right);
+ if (!maybeThisResult)
+ return ConstantIntRanges::maxRange(width);
+ APInt result = std::move(*maybeThisResult);
+ min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
+ max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
+ }
+ }
+ return ConstantIntRanges::range(min, max, isSigned);
+}
+
+//===----------------------------------------------------------------------===//
+// Ext, trunc, index op handling
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferIndexOp(InferRangeFn inferFn,
+ ArrayRef<ConstantIntRanges> argRanges,
+ intrange::CmpMode mode) {
+ ConstantIntRanges sixtyFour = inferFn(argRanges);
+ SmallVector<ConstantIntRanges, 2> truncated;
+ llvm::transform(argRanges, std::back_inserter(truncated),
+ [](const ConstantIntRanges &range) {
+ return truncRange(range, /*destWidth=*/indexMinWidth);
+ });
+ ConstantIntRanges thirtyTwo = inferFn(truncated);
+ ConstantIntRanges thirtyTwoAsSixtyFour =
+ extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
+ ConstantIntRanges sixtyFourAsThirtyTwo =
+ truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
+
+ LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
+ << " 32-bit = " << thirtyTwo << "\n");
+ bool truncEqual = false;
+ switch (mode) {
+ case intrange::CmpMode::Both:
+ truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
+ break;
+ case intrange::CmpMode::Signed:
+ truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
+ thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
+ break;
+ case intrange::CmpMode::Unsigned:
+ truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
+ thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
+ break;
+ }
+ if (truncEqual)
+ // Returing the 64-bit result preserves more information.
+ return sixtyFour;
+ ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
+ return merged;
+}
+
+ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
+ unsigned int destWidth) {
+ APInt umin = range.umin().zext(destWidth);
+ APInt umax = range.umax().zext(destWidth);
+ APInt smin = range.smin().sext(destWidth);
+ APInt smax = range.smax().sext(destWidth);
+ return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
+ unsigned destWidth) {
+ APInt umin = range.umin().zext(destWidth);
+ APInt umax = range.umax().zext(destWidth);
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
+ unsigned destWidth) {
+ APInt smin = range.smin().sext(destWidth);
+ APInt smax = range.smax().sext(destWidth);
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
+ unsigned int destWidth) {
+ // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
+ // the range of the resulting value is not contiguous ind includes 0.
+ // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
+ // but you can't truncate [255, 257] similarly.
+ bool hasUnsignedRollover =
+ range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
+ APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
+ : range.umin().trunc(destWidth);
+ APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
+ : range.umax().trunc(destWidth);
+
+ // Signed post-truncation rollover will not occur when either:
+ // - The high parts of the min and max, plus the sign bit, are the same
+ // - The high halves + sign bit of the min and max are either all 1s or all 0s
+ // and you won't create a [positive, negative] range by truncating.
+ // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
+ // but not [255, 257]_i16 to a range of i8s. You can also truncate
+ // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
+ // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
+ // will truncate to 0x7e, which is greater than 0
+ APInt sminHighPart = range.smin().ashr(destWidth - 1);
+ APInt smaxHighPart = range.smax().ashr(destWidth - 1);
+ bool hasSignedOverflow =
+ (sminHighPart != smaxHighPart) &&
+ !(sminHighPart.isAllOnes() &&
+ (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
+ !(sminHighPart.isZero() && smaxHighPart.isZero());
+ APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
+ : range.smin().trunc(destWidth);
+ APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
+ : range.smax().trunc(destWidth);
+ return {umin, umax, smin, smax};
+}
+
+//===----------------------------------------------------------------------===//
+// Addition
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ ConstArithFn uadd = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.uadd_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn sadd = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.sadd_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+
+ ConstantIntRanges urange = computeBoundsBy(
+ uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
+ ConstantIntRanges srange = computeBoundsBy(
+ sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+//===----------------------------------------------------------------------===//
+// Subtraction
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn usub = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.usub_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn ssub = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.ssub_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstantIntRanges urange = computeBoundsBy(
+ usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
+ ConstantIntRanges srange = computeBoundsBy(
+ ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+//===----------------------------------------------------------------------===//
+// Multiplication
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn umul = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.umul_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+ ConstArithFn smul = [](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.smul_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : result;
+ };
+
+ ConstantIntRanges urange =
+ minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/false);
+ ConstantIntRanges srange =
+ minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
+ /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+//===----------------------------------------------------------------------===//
+// DivU, CeilDivU (Unsigned division)
+//===----------------------------------------------------------------------===//
+
+/// Fix up division results (ex. for ceiling and floor), returning an APInt
+/// if there has been no overflow
+using DivisionFixupFn = function_ref<std::optional<APInt>(
+ const APInt &lhs, const APInt &rhs, const APInt &result)>;
+
+static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs,
+ DivisionFixupFn fixup) {
+ const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
+ &rhsMax = rhs.umax();
+
+ if (!rhsMin.isZero()) {
+ auto udiv = [&fixup](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ return fixup(a, b, a.udiv(b));
+ };
+ return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+ /*isSigned=*/false);
+ }
+ // Otherwise, it's possible we might divide by 0.
+ return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
+ConstantIntRanges
+mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) {
+ return inferDivURange(argRanges[0], argRanges[1],
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) { return result; });
+}
+
+ConstantIntRanges
+mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ DivisionFixupFn ceilDivUIFix =
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) -> std::optional<APInt> {
+ if (!lhs.urem(rhs).isZero()) {
+ bool overflowed = false;
+ APInt corrected =
+ result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+ return overflowed ? std::optional<APInt>() : corrected;
+ }
+ return result;
+ };
+ return inferDivURange(lhs, rhs, ceilDivUIFix);
+}
+
+//===----------------------------------------------------------------------===//
+// DivS, CeilDivS, FloorDivS (Signed division)
+//===----------------------------------------------------------------------===//
+
+static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs,
+ DivisionFixupFn fixup) {
+ const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+ &rhsMax = rhs.smax();
+ bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
+
+ if (canDivide) {
+ auto sdiv = [&fixup](const APInt &a,
+ const APInt &b) -> std::optional<APInt> {
+ bool overflowed = false;
+ APInt result = a.sdiv_ov(b, overflowed);
+ return overflowed ? std::optional<APInt>() : fixup(a, b, result);
+ };
+ return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+ /*isSigned=*/true);
+ }
+ return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+}
+
+ConstantIntRanges
+mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) {
+ return inferDivSRange(argRanges[0], argRanges[1],
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) { return result; });
+}
+
+ConstantIntRanges
+mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ DivisionFixupFn ceilDivSIFix =
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) -> std::optional<APInt> {
+ if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
+ bool overflowed = false;
+ APInt corrected =
+ result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
+ return overflowed ? std::optional<APInt>() : corrected;
+ }
+ return result;
+ };
+ return inferDivSRange(lhs, rhs, ceilDivSIFix);
+}
+
+ConstantIntRanges
+mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ DivisionFixupFn floorDivSIFix =
+ [](const APInt &lhs, const APInt &rhs,
+ const APInt &result) -> std::optional<APInt> {
+ if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
+ bool overflowed = false;
+ APInt corrected =
+ result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
+ return overflowed ? std::optional<APInt>() : corrected;
+ }
+ return result;
+ };
+ return inferDivSRange(lhs, rhs, floorDivSIFix);
+}
+
+//===----------------------------------------------------------------------===//
+// Signed remainder (RemS)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
+ &rhsMax = rhs.smax();
+
+ unsigned width = rhsMax.getBitWidth();
+ APInt smin = APInt::getSignedMinValue(width);
+ APInt smax = APInt::getSignedMaxValue(width);
+ // No bounds if zero could be a divisor.
+ bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
+ if (canBound) {
+ APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
+ bool canNegativeDividend = lhsMin.isNegative();
+ bool canPositiveDividend = lhsMax.isStrictlyPositive();
+ APInt zero = APInt::getZero(maxDivisor.getBitWidth());
+ APInt maxPositiveResult = maxDivisor - 1;
+ APInt minNegativeResult = -maxPositiveResult;
+ smin = canNegativeDividend ? minNegativeResult : zero;
+ smax = canPositiveDividend ? maxPositiveResult : zero;
+ // Special case: sweeping out a contiguous range in N/[modulus].
+ if (rhsMin == rhsMax) {
+ if ((lhsMax - lhsMin).ult(maxDivisor)) {
+ APInt minRem = lhsMin.srem(maxDivisor);
+ APInt maxRem = lhsMax.srem(maxDivisor);
+ if (minRem.sle(maxRem)) {
+ smin = minRem;
+ smax = maxRem;
+ }
+ }
+ }
+ }
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+//===----------------------------------------------------------------------===//
+// Unsigned remainder (RemU)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
+
+ unsigned width = rhsMin.getBitWidth();
+ APInt umin = APInt::getZero(width);
+ APInt umax = APInt::getMaxValue(width);
+
+ if (!rhsMin.isZero()) {
+ umax = rhsMax - 1;
+ // Special case: sweeping out a contiguous range in N/[modulus]
+ if (rhsMin == rhsMax) {
+ const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
+ if ((lhsMax - lhsMin).ult(rhsMax)) {
+ APInt minRem = lhsMin.urem(rhsMax);
+ APInt maxRem = lhsMax.urem(rhsMax);
+ if (minRem.ule(maxRem)) {
+ umin = minRem;
+ umax = maxRem;
+ }
+ }
+ }
+ }
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+//===----------------------------------------------------------------------===//
+// Max and min (MaxS, MaxU, MinS, MinU)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
+ const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+ConstantIntRanges
+mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
+ const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+ConstantIntRanges
+mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
+ const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
+ return ConstantIntRanges::fromSigned(smin, smax);
+}
+
+ConstantIntRanges
+mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
+ const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
+ return ConstantIntRanges::fromUnsigned(umin, umax);
+}
+
+//===----------------------------------------------------------------------===//
+// Bitwise operators (And, Or, Xor)
+//===----------------------------------------------------------------------===//
+
+/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
+/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
+/// that both bonuds have in common. This gives us a consertive approximation
+/// for what values can be passed to bitwise operations.
+static std::tuple<APInt, APInt>
+widenBitwiseBounds(const ConstantIntRanges &bound) {
+ APInt leftVal = bound.umin(), rightVal = bound.umax();
+ unsigned bitwidth = leftVal.getBitWidth();
+ unsigned
diff eringBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
+ leftVal.clearLowBits(
diff eringBits);
+ rightVal.setLowBits(
diff eringBits);
+ return std::make_tuple(std::move(leftVal), std::move(rightVal));
+}
+
+ConstantIntRanges
+mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) {
+ auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+ auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+ auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+ return a & b;
+ };
+ return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+ /*isSigned=*/false);
+}
+
+ConstantIntRanges
+mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
+ auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+ auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+ auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+ return a | b;
+ };
+ return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+ /*isSigned=*/false);
+}
+
+ConstantIntRanges
+mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
+ auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
+ auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
+ auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
+ return a ^ b;
+ };
+ return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
+ /*isSigned=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// Shifts (Shl, ShrS, ShrU)
+//===----------------------------------------------------------------------===//
+
+ConstantIntRanges
+mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ ConstArithFn shl = [](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.shl(r);
+ };
+ ConstantIntRanges urange =
+ minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/false);
+ ConstantIntRanges srange =
+ minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/true);
+ return urange.intersection(srange);
+}
+
+ConstantIntRanges
+mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn ashr = [](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
+ };
+
+ return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/true);
+}
+
+ConstantIntRanges
+mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) {
+ const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+
+ ConstArithFn lshr = [](const APInt &l,
+ const APInt &r) -> std::optional<APInt> {
+ return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
+ };
+ return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
+ /*isSigned=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// Comparisons (Cmp)
+//===----------------------------------------------------------------------===//
+
+static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) {
+ switch (pred) {
+ case intrange::CmpPredicate::eq:
+ return intrange::CmpPredicate::ne;
+ case intrange::CmpPredicate::ne:
+ return intrange::CmpPredicate::eq;
+ case intrange::CmpPredicate::slt:
+ return intrange::CmpPredicate::sge;
+ case intrange::CmpPredicate::sle:
+ return intrange::CmpPredicate::sgt;
+ case intrange::CmpPredicate::sgt:
+ return intrange::CmpPredicate::sle;
+ case intrange::CmpPredicate::sge:
+ return intrange::CmpPredicate::slt;
+ case intrange::CmpPredicate::ult:
+ return intrange::CmpPredicate::uge;
+ case intrange::CmpPredicate::ule:
+ return intrange::CmpPredicate::ugt;
+ case intrange::CmpPredicate::ugt:
+ return intrange::CmpPredicate::ule;
+ case intrange::CmpPredicate::uge:
+ return intrange::CmpPredicate::ult;
+ }
+ llvm_unreachable("unknown cmp predicate value");
+}
+
+static bool isStaticallyTrue(intrange::CmpPredicate pred,
+ const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs) {
+ switch (pred) {
+ case intrange::CmpPredicate::sle:
+ return lhs.smax().sle(rhs.smin());
+ case intrange::CmpPredicate::slt:
+ return lhs.smax().slt(rhs.smin());
+ case intrange::CmpPredicate::ule:
+ return lhs.umax().ule(rhs.umin());
+ case intrange::CmpPredicate::ult:
+ return lhs.umax().ult(rhs.umin());
+ case intrange::CmpPredicate::sge:
+ return lhs.smin().sge(rhs.smax());
+ case intrange::CmpPredicate::sgt:
+ return lhs.smin().sgt(rhs.smax());
+ case intrange::CmpPredicate::uge:
+ return lhs.umin().uge(rhs.umax());
+ case intrange::CmpPredicate::ugt:
+ return lhs.umin().ugt(rhs.umax());
+ case intrange::CmpPredicate::eq: {
+ std::optional<APInt> lhsConst = lhs.getConstantValue();
+ std::optional<APInt> rhsConst = rhs.getConstantValue();
+ return lhsConst && rhsConst && lhsConst == rhsConst;
+ }
+ case intrange::CmpPredicate::ne: {
+ // While equality requires that there is an interpration of the preceeding
+ // computations that produces equal constants, whether that be signed or
+ // unsigned, statically determining inequality requires that neither
+ // interpretation produce potentially overlapping ranges.
+ bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
+ isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs);
+ bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
+ isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs);
+ return sne && une;
+ }
+ }
+ return false;
+}
+
+std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
+ const ConstantIntRanges &lhs,
+ const ConstantIntRanges &rhs) {
+ if (isStaticallyTrue(pred, lhs, rhs))
+ return true;
+ if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
+ return false;
+ return std::nullopt;
+}
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
new file mode 100644
index 0000000000000..2784d5fd5cf70
--- /dev/null
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+
+// Most operations are covered by the `arith` tests, which use the same code
+// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
+// code is operating as expected.
+
+// CHECK-LABEL: func @add_same_for_both
+// CHECK: %[[true:.*]] = index.bool.constant true
+// CHECK: return %[[true]]
+func.func @add_same_for_both(%arg0 : index) -> i1 {
+ %c1 = index.constant 1
+ %calmostBig = index.constant 0xfffffffe
+ %0 = index.minu %arg0, %calmostBig
+ %1 = index.add %0, %c1
+ %2 = index.cmp uge(%1, %c1)
+ func.return %2 : i1
+}
+
+// CHECK-LABEL: func @add_unsigned_ov
+// CHECK: %[[uge:.*]] = index.cmp uge
+// CHECK: return %[[uge]]
+func.func @add_unsigned_ov(%arg0 : index) -> i1 {
+ %c1 = index.constant 1
+ %cu32_max = index.constant 0xffffffff
+ %0 = index.minu %arg0, %cu32_max
+ %1 = index.add %0, %c1
+ // On 32-bit, the add could wrap, so the result doesn't have to be >= 1
+ %2 = index.cmp uge(%1, %c1)
+ func.return %2 : i1
+}
+
+// CHECK-LABEL: func @add_signed_ov
+// CHECK: %[[sge:.*]] = index.cmp sge
+// CHECK: return %[[sge]]
+func.func @add_signed_ov(%arg0 : index) -> i1 {
+ %c0 = index.constant 0
+ %c1 = index.constant 1
+ %ci32_max = index.constant 0x7fffffff
+ %0 = index.minu %arg0, %ci32_max
+ %1 = index.add %0, %c1
+ // On 32-bit, the add could wrap, so the result doesn't have to be positive
+ %2 = index.cmp sge(%1, %c0)
+ func.return %2 : i1
+}
+
+// CHECK-LABEL: func @add_big
+// CHECK: %[[true:.*]] = index.bool.constant true
+// CHECK: return %[[true]]
+func.func @add_big(%arg0 : index) -> i1 {
+ %c1 = index.constant 1
+ %cmin = index.constant 0x300000000
+ %cmax = index.constant 0x30000ffff
+ // Note: the order of the clamps matters.
+ // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff]
+ // and then [0x30...0000, 0x30...ffff]
+ // If you switch the order of the below operations, you instead first infer
+ // the range [0,0x3...ffff]. Then, the min inference can't constraint
+ // this intermediate, since in the 32-bit case we could have, for example
+ // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff
+ // which means we can't do any inference.
+ %0 = index.maxu %arg0, %cmin
+ %1 = index.minu %0, %cmax
+ %2 = index.add %1, %c1
+ %3 = index.cmp uge(%1, %cmin)
+ func.return %3 : i1
+}
More information about the Mlir-commits
mailing list