[Mlir-commits] [mlir] 95aff23 - Re-land "[mlir] Add integer range inference analysis""
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri Jun 3 10:13:53 PDT 2022
Author: Krzysztof Drewniak
Date: 2022-06-03T17:13:48Z
New Revision: 95aff23e29214543360d893f9a61df0ebd1b65d2
URL: https://github.com/llvm/llvm-project/commit/95aff23e29214543360d893f9a61df0ebd1b65d2
DIFF: https://github.com/llvm/llvm-project/commit/95aff23e29214543360d893f9a61df0ebd1b65d2.diff
LOG: Re-land "[mlir] Add integer range inference analysis""
This reverts commit 4e5ce2056e3e85f109a074e80bdd23a10ca2bed9.
This relands commit 1350c9887dca5ba80af8e3c1e61b29d6696eb240.
Reinstates the range analysis with the build issue fixed.
Differential Revision: https://reviews.llvm.org/D126926
Added:
mlir/include/mlir/Analysis/IntRangeAnalysis.h
mlir/include/mlir/Interfaces/InferIntRangeInterface.h
mlir/include/mlir/Interfaces/InferIntRangeInterface.td
mlir/lib/Analysis/IntRangeAnalysis.cpp
mlir/lib/Interfaces/InferIntRangeInterface.cpp
mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
mlir/test/lib/Transforms/TestIntRangeInference.cpp
mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
Modified:
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Analysis/DataFlowAnalysis.cpp
mlir/lib/Interfaces/CMakeLists.txt
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
mlir/unittests/Interfaces/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
new file mode 100644
index 0000000000000..b2b604359b48b
--- /dev/null
+++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
@@ -0,0 +1,41 @@
+//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer range inference
+// so that it can be used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
+#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+namespace detail {
+class IntRangeAnalysisImpl;
+} // end namespace detail
+
+class IntRangeAnalysis {
+public:
+ /// Analyze all operations rooted under (but not including)
+ /// `topLevelOperation`.
+ IntRangeAnalysis(Operation *topLevelOperation);
+ IntRangeAnalysis(IntRangeAnalysis &&other);
+ ~IntRangeAnalysis();
+
+ /// Get inferred range for value `v` if one exists.
+ Optional<ConstantIntRanges> getResult(Value v);
+
+private:
+ std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
+};
+} // end namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index cf075f728fdbf..918f3ea398e47 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
+add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(SideEffectInterfaces)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
new file mode 100644
index 0000000000000..9a393855d05ff
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -0,0 +1,98 @@
+//===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the integer range inference interface
+// defined in `InferIntRange.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
+#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+/// A set of arbitrary-precision integers representing bounds on a given integer
+/// value. These bounds are inclusive on both ends, so
+/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
+/// the unsigned and signed interpretations of values in order to enable more
+/// precice inference of the interplay between operations with signed and
+/// unsigned semantics.
+class ConstantIntRanges {
+public:
+ /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
+ /// Non-integer values should be bounded by APInts of bitwidth 0.
+ ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
+ const APInt &smax)
+ : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
+ assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
+ umaxVal.getBitWidth() == sminVal.getBitWidth() &&
+ sminVal.getBitWidth() == smaxVal.getBitWidth() &&
+ "All bounds in the ranges must have the same bitwidth");
+ }
+
+ bool operator==(const ConstantIntRanges &other) const;
+
+ /// The minimum value of an integer when it is interpreted as unsigned.
+ const APInt &umin() const;
+
+ /// The maximum value of an integer when it is interpreted as unsigned.
+ const APInt &umax() const;
+
+ /// The minimum value of an integer when it is interpreted as signed.
+ const APInt &smin() const;
+
+ /// The maximum value of an integer when it is interpreted as signed.
+ const APInt &smax() const;
+
+ /// Return the bitwidth that should be used for integer ranges describing
+ /// `type`. For concrete integer types, this is their bitwidth, for `index`,
+ /// this is the internal storage bitwidth of `index` attributes, and for
+ /// non-integer types this is 0.
+ static unsigned getStorageBitwidth(Type type);
+
+ /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned
+ /// minimum and `max` is both the signed and unsigned maximum.
+ static ConstantIntRanges range(const APInt &min, const APInt &max);
+
+ /// Create an `IntRangeAttrs` with the signed minimum and maximum equal
+ /// to `smin` and `smax`, where the unsigned bounds are constructed from the
+ /// signed ones if they correspond to a contigious range of bit patterns when
+ /// viewed as unsigned values and are left at [0, int_max()] otherwise.
+ static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
+
+ /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal
+ /// to `umin` and `umax` and the signed part equal to `umin` and `umax`
+ /// unless the sign bit changes between the minimum and maximum.
+ static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
+
+ /// Returns the union (computed separately for signed and unsigned bounds)
+ /// of `a` and `b`.
+ ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
+
+ /// If either the signed or unsigned interpretations of the range
+ /// indicate that the value it bounds is a constant, return that constant
+ /// value.
+ Optional<APInt> getConstantValue() const;
+
+ friend raw_ostream &operator<<(raw_ostream &os,
+ const ConstantIntRanges &range);
+
+private:
+ APInt uminVal, umaxVal, sminVal, smaxVal;
+};
+
+/// The type of the `setResultRanges` callback provided to ops implementing
+/// InferIntRangeInterface. It should be called once for each integer result
+/// value and be passed the ConstantIntRanges corresponding to that value.
+using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
new file mode 100644
index 0000000000000..57f8d693b7916
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -0,0 +1,52 @@
+//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-----------------------------------------------------===//
+//
+// Defines the interface for range analysis on scalar integers
+//
+//===-----------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE
+#define MLIR_INTERFACES_INFERINTRANGEINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
+ let description = [{
+ Allows operations to participate in range analysis for scalar integer values by
+ providing a methods that allows them to specify lower and upper bounds on their
+ result(s) given lower and upper bounds on their input(s) if known.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Infer the bounds on the results of this op given the bounds on its arguments.
+ For each result value or block argument (that isn't a branch argument,
+ since the dataflow analysis handles those case), the method should call
+ `setValueRange` with that `Value` as an argument. When `setValueRange`
+ is not called for some value, it will recieve a default value of the mimimum
+ and maximum values forits type (the unbounded range).
+
+ When called on an op that also implements the RegionBranchOpInterface
+ or BranchOpInterface, this method should not attempt to infer the values
+ of the branch results, as this will be handled by the analyses that use
+ this interface.
+
+ This function will only be called when at least one result of the op is a
+ scalar integer value or the op has a region.
+
+ `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
+ order. Non-integer arguments will have the an unbounded range of width-0
+ APInts in their `argRanges` element.
+ }],
+ "void", "inferResultRanges", (ins
+ "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+ "::mlir::SetIntRangeFn":$setResultRanges)
+ >];
+}
+#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 29314ed535931..6c45e40efa9ab 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
CallGraph.cpp
DataFlowAnalysis.cpp
DataLayoutAnalysis.cpp
+ IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
@@ -16,6 +17,7 @@ add_mlir_library(MLIRAnalysis
CallGraph.cpp
DataFlowAnalysis.cpp
DataLayoutAnalysis.cpp
+ IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
@@ -31,7 +33,9 @@ add_mlir_library(MLIRAnalysis
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
+ MLIRLoopLikeInterface
MLIRSideEffectInterfaces
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index 9c10595dbb00a..239d9e4060bca 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -359,11 +359,20 @@ void ForwardDataFlowSolver::visitOperation(Operation *op) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
return visitRegionBranchOperation(branch, operandLattices);
- // If we can't, conservatively mark all regions as executable.
- // TODO: Let the `visitOperation` method decide how to propagate
- // information to the block arguments.
- for (Region ®ion : op->getRegions())
- markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true);
+ for (Region ®ion : op->getRegions()) {
+ analysis.visitNonControlFlowArguments(op, RegionSuccessor(®ion),
+ operandLattices);
+ // `visitNonControlFlowArguments` is required to define all of the region
+ // argument lattices.
+ assert(llvm::none_of(
+ region.getArguments(),
+ [&](Value value) {
+ return analysis.getLatticeElement(value).isUninitialized();
+ }) &&
+ "expected `visitNonControlFlowArguments` to define all argument "
+ "lattices");
+ markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/false);
+ }
}
// If this op produces no results, it can't produce any constants.
@@ -567,12 +576,45 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
return;
+ // If the branch is a RegionBranchTerminatorOpInterface,
+ // construct the set of operand lattices as the set of non control-flow
+ // arguments of the parent and the values this op returns. This allows
+ // for the correct lattices to be passed to getSuccessorsForOperands()
+ // in cases such as scf.while.
+ ArrayRef<AbstractLatticeElement *> branchOpLattices = operandLattices;
+ SmallVector<AbstractLatticeElement *, 0> parentLattices;
+ if (auto regionTerminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ parentLattices.reserve(regionInterface->getNumOperands());
+ for (Value parentOperand : regionInterface->getOperands()) {
+ AbstractLatticeElement *operandLattice =
+ analysis.lookupLatticeElement(parentOperand);
+ if (!operandLattice || operandLattice->isUninitialized())
+ return;
+ parentLattices.push_back(operandLattice);
+ }
+ unsigned regionNumber = parentRegion->getRegionNumber();
+ OperandRange iterArgs =
+ regionInterface.getSuccessorEntryOperands(regionNumber);
+ OperandRange terminatorArgs =
+ regionTerminator.getSuccessorOperands(regionNumber);
+ assert(iterArgs.size() == terminatorArgs.size() &&
+ "Number of iteration arguments for region should equal number of "
+ "those arguments defined by terminator");
+ if (!iterArgs.empty()) {
+ unsigned iterStart = iterArgs.getBeginOperandIndex();
+ unsigned terminatorStart = terminatorArgs.getBeginOperandIndex();
+ for (unsigned i = 0, e = iterArgs.size(); i < e; ++i)
+ parentLattices[iterStart + i] = operandLattices[terminatorStart + i];
+ }
+ branchOpLattices = parentLattices;
+ }
// Query the set of successors of the current region using the current
// optimistic lattice state.
SmallVector<RegionSuccessor, 1> regionSuccessors;
analysis.getSuccessorsForOperands(regionInterface,
parentRegion->getRegionNumber(),
- operandLattices, regionSuccessors);
+ branchOpLattices, regionSuccessors);
if (regionSuccessors.empty())
return;
@@ -584,7 +626,7 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
// region index (if any).
return *getRegionBranchSuccessorOperands(op, regionIndex);
};
- return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
+ return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices,
getOperands);
}
diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
new file mode 100644
index 0000000000000..7e6d61ff89560
--- /dev/null
+++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp
@@ -0,0 +1,325 @@
+//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+
+namespace {
+/// A wrapper around ConstantIntRanges that provides the lattice functions
+/// expected by dataflow analysis.
+struct IntRangeLattice {
+ IntRangeLattice(const ConstantIntRanges &value) : value(value){};
+ IntRangeLattice(ConstantIntRanges &&value) : value(value){};
+
+ bool operator==(const IntRangeLattice &other) const {
+ return value == other.value;
+ }
+
+ /// wrapper around rangeUnion()
+ static IntRangeLattice join(const IntRangeLattice &a,
+ const IntRangeLattice &b) {
+ return a.value.rangeUnion(b.value);
+ }
+
+ /// Creates a range with bitwidth 0 to represent that we don't know if the
+ /// value being marked overdefined is even an integer.
+ static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
+ APInt noIntValue = APInt::getZeroWidth();
+ return ConstantIntRanges::range(noIntValue, noIntValue);
+ }
+
+ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+ /// range that is used to mark the value v as unable to be analyzed further,
+ /// where t is the type of v.
+ static IntRangeLattice getPessimisticValueState(Value v) {
+ unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
+ APInt umin = APInt::getMinValue(width);
+ APInt umax = APInt::getMaxValue(width);
+ APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+ APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+ return ConstantIntRanges{umin, umax, smin, smax};
+ }
+
+ ConstantIntRanges value;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace detail {
+class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
+ using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
+
+public:
+ /// Define bounds on the results or block arguments of the operation
+ /// based on the bounds on the arguments given in `operands`
+ ChangeResult
+ visitOperation(Operation *op,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
+
+ /// Skip regions of branch ops when we can statically infer constant
+ /// values for operands to the branch op and said op tells us it's safe to do
+ /// so.
+ LogicalResult
+ getSuccessorsForOperands(BranchOpInterface branch,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+ SmallVectorImpl<Block *> &successors) final;
+
+ /// Skip regions of branch or loop ops when we can statically infer constant
+ /// values for operands to the branch op and said op tells us it's safe to do
+ /// so.
+ void
+ getSuccessorsForOperands(RegionBranchOpInterface branch,
+ Optional<unsigned> sourceIndex,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) final;
+
+ /// Call the InferIntRangeInterface implementation for region-using ops
+ /// that implement it, and infer the bounds of loop induction variables
+ /// for ops that implement LoopLikeOPInterface.
+ ChangeResult visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor ®ion,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
+};
+} // end namespace detail
+} // end namespace mlir
+
+/// Given the results of getConstant{Lower,Upper}Bound()
+/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
+/// that result if possible.
+static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
+ Type boundType,
+ detail::IntRangeAnalysisImpl &analysis,
+ bool getUpper) {
+ unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
+ if (loopBound.hasValue()) {
+ if (loopBound->is<Attribute>()) {
+ if (auto bound =
+ loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
+ return bound.getValue();
+ } else if (loopBound->is<Value>()) {
+ LatticeElement<IntRangeLattice> *lattice =
+ analysis.lookupLatticeElement(loopBound->get<Value>());
+ if (lattice != nullptr)
+ return getUpper ? lattice->getValue().value.smax()
+ : lattice->getValue().value.smin();
+ }
+ }
+ return getUpper ? APInt::getSignedMaxValue(width)
+ : APInt::getSignedMinValue(width);
+}
+
+ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
+ Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+ ChangeResult result = ChangeResult::NoChange;
+ // Ignore non-integer outputs - return early if the op has no scalar
+ // integer results
+ bool hasIntegerResult = false;
+ for (Value v : op->getResults()) {
+ if (v.getType().isIntOrIndex())
+ hasIntegerResult = true;
+ else
+ result |= markAllPessimisticFixpoint(v);
+ }
+ if (!hasIntegerResult)
+ return result;
+
+ if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
+ LLVM_DEBUG(inferrable->print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+ SmallVector<ConstantIntRanges> argRanges(
+ llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
+ return val->getValue().value;
+ }));
+
+ auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
+ Optional<IntRangeLattice> oldRange;
+ if (!lattice.isUninitialized())
+ oldRange = lattice.getValue();
+ result |= lattice.join(IntRangeLattice(attrs));
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often (because
+ // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+ // and often can't).
+ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedResult && oldRange.hasValue() &&
+ !(lattice.getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ result |= lattice.markPessimisticFixpoint();
+ }
+ };
+
+ inferrable.inferResultRanges(argRanges, joinCallback);
+ for (Value opResult : op->getResults()) {
+ LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
+ // setResultRange() not called, make pessimistic.
+ if (lattice.isUninitialized())
+ result |= lattice.markPessimisticFixpoint();
+ }
+ } else if (op->getNumRegions() == 0) {
+ // No regions + no result inference method -> unbounded results (ex. memory
+ // ops)
+ result |= markAllPessimisticFixpoint(op->getResults());
+ }
+ return result;
+}
+
+LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
+ BranchOpInterface branch,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+ SmallVectorImpl<Block *> &successors) {
+ auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
+ Optional<APInt> maybeConstValue =
+ enumPair.value()->getValue().value.getConstantValue();
+
+ if (maybeConstValue) {
+ return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+ *maybeConstValue);
+ }
+ return {};
+ };
+ SmallVector<Attribute> inferredConsts(
+ llvm::map_range(llvm::enumerate(operands), toConstantAttr));
+ if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
+ successors.push_back(singleSucc);
+ return success();
+ }
+ return failure();
+}
+
+void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
+ RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) {
+ auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
+ Optional<APInt> maybeConstValue =
+ enumPair.value()->getValue().value.getConstantValue();
+
+ if (maybeConstValue) {
+ return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+ *maybeConstValue);
+ }
+ return {};
+ };
+ SmallVector<Attribute> inferredConsts(
+ llvm::map_range(llvm::enumerate(operands), toConstantAttr));
+ branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
+}
+
+ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor ®ion,
+ ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+ if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
+ LLVM_DEBUG(inferrable->print(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+ SmallVector<ConstantIntRanges> argRanges(
+ llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
+ return val->getValue().value;
+ }));
+
+ ChangeResult result = ChangeResult::NoChange;
+ auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
+ Optional<IntRangeLattice> oldRange;
+ if (!lattice.isUninitialized())
+ oldRange = lattice.getValue();
+ result |= lattice.join(IntRangeLattice(attrs));
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often (because
+ // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+ // and often can't).
+ bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedValue && oldRange.hasValue() &&
+ !(lattice.getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ result |= lattice.markPessimisticFixpoint();
+ }
+ };
+
+ inferrable.inferResultRanges(argRanges, joinCallback);
+ for (Value regionArg : region.getSuccessor()->getArguments()) {
+ LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
+ // setResultRange() not called, make pessimistic.
+ if (lattice.isUninitialized())
+ result |= lattice.markPessimisticFixpoint();
+ }
+
+ return result;
+ }
+
+ // Infer bounds for loop arguments that have static bounds
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+ Optional<Value> iv = loop.getSingleInductionVar();
+ if (!iv.hasValue()) {
+ return ForwardDataFlowAnalysis<
+ IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
+ }
+ Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
+ Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
+ Optional<OpFoldResult> step = loop.getSingleStep();
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
+ /*getUpper=*/false);
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
+ /*getUpper=*/true);
+ // Assume positivity for uniscoverable steps by way of getUpper = true.
+ APInt stepVal =
+ getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
+
+ if (stepVal.isNegative()) {
+ std::swap(min, max);
+ } else {
+ // Correct the upper bound by subtracting 1 so that it becomes a <= bound,
+ // because loops do not generally include their upper bound.
+ max -= 1;
+ }
+
+ LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
+ return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
+ }
+ return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
+ op, region, operands);
+}
+
+IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
+ impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
+ topLevelOperation->getContext());
+ impl->run(topLevelOperation);
+}
+
+IntRangeAnalysis::~IntRangeAnalysis() = default;
+IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
+
+Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
+ LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
+ if (result == nullptr || result->isUninitialized())
+ return llvm::None;
+ return result->getValue().value;
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 1178c40207566..2082ad41a7f27 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
CopyOpInterface.cpp
DataLayoutInterfaces.cpp
DerivedAttributeOpInterface.cpp
+ InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
SideEffectInterfaces.cpp
@@ -35,6 +36,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
+add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
new file mode 100644
index 0000000000000..777ea18456551
--- /dev/null
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -0,0 +1,99 @@
+//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
+
+using namespace mlir;
+
+bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
+ return umin().getBitWidth() == other.umin().getBitWidth() &&
+ umin() == other.umin() && umax() == other.umax() &&
+ smin() == other.smin() && smax() == other.smax();
+}
+
+const APInt &ConstantIntRanges::umin() const { return uminVal; }
+
+const APInt &ConstantIntRanges::umax() const { return umaxVal; }
+
+const APInt &ConstantIntRanges::smin() const { return sminVal; }
+
+const APInt &ConstantIntRanges::smax() const { return smaxVal; }
+
+unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+ if (type.isIndex())
+ return IndexType::kInternalStorageBitWidth;
+ if (auto integerType = type.dyn_cast<IntegerType>())
+ return integerType.getWidth();
+ // Non-integer types have their bounds stored in width 0 `APInt`s.
+ return 0;
+}
+
+ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
+ return {min, max, min, max};
+}
+
+ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
+ const APInt &smax) {
+ unsigned int width = smin.getBitWidth();
+ APInt umin, umax;
+ if (smin.isNonNegative() == smax.isNonNegative()) {
+ umin = smin.ult(smax) ? smin : smax;
+ umax = smin.ugt(smax) ? smin : smax;
+ } else {
+ umin = APInt::getMinValue(width);
+ umax = APInt::getMaxValue(width);
+ }
+ return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
+ const APInt &umax) {
+ unsigned int width = umin.getBitWidth();
+ APInt smin, smax;
+ if (umin.isNonNegative() == umax.isNonNegative()) {
+ smin = umin.slt(umax) ? umin : umax;
+ smax = umin.sgt(umax) ? umin : umax;
+ } else {
+ smin = APInt::getSignedMinValue(width);
+ smax = APInt::getSignedMaxValue(width);
+ }
+ return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges
+ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
+ // "Not an integer" poisons everything and also cannot be fed to comparison
+ // operators.
+ if (umin().getBitWidth() == 0)
+ return *this;
+ if (other.umin().getBitWidth() == 0)
+ return other;
+
+ const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+ const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+ const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+ const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+
+ return {uminUnion, umaxUnion, sminUnion, smaxUnion};
+}
+
+Optional<APInt> ConstantIntRanges::getConstantValue() const {
+ // Note: we need to exclude the trivially-equal width 0 values here.
+ if (umin() == umax() && umin().getBitWidth() != 0)
+ return umin();
+ if (smin() == smax() && smin().getBitWidth() != 0)
+ return smin();
+ return None;
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
+ return os << "unsigned : [" << range.umin() << ", " << range.umax()
+ << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+}
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
new file mode 100644
index 0000000000000..45d506d00d65d
--- /dev/null
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+
+// CHECK-LABEL: func @constant
+// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
+// CHECK: return %[[cst]]
+func.func @constant() -> index {
+ %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+ smin = 3 : index, smax = 3 : index}
+ func.return %0 : index
+}
+
+// CHECK-LABEL: func @increment
+// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
+// CHECK: return %[[cst]]
+func.func @increment() -> index {
+ %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = test.increment %0
+ func.return %1 : index
+}
+
+// CHECK-LABEL: func @maybe_increment
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @maybe_increment(%arg0 : i1) -> index {
+ %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+ smin = 3 : index, smax = 3 : index}
+ %1 = scf.if %arg0 -> index {
+ scf.yield %0 : index
+ } else {
+ %2 = test.increment %0
+ scf.yield %2 : index
+ }
+ %3 = test.reflect_bounds %1
+ func.return %3 : index
+}
+
+// CHECK-LABEL: func @maybe_increment_br
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @maybe_increment_br(%arg0 : i1) -> index {
+ %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+ smin = 3 : index, smax = 3 : index}
+ cf.cond_br %arg0, ^bb0, ^bb1
+^bb0:
+ %1 = test.increment %0
+ cf.br ^bb2(%1 : index)
+^bb1:
+ cf.br ^bb2(%0 : index)
+^bb2(%2 : index):
+ %3 = test.reflect_bounds %2
+ func.return %3 : index
+}
+
+// CHECK-LABEL: func @for_bounds
+// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
+func.func @for_bounds() -> index {
+ %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+ smin = 0 : index, smax = 0 : index}
+ %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
+ smin = 1 : index, smax = 1 : index}
+ %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
+ smin = 2 : index, smax = 2 : index}
+
+ %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
+ scf.yield %arg0 : index
+ }
+ %1 = test.reflect_bounds %0
+ func.return %1 : index
+}
+
+// CHECK-LABEL: func @no_analysis_of_loop_variants
+// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
+func.func @no_analysis_of_loop_variants() -> index {
+ %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+ smin = 0 : index, smax = 0 : index}
+ %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
+ smin = 1 : index, smax = 1 : index}
+ %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
+ smin = 2 : index, smax = 2 : index}
+
+ %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
+ %1 = test.increment %arg2
+ scf.yield %1 : index
+ }
+ %2 = test.reflect_bounds %0
+ func.return %2 : index
+}
+
+// CHECK-LABEL: func @region_args
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @region_args() {
+ test.with_bounds_region { umin = 3 : index, umax = 4 : index,
+ smin = 3 : index, smax = 4 : index } %arg0 {
+ %0 = test.reflect_bounds %arg0
+ }
+ func.return
+}
+
+// CHECK-LABEL: func @func_args_unbound
+// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
+func.func @func_args_unbound(%arg0 : index) -> index {
+ %0 = test.reflect_bounds %arg0
+ func.return %0 : index
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index de0e6d2f7ad07..de495bf8fe3db 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
MLIRFunc
MLIRFuncTransforms
MLIRIR
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRLinalg
MLIRLinalgTransforms
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 24e5bc6123603..415b545a6eae0 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -14,15 +14,21 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
@@ -1396,6 +1402,67 @@ LogicalResult TestVerifiersOp::verifyRegions() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Parse the input argument
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (failed(parser.parseArgument(argInfo)))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
+}
+
+void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs());
+ p << ' ';
+ p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
+ /*omitType=*/true);
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void TestWithBoundsRegionOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ Value arg = getRegion().getArgument(0);
+ setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ APInt one(range.umin().getBitWidth(), 1);
+ setResultRanges(getResult(),
+ {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+ range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+}
+
+void TestReflectBoundsOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ MLIRContext *ctx = getContext();
+ Builder b(ctx);
+ setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
+ setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
+ setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
+ setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+ setResultRanges(getResult(), range);
+}
+
#include "TestOpEnums.cpp.inc"
#include "TestOpInterfaces.cpp.inc"
#include "TestOpStructs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 480585000a89f..a894524311030 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -33,6 +33,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0fd0566803df6..94556c4d59a79 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -789,7 +790,7 @@ def StringAttrPrettyNameOp
def CustomResultsNameOp
: TEST_Op<"custom_result_name",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
- let arguments = (ins
+ let arguments = (ins
Variadic<AnyInteger>:$optional,
StrArrayAttr:$names
);
@@ -2885,4 +2886,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
}];
}
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+def TestWithBoundsOp : TEST_Op<"with_bounds",
+ [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ NoSideEffect]> {
+ let arguments = (ins IndexAttr:$umin,
+ IndexAttr:$umax,
+ IndexAttr:$smin,
+ IndexAttr:$smax);
+ let results = (outs Index:$fakeVal);
+
+ let assemblyFormat = "attr-dict";
+}
+
+def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
+ [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ SingleBlock, NoTerminator]> {
+ let arguments = (ins IndexAttr:$umin,
+ IndexAttr:$umax,
+ IndexAttr:$smin,
+ IndexAttr:$smax);
+ // The region has one argument of index type
+ let regions = (region SizedRegion<1>:$region);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def TestIncrementOp : TEST_Op<"increment",
+ [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+ NoSideEffect]> {
+ let arguments = (ins Index:$value);
+ let results = (outs Index:$result);
+
+ let assemblyFormat = "attr-dict $value";
+}
+
+def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
+ [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+ let arguments = (ins Index:$value,
+ OptionalAttr<IndexAttr>:$umin,
+ OptionalAttr<IndexAttr>:$umax,
+ OptionalAttr<IndexAttr>:$smin,
+ OptionalAttr<IndexAttr>:$smax);
+ let results = (outs Index:$result);
+
+ let assemblyFormat = "attr-dict $value";
+}
#endif // TEST_OPS
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 95c30c34f4950..00856562d8709 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
+ TestIntRangeInference.cpp
EXCLUDE_FROM_LIBMLIR
@@ -10,6 +11,8 @@ add_mlir_library(MLIRTestTransforms
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
LINK_LIBS PUBLIC
+ MLIRAnalysis
+ MLIRInferIntRangeInterface
MLIRTestDialect
MLIRTransforms
)
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
new file mode 100644
index 0000000000000..1bd2a24d3ce6c
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -0,0 +1,115 @@
+//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// TODO: This pass is needed to test integer range inference until that
+// functionality has been integrated into SCCP.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
+ OpBuilder &b, OperationFolder &folder,
+ Value value) {
+ Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
+ if (!maybeInferredRange)
+ return failure();
+ const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
+ Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+ if (!maybeConstValue.hasValue())
+ return failure();
+
+ Operation *maybeDefiningOp = value.getDefiningOp();
+ Dialect *valueDialect =
+ maybeDefiningOp ? maybeDefiningOp->getDialect()
+ : value.getParentRegion()->getParentOp()->getDialect();
+ Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
+ Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
+ value.getType(), value.getLoc());
+ if (!constant)
+ return failure();
+
+ value.replaceAllUsesWith(constant);
+ return success();
+}
+
+static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
+ MutableArrayRef<Region> initialRegions) {
+ SmallVector<Block *> worklist;
+ auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+ for (Region ®ion : regions)
+ for (Block &block : llvm::reverse(region))
+ worklist.push_back(&block);
+ };
+
+ OpBuilder builder(context);
+ OperationFolder folder(context);
+
+ addToWorklist(initialRegions);
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+
+ for (Operation &op : llvm::make_early_inc_range(*block)) {
+ builder.setInsertionPoint(&op);
+
+ // Replace any result with constants.
+ bool replacedAll = op.getNumResults() != 0;
+ for (Value res : op.getResults())
+ replacedAll &=
+ succeeded(replaceWithConstant(analysis, builder, folder, res));
+
+ // If all of the results of the operation were replaced, try to erase
+ // the operation completely.
+ if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+ assert(op.use_empty() && "expected all uses to be replaced");
+ op.erase();
+ continue;
+ }
+
+ // Add any the regions of this operation to the worklist.
+ addToWorklist(op.getRegions());
+ }
+
+ // Replace any block arguments with constants.
+ builder.setInsertionPointToStart(block);
+ for (BlockArgument arg : block->getArguments())
+ (void)replaceWithConstant(analysis, builder, folder, arg);
+ }
+}
+
+namespace {
+struct TestIntRangeInference
+ : PassWrapper<TestIntRangeInference, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
+
+ StringRef getArgument() const final { return "test-int-range-inference"; }
+ StringRef getDescription() const final {
+ return "Test integer range inference analysis";
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ IntRangeAnalysis analysis(op);
+ rewrite(analysis, op->getContext(), op->getRegions());
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIntRangeInference() {
+ PassRegistration<TestIntRangeInference>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e75c2758e88db..aa94294b4ea8d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -79,6 +79,7 @@ void registerTestDynamicPipelinePass();
void registerTestExpandMathPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
+void registerTestIntRangeInference();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestGenericIRVisitorsInterruptPass();
@@ -175,6 +176,7 @@ void registerTestPasses() {
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
+ mlir::test::registerTestIntRangeInference();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index 54a6837b0ed88..c4201b135e917 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_unittest(MLIRInterfacesTests
ControlFlowInterfacesTest.cpp
DataLayoutInterfacesTest.cpp
+ InferIntRangeInterfaceTest.cpp
InferTypeOpInterfaceTest.cpp
)
@@ -10,6 +11,7 @@ target_link_libraries(MLIRInterfacesTests
MLIRDataLayoutInterfaces
MLIRDLTI
MLIRFunc
+ MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRParser
)
diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
new file mode 100644
index 0000000000000..97c75b3680567
--- /dev/null
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -0,0 +1,99 @@
+//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/APInt.h"
+#include <limits>
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+TEST(IntRangeAttrs, BasicConstructors) {
+ APInt zero = APInt::getZero(64);
+ APInt two(64, 2);
+ APInt three(64, 3);
+ ConstantIntRanges boundedAbove(zero, two, zero, three);
+ EXPECT_EQ(boundedAbove.umin(), zero);
+ EXPECT_EQ(boundedAbove.umax(), two);
+ EXPECT_EQ(boundedAbove.smin(), zero);
+ EXPECT_EQ(boundedAbove.smax(), three);
+}
+
+TEST(IntRangeAttrs, FromUnsigned) {
+ APInt zero = APInt::getZero(64);
+ APInt maxInt = APInt::getSignedMaxValue(64);
+ APInt minInt = APInt::getSignedMinValue(64);
+ APInt minIntPlusOne = minInt + 1;
+
+ ConstantIntRanges canPortToSigned =
+ ConstantIntRanges::fromUnsigned(zero, maxInt);
+ EXPECT_EQ(canPortToSigned.smin(), zero);
+ EXPECT_EQ(canPortToSigned.smax(), maxInt);
+
+ ConstantIntRanges cantPortToSigned =
+ ConstantIntRanges::fromUnsigned(zero, minInt);
+ EXPECT_EQ(cantPortToSigned.smin(), minInt);
+ EXPECT_EQ(cantPortToSigned.smax(), maxInt);
+
+ ConstantIntRanges signedNegative =
+ ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne);
+ EXPECT_EQ(signedNegative.smin(), minInt);
+ EXPECT_EQ(signedNegative.smax(), minIntPlusOne);
+}
+
+TEST(IntRangeAttrs, FromSigned) {
+ APInt zero = APInt::getZero(64);
+ APInt one = zero + 1;
+ APInt negOne = zero - 1;
+ APInt intMax = APInt::getSignedMaxValue(64);
+ APInt intMin = APInt::getSignedMinValue(64);
+ APInt uintMax = APInt::getMaxValue(64);
+
+ ConstantIntRanges noUnsignedBound =
+ ConstantIntRanges::fromSigned(negOne, one);
+ EXPECT_EQ(noUnsignedBound.umin(), zero);
+ EXPECT_EQ(noUnsignedBound.umax(), uintMax);
+
+ ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax);
+ EXPECT_EQ(positive.umin(), one);
+ EXPECT_EQ(positive.umax(), intMax);
+
+ ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne);
+ EXPECT_EQ(negative.umin(), intMin);
+ EXPECT_EQ(negative.umax(), negOne);
+
+ ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one);
+ EXPECT_EQ(preserved.umin(), zero);
+ EXPECT_EQ(preserved.umax(), one);
+}
+
+TEST(IntRangeAttrs, Join) {
+ APInt zero = APInt::getZero(64);
+ APInt one = zero + 1;
+ APInt two = zero + 2;
+ APInt intMin = APInt::getSignedMinValue(64);
+ APInt intMax = APInt::getSignedMaxValue(64);
+ APInt uintMax = APInt::getMaxValue(64);
+
+ ConstantIntRanges maximal(zero, uintMax, intMin, intMax);
+ ConstantIntRanges zeroOne(zero, one, zero, one);
+
+ EXPECT_EQ(zeroOne.rangeUnion(maximal), maximal);
+ EXPECT_EQ(maximal.rangeUnion(zeroOne), maximal);
+
+ EXPECT_EQ(zeroOne.rangeUnion(zeroOne), zeroOne);
+
+ ConstantIntRanges oneTwo(one, two, one, two);
+ ConstantIntRanges zeroTwo(zero, two, zero, two);
+ EXPECT_EQ(zeroOne.rangeUnion(oneTwo), zeroTwo);
+
+ ConstantIntRanges zeroOneUnsignedOnly(zero, one, intMin, intMax);
+ ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
+ EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
+}
More information about the Mlir-commits
mailing list