[Mlir-commits] [mlir] 1350c98 - [mlir] Add integer range inference analysis

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Jun 2 13:24:17 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-06-02T20:24:11Z
New Revision: 1350c9887dca5ba80af8e3c1e61b29d6696eb240

URL: https://github.com/llvm/llvm-project/commit/1350c9887dca5ba80af8e3c1e61b29d6696eb240
DIFF: https://github.com/llvm/llvm-project/commit/1350c9887dca5ba80af8e3c1e61b29d6696eb240.diff

LOG: [mlir] Add integer range inference analysis

This commit defines a dataflow analysis for integer ranges, which
uses a newly-added InferIntRangeInterface to compute the lower and
upper bounds on the results of an operation from the bounds on the
arguments. The range inference is a flow-insensitive dataflow analysis
that can be used to simplify code, such as by statically identifying
bounds checks that cannot fail in order to eliminate them.

The InferIntRangeInterface has one method, inferResultRanges(), which
takes a vector of inferred ranges for each argument to an op
implementing the interface and a callback allowing the implementation
to define the ranges for each result. These ranges are stored as
ConstantIntRanges, which hold the lower and upper bounds for a
value. Bounds are tracked separately for the signed and unsigned
interpretations of a value, which ensures that the impact of
arithmetic overflows is correctly tracked during the analysis.

The commit also adds a -test-int-range-inference pass to test the
analysis until it is integrated into SCCP or otherwise exposed.

Finally, this commit fixes some bugs relating to the handling of
region iteration arguments and terminators in the data flow analysis
framework.

Depends on D124020

Depends on D124021

Reviewed By: rriddle, Mogball

Differential Revision: https://reviews.llvm.org/D124023

Added: 
    mlir/include/mlir/Analysis/IntRangeAnalysis.h
    mlir/include/mlir/Interfaces/InferIntRangeInterface.h
    mlir/include/mlir/Interfaces/InferIntRangeInterface.td
    mlir/lib/Analysis/IntRangeAnalysis.cpp
    mlir/lib/Interfaces/InferIntRangeInterface.cpp
    mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
    mlir/test/lib/Transforms/TestIntRangeInference.cpp
    mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp

Modified: 
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Analysis/CMakeLists.txt
    mlir/lib/Analysis/DataFlowAnalysis.cpp
    mlir/lib/Interfaces/CMakeLists.txt
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    mlir/unittests/Interfaces/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
new file mode 100644
index 0000000000000..b2b604359b48b
--- /dev/null
+++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
@@ -0,0 +1,41 @@
+//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer range inference
+// so that it can be used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
+#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+namespace detail {
+class IntRangeAnalysisImpl;
+} // end namespace detail
+
+class IntRangeAnalysis {
+public:
+  /// Analyze all operations rooted under (but not including)
+  /// `topLevelOperation`.
+  IntRangeAnalysis(Operation *topLevelOperation);
+  IntRangeAnalysis(IntRangeAnalysis &&other);
+  ~IntRangeAnalysis();
+
+  /// Get inferred range for value `v` if one exists.
+  Optional<ConstantIntRanges> getResult(Value v);
+
+private:
+  std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
+};
+} // end namespace mlir
+
+#endif

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index cf075f728fdbf..918f3ea398e47 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(CopyOpInterface)
 add_mlir_interface(DerivedAttributeOpInterface)
+add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(SideEffectInterfaces)

diff  --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
new file mode 100644
index 0000000000000..9a393855d05ff
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -0,0 +1,98 @@
+//===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the integer range inference interface
+// defined in `InferIntRange.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
+#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+/// A set of arbitrary-precision integers representing bounds on a given integer
+/// value. These bounds are inclusive on both ends, so
+/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
+/// the unsigned and signed interpretations of values in order to enable more
+/// precice inference of the interplay between operations with signed and
+/// unsigned semantics.
+class ConstantIntRanges {
+public:
+  /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
+  /// Non-integer values should be bounded by APInts of bitwidth 0.
+  ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
+                    const APInt &smax)
+      : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
+    assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
+           umaxVal.getBitWidth() == sminVal.getBitWidth() &&
+           sminVal.getBitWidth() == smaxVal.getBitWidth() &&
+           "All bounds in the ranges must have the same bitwidth");
+  }
+
+  bool operator==(const ConstantIntRanges &other) const;
+
+  /// The minimum value of an integer when it is interpreted as unsigned.
+  const APInt &umin() const;
+
+  /// The maximum value of an integer when it is interpreted as unsigned.
+  const APInt &umax() const;
+
+  /// The minimum value of an integer when it is interpreted as signed.
+  const APInt &smin() const;
+
+  /// The maximum value of an integer when it is interpreted as signed.
+  const APInt &smax() const;
+
+  /// Return the bitwidth that should be used for integer ranges describing
+  /// `type`. For concrete integer types, this is their bitwidth, for `index`,
+  /// this is the internal storage bitwidth of `index` attributes, and for
+  /// non-integer types this is 0.
+  static unsigned getStorageBitwidth(Type type);
+
+  /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned
+  /// minimum and `max` is both the signed and unsigned maximum.
+  static ConstantIntRanges range(const APInt &min, const APInt &max);
+
+  /// Create an `IntRangeAttrs` with the signed minimum and maximum equal
+  /// to `smin` and `smax`, where the unsigned bounds are constructed from the
+  /// signed ones if they correspond to a contigious range of bit patterns when
+  /// viewed as unsigned values and are left at [0, int_max()] otherwise.
+  static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
+
+  /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal
+  /// to `umin` and `umax` and the signed part equal to `umin` and `umax`
+  /// unless the sign bit changes between the minimum and maximum.
+  static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
+
+  /// Returns the union (computed separately for signed and unsigned bounds)
+  /// of `a` and `b`.
+  ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
+
+  /// If either the signed or unsigned interpretations of the range
+  /// indicate that the value it bounds is a constant, return that constant
+  /// value.
+  Optional<APInt> getConstantValue() const;
+
+  friend raw_ostream &operator<<(raw_ostream &os,
+                                 const ConstantIntRanges &range);
+
+private:
+  APInt uminVal, umaxVal, sminVal, smaxVal;
+};
+
+/// The type of the `setResultRanges` callback provided to ops implementing
+/// InferIntRangeInterface. It should be called once for each integer result
+/// value and be passed the ConstantIntRanges corresponding to that value.
+using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H

diff  --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
new file mode 100644
index 0000000000000..57f8d693b7916
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -0,0 +1,52 @@
+//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-----------------------------------------------------===//
+//
+// Defines the interface for range analysis on scalar integers
+//
+//===-----------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE
+#define MLIR_INTERFACES_INFERINTRANGEINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
+  let description = [{
+    Allows operations to participate in range analysis for scalar integer values by
+    providing a methods that allows them to specify lower and upper bounds on their
+    result(s) given lower and upper bounds on their input(s) if known.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Infer the bounds on the results of this op given the bounds on its arguments.
+      For each result value or block argument (that isn't a branch argument,
+      since the dataflow analysis handles those case), the method should call
+      `setValueRange` with that `Value` as an argument. When `setValueRange`
+      is not called for some value, it will recieve a default value of the mimimum
+      and maximum values forits type (the unbounded range).
+
+      When called on an op that also implements the RegionBranchOpInterface
+      or BranchOpInterface, this method should not attempt to infer the values
+      of the branch results, as this will be handled by the analyses that use
+      this interface.
+
+      This function will only be called when at least one result of the op is a
+      scalar integer value or the op has a region.
+
+      `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
+       order. Non-integer arguments will have the an unbounded range of width-0
+       APInts in their `argRanges` element.
+    }],
+    "void", "inferResultRanges", (ins
+      "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+      "::mlir::SetIntRangeFn":$setResultRanges)
+  >];
+}
+#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 29314ed535931..bc5017d065aec 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
   CallGraph.cpp
   DataFlowAnalysis.cpp
   DataLayoutAnalysis.cpp
+  IntRangeAnalysis.cpp
   Liveness.cpp
   SliceAnalysis.cpp
 
@@ -16,6 +17,7 @@ add_mlir_library(MLIRAnalysis
   CallGraph.cpp
   DataFlowAnalysis.cpp
   DataLayoutAnalysis.cpp
+  IntRangeAnalysis.cpp
   Liveness.cpp
   SliceAnalysis.cpp
 
@@ -31,6 +33,7 @@ add_mlir_library(MLIRAnalysis
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRSideEffectInterfaces
   MLIRViewLikeInterface

diff  --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index 9c10595dbb00a..239d9e4060bca 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -359,11 +359,20 @@ void ForwardDataFlowSolver::visitOperation(Operation *op) {
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
       return visitRegionBranchOperation(branch, operandLattices);
 
-    // If we can't, conservatively mark all regions as executable.
-    // TODO: Let the `visitOperation` method decide how to propagate
-    // information to the block arguments.
-    for (Region &region : op->getRegions())
-      markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
+    for (Region &region : op->getRegions()) {
+      analysis.visitNonControlFlowArguments(op, RegionSuccessor(&region),
+                                            operandLattices);
+      // `visitNonControlFlowArguments` is required to define all of the region
+      // argument lattices.
+      assert(llvm::none_of(
+                 region.getArguments(),
+                 [&](Value value) {
+                   return analysis.getLatticeElement(value).isUninitialized();
+                 }) &&
+             "expected `visitNonControlFlowArguments` to define all argument "
+             "lattices");
+      markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/false);
+    }
   }
 
   // If this op produces no results, it can't produce any constants.
@@ -567,12 +576,45 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
       return;
 
+    // If the branch is a RegionBranchTerminatorOpInterface,
+    // construct the set of operand lattices as the set of non control-flow
+    // arguments of the parent and the values this op returns. This allows
+    // for the correct lattices to be passed to getSuccessorsForOperands()
+    // in cases such as scf.while.
+    ArrayRef<AbstractLatticeElement *> branchOpLattices = operandLattices;
+    SmallVector<AbstractLatticeElement *, 0> parentLattices;
+    if (auto regionTerminator =
+            dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+      parentLattices.reserve(regionInterface->getNumOperands());
+      for (Value parentOperand : regionInterface->getOperands()) {
+        AbstractLatticeElement *operandLattice =
+            analysis.lookupLatticeElement(parentOperand);
+        if (!operandLattice || operandLattice->isUninitialized())
+          return;
+        parentLattices.push_back(operandLattice);
+      }
+      unsigned regionNumber = parentRegion->getRegionNumber();
+      OperandRange iterArgs =
+          regionInterface.getSuccessorEntryOperands(regionNumber);
+      OperandRange terminatorArgs =
+          regionTerminator.getSuccessorOperands(regionNumber);
+      assert(iterArgs.size() == terminatorArgs.size() &&
+             "Number of iteration arguments for region should equal number of "
+             "those arguments defined by terminator");
+      if (!iterArgs.empty()) {
+        unsigned iterStart = iterArgs.getBeginOperandIndex();
+        unsigned terminatorStart = terminatorArgs.getBeginOperandIndex();
+        for (unsigned i = 0, e = iterArgs.size(); i < e; ++i)
+          parentLattices[iterStart + i] = operandLattices[terminatorStart + i];
+      }
+      branchOpLattices = parentLattices;
+    }
     // Query the set of successors of the current region using the current
     // optimistic lattice state.
     SmallVector<RegionSuccessor, 1> regionSuccessors;
     analysis.getSuccessorsForOperands(regionInterface,
                                       parentRegion->getRegionNumber(),
-                                      operandLattices, regionSuccessors);
+                                      branchOpLattices, regionSuccessors);
     if (regionSuccessors.empty())
       return;
 
@@ -584,7 +626,7 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
         // region index (if any).
         return *getRegionBranchSuccessorOperands(op, regionIndex);
       };
-      return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
+      return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices,
                                    getOperands);
     }
 

diff  --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
new file mode 100644
index 0000000000000..7e6d61ff89560
--- /dev/null
+++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp
@@ -0,0 +1,325 @@
+//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+
+namespace {
+/// A wrapper around ConstantIntRanges that provides the lattice functions
+/// expected by dataflow analysis.
+struct IntRangeLattice {
+  IntRangeLattice(const ConstantIntRanges &value) : value(value){};
+  IntRangeLattice(ConstantIntRanges &&value) : value(value){};
+
+  bool operator==(const IntRangeLattice &other) const {
+    return value == other.value;
+  }
+
+  /// wrapper around rangeUnion()
+  static IntRangeLattice join(const IntRangeLattice &a,
+                              const IntRangeLattice &b) {
+    return a.value.rangeUnion(b.value);
+  }
+
+  /// Creates a range with bitwidth 0 to represent that we don't know if the
+  /// value being marked overdefined is even an integer.
+  static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
+    APInt noIntValue = APInt::getZeroWidth();
+    return ConstantIntRanges::range(noIntValue, noIntValue);
+  }
+
+  /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+  /// range that is used to mark the value v as unable to be analyzed further,
+  /// where t is the type of v.
+  static IntRangeLattice getPessimisticValueState(Value v) {
+    unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
+    APInt umin = APInt::getMinValue(width);
+    APInt umax = APInt::getMaxValue(width);
+    APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+    APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+    return ConstantIntRanges{umin, umax, smin, smax};
+  }
+
+  ConstantIntRanges value;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace detail {
+class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
+  using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
+
+public:
+  /// Define bounds on the results or block arguments of the operation
+  /// based on the bounds on the arguments given in `operands`
+  ChangeResult
+  visitOperation(Operation *op,
+                 ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
+
+  /// Skip regions of branch ops when we can statically infer constant
+  /// values for operands to the branch op and said op tells us it's safe to do
+  /// so.
+  LogicalResult
+  getSuccessorsForOperands(BranchOpInterface branch,
+                           ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+                           SmallVectorImpl<Block *> &successors) final;
+
+  /// Skip regions of branch or loop ops when we can statically infer constant
+  /// values for operands to the branch op and said op tells us it's safe to do
+  /// so.
+  void
+  getSuccessorsForOperands(RegionBranchOpInterface branch,
+                           Optional<unsigned> sourceIndex,
+                           ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+                           SmallVectorImpl<RegionSuccessor> &successors) final;
+
+  /// Call the InferIntRangeInterface implementation for region-using ops
+  /// that implement it, and infer the bounds of loop induction variables
+  /// for ops that implement LoopLikeOPInterface.
+  ChangeResult visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &region,
+      ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
+};
+} // end namespace detail
+} // end namespace mlir
+
+/// Given the results of getConstant{Lower,Upper}Bound()
+/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
+/// that result if possible.
+static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
+                                  Type boundType,
+                                  detail::IntRangeAnalysisImpl &analysis,
+                                  bool getUpper) {
+  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
+  if (loopBound.hasValue()) {
+    if (loopBound->is<Attribute>()) {
+      if (auto bound =
+              loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
+        return bound.getValue();
+    } else if (loopBound->is<Value>()) {
+      LatticeElement<IntRangeLattice> *lattice =
+          analysis.lookupLatticeElement(loopBound->get<Value>());
+      if (lattice != nullptr)
+        return getUpper ? lattice->getValue().value.smax()
+                        : lattice->getValue().value.smin();
+    }
+  }
+  return getUpper ? APInt::getSignedMaxValue(width)
+                  : APInt::getSignedMinValue(width);
+}
+
+ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
+    Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+  ChangeResult result = ChangeResult::NoChange;
+  // Ignore non-integer outputs - return early if the op has no scalar
+  // integer results
+  bool hasIntegerResult = false;
+  for (Value v : op->getResults()) {
+    if (v.getType().isIntOrIndex())
+      hasIntegerResult = true;
+    else
+      result |= markAllPessimisticFixpoint(v);
+  }
+  if (!hasIntegerResult)
+    return result;
+
+  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+    LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
+    LLVM_DEBUG(inferrable->print(llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "\n");
+    SmallVector<ConstantIntRanges> argRanges(
+        llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
+          return val->getValue().value;
+        }));
+
+    auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+      LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
+      Optional<IntRangeLattice> oldRange;
+      if (!lattice.isUninitialized())
+        oldRange = lattice.getValue();
+      result |= lattice.join(IntRangeLattice(attrs));
+
+      // Catch loop results with loop variant bounds and conservatively make
+      // them [-inf, inf] so we don't circle around infinitely often (because
+      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+      // and often can't).
+      bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+        return op->hasTrait<OpTrait::IsTerminator>();
+      });
+      if (isYieldedResult && oldRange.hasValue() &&
+          !(lattice.getValue() == *oldRange)) {
+        LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+        result |= lattice.markPessimisticFixpoint();
+      }
+    };
+
+    inferrable.inferResultRanges(argRanges, joinCallback);
+    for (Value opResult : op->getResults()) {
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
+      // setResultRange() not called, make pessimistic.
+      if (lattice.isUninitialized())
+        result |= lattice.markPessimisticFixpoint();
+    }
+  } else if (op->getNumRegions() == 0) {
+    // No regions + no result inference method -> unbounded results (ex. memory
+    // ops)
+    result |= markAllPessimisticFixpoint(op->getResults());
+  }
+  return result;
+}
+
+LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
+    BranchOpInterface branch,
+    ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+    SmallVectorImpl<Block *> &successors) {
+  auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
+    Optional<APInt> maybeConstValue =
+        enumPair.value()->getValue().value.getConstantValue();
+
+    if (maybeConstValue) {
+      return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+                              *maybeConstValue);
+    }
+    return {};
+  };
+  SmallVector<Attribute> inferredConsts(
+      llvm::map_range(llvm::enumerate(operands), toConstantAttr));
+  if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
+    successors.push_back(singleSucc);
+    return success();
+  }
+  return failure();
+}
+
+void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
+    RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
+    ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+    SmallVectorImpl<RegionSuccessor> &successors) {
+  auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
+    Optional<APInt> maybeConstValue =
+        enumPair.value()->getValue().value.getConstantValue();
+
+    if (maybeConstValue) {
+      return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+                              *maybeConstValue);
+    }
+    return {};
+  };
+  SmallVector<Attribute> inferredConsts(
+      llvm::map_range(llvm::enumerate(operands), toConstantAttr));
+  branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
+}
+
+ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
+    Operation *op, const RegionSuccessor &region,
+    ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+    LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
+    LLVM_DEBUG(inferrable->print(llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "\n");
+    SmallVector<ConstantIntRanges> argRanges(
+        llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
+          return val->getValue().value;
+        }));
+
+    ChangeResult result = ChangeResult::NoChange;
+    auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+      LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
+      Optional<IntRangeLattice> oldRange;
+      if (!lattice.isUninitialized())
+        oldRange = lattice.getValue();
+      result |= lattice.join(IntRangeLattice(attrs));
+
+      // Catch loop results with loop variant bounds and conservatively make
+      // them [-inf, inf] so we don't circle around infinitely often (because
+      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+      // and often can't).
+      bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+        return op->hasTrait<OpTrait::IsTerminator>();
+      });
+      if (isYieldedValue && oldRange.hasValue() &&
+          !(lattice.getValue() == *oldRange)) {
+        LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+        result |= lattice.markPessimisticFixpoint();
+      }
+    };
+
+    inferrable.inferResultRanges(argRanges, joinCallback);
+    for (Value regionArg : region.getSuccessor()->getArguments()) {
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
+      // setResultRange() not called, make pessimistic.
+      if (lattice.isUninitialized())
+        result |= lattice.markPessimisticFixpoint();
+    }
+
+    return result;
+  }
+
+  // Infer bounds for loop arguments that have static bounds
+  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+    Optional<Value> iv = loop.getSingleInductionVar();
+    if (!iv.hasValue()) {
+      return ForwardDataFlowAnalysis<
+          IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
+    }
+    Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
+    Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
+    Optional<OpFoldResult> step = loop.getSingleStep();
+    APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
+                                     /*getUpper=*/false);
+    APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
+                                     /*getUpper=*/true);
+    // Assume positivity for uniscoverable steps by way of getUpper = true.
+    APInt stepVal =
+        getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
+
+    if (stepVal.isNegative()) {
+      std::swap(min, max);
+    } else {
+      // Correct the upper bound by subtracting 1 so that it becomes a <= bound,
+      // because loops do not generally include their upper bound.
+      max -= 1;
+    }
+
+    LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
+    return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
+  }
+  return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
+      op, region, operands);
+}
+
+IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
+  impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
+      topLevelOperation->getContext());
+  impl->run(topLevelOperation);
+}
+
+IntRangeAnalysis::~IntRangeAnalysis() = default;
+IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
+
+Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
+  LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
+  if (result == nullptr || result->isUninitialized())
+    return llvm::None;
+  return result->getValue().value;
+}

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 1178c40207566..2082ad41a7f27 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
   CopyOpInterface.cpp
   DataLayoutInterfaces.cpp
   DerivedAttributeOpInterface.cpp
+  InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
   SideEffectInterfaces.cpp
@@ -35,6 +36,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
+add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(SideEffectInterfaces)
 add_mlir_interface_library(TilingInterface)

diff  --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
new file mode 100644
index 0000000000000..777ea18456551
--- /dev/null
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -0,0 +1,99 @@
+//===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
+
+using namespace mlir;
+
+bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
+  return umin().getBitWidth() == other.umin().getBitWidth() &&
+         umin() == other.umin() && umax() == other.umax() &&
+         smin() == other.smin() && smax() == other.smax();
+}
+
+const APInt &ConstantIntRanges::umin() const { return uminVal; }
+
+const APInt &ConstantIntRanges::umax() const { return umaxVal; }
+
+const APInt &ConstantIntRanges::smin() const { return sminVal; }
+
+const APInt &ConstantIntRanges::smax() const { return smaxVal; }
+
+unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+  if (type.isIndex())
+    return IndexType::kInternalStorageBitWidth;
+  if (auto integerType = type.dyn_cast<IntegerType>())
+    return integerType.getWidth();
+  // Non-integer types have their bounds stored in width 0 `APInt`s.
+  return 0;
+}
+
+ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
+  return {min, max, min, max};
+}
+
+ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
+                                                const APInt &smax) {
+  unsigned int width = smin.getBitWidth();
+  APInt umin, umax;
+  if (smin.isNonNegative() == smax.isNonNegative()) {
+    umin = smin.ult(smax) ? smin : smax;
+    umax = smin.ugt(smax) ? smin : smax;
+  } else {
+    umin = APInt::getMinValue(width);
+    umax = APInt::getMaxValue(width);
+  }
+  return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
+                                                  const APInt &umax) {
+  unsigned int width = umin.getBitWidth();
+  APInt smin, smax;
+  if (umin.isNonNegative() == umax.isNonNegative()) {
+    smin = umin.slt(umax) ? umin : umax;
+    smax = umin.sgt(umax) ? umin : umax;
+  } else {
+    smin = APInt::getSignedMinValue(width);
+    smax = APInt::getSignedMaxValue(width);
+  }
+  return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges
+ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
+  // "Not an integer" poisons everything and also cannot be fed to comparison
+  // operators.
+  if (umin().getBitWidth() == 0)
+    return *this;
+  if (other.umin().getBitWidth() == 0)
+    return other;
+
+  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+
+  return {uminUnion, umaxUnion, sminUnion, smaxUnion};
+}
+
+Optional<APInt> ConstantIntRanges::getConstantValue() const {
+  // Note: we need to exclude the trivially-equal width 0 values here.
+  if (umin() == umax() && umin().getBitWidth() != 0)
+    return umin();
+  if (smin() == smax() && smin().getBitWidth() != 0)
+    return smin();
+  return None;
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
+  return os << "unsigned : [" << range.umin() << ", " << range.umax()
+            << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+}

diff  --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
new file mode 100644
index 0000000000000..45d506d00d65d
--- /dev/null
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+
+// CHECK-LABEL: func @constant
+// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
+// CHECK: return %[[cst]]
+func.func @constant() -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+                               smin = 3 : index, smax = 3 : index}
+  func.return %0 : index
+}
+
+// CHECK-LABEL: func @increment
+// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
+// CHECK: return %[[cst]]
+func.func @increment() -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = test.increment %0
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @maybe_increment
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @maybe_increment(%arg0 : i1) -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+                               smin = 3 : index, smax = 3 : index}
+  %1 = scf.if %arg0 -> index {
+    scf.yield %0 : index
+  } else {
+    %2 = test.increment %0
+    scf.yield %2 : index
+  }
+  %3 = test.reflect_bounds %1
+  func.return %3 : index
+}
+
+// CHECK-LABEL: func @maybe_increment_br
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @maybe_increment_br(%arg0 : i1) -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+                               smin = 3 : index, smax = 3 : index}
+  cf.cond_br %arg0, ^bb0, ^bb1
+^bb0:
+    %1 = test.increment %0
+    cf.br ^bb2(%1 : index)
+^bb1:
+    cf.br ^bb2(%0 : index)
+^bb2(%2 : index):
+  %3 = test.reflect_bounds %2
+  func.return %3 : index
+}
+
+// CHECK-LABEL: func @for_bounds
+// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
+func.func @for_bounds() -> index {
+  %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                                smin = 0 : index, smax = 0 : index}
+  %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
+                                smin = 1 : index, smax = 1 : index}
+  %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
+                                smin = 2 : index, smax = 2 : index}
+
+  %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
+    scf.yield %arg0 : index
+  }
+  %1 = test.reflect_bounds %0
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @no_analysis_of_loop_variants
+// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
+func.func @no_analysis_of_loop_variants() -> index {
+  %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                                smin = 0 : index, smax = 0 : index}
+  %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
+                                smin = 1 : index, smax = 1 : index}
+  %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
+                                smin = 2 : index, smax = 2 : index}
+
+  %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
+    %1 = test.increment %arg2
+    scf.yield %1 : index
+  }
+  %2 = test.reflect_bounds %0
+  func.return %2 : index
+}
+
+// CHECK-LABEL: func @region_args
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @region_args() {
+  test.with_bounds_region { umin = 3 : index, umax = 4 : index,
+                            smin = 3 : index, smax = 4 : index } %arg0 {
+    %0 = test.reflect_bounds %arg0
+  }
+  func.return
+}
+
+// CHECK-LABEL: func @func_args_unbound
+// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
+func.func @func_args_unbound(%arg0 : index) -> index {
+  %0 = test.reflect_bounds %arg0
+  func.return %0 : index
+}

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index de0e6d2f7ad07..de495bf8fe3db 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
   MLIRFunc
   MLIRFuncTransforms
   MLIRIR
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRLinalg
   MLIRLinalgTransforms

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 24e5bc6123603..9c0616db3705c 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -14,15 +14,21 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 
@@ -1396,6 +1402,66 @@ LogicalResult TestVerifiersOp::verifyRegions() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                         SetIntRangeFn setResultRanges) {
+  setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
+                                          OperationState &result) {
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  // Parse the input argument
+  OpAsmParser::Argument argInfo;
+  argInfo.type = parser.getBuilder().getIndexType();
+  parser.parseArgument(argInfo);
+
+  // Parse the body region, and reuse the operand info as the argument info.
+  Region *body = result.addRegion();
+  return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
+}
+
+void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
+  p.printOptionalAttrDict((*this)->getAttrs());
+  p << ' ';
+  p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
+                        /*omitType=*/true);
+  p << ' ';
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void TestWithBoundsRegionOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  Value arg = getRegion().getArgument(0);
+  setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                        SetIntRangeFn setResultRanges) {
+  const ConstantIntRanges &range = argRanges[0];
+  APInt one(range.umin().getBitWidth(), 1);
+  setResultRanges(getResult(),
+                  {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+                   range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+}
+
+void TestReflectBoundsOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  const ConstantIntRanges &range = argRanges[0];
+  MLIRContext *ctx = getContext();
+  Builder b(ctx);
+  setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
+  setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
+  setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
+  setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+  setResultRanges(getResult(), range);
+}
+
 #include "TestOpEnums.cpp.inc"
 #include "TestOpInterfaces.cpp.inc"
 #include "TestOpStructs.cpp.inc"

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 480585000a89f..a894524311030 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -33,6 +33,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0fd0566803df6..94556c4d59a79 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -789,7 +790,7 @@ def StringAttrPrettyNameOp
 def CustomResultsNameOp
  : TEST_Op<"custom_result_name",
            [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
-  let arguments = (ins 
+  let arguments = (ins
     Variadic<AnyInteger>:$optional,
     StrArrayAttr:$names
   );
@@ -2885,4 +2886,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+def TestWithBoundsOp : TEST_Op<"with_bounds",
+                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                           NoSideEffect]> {
+  let arguments = (ins IndexAttr:$umin,
+                       IndexAttr:$umax,
+                       IndexAttr:$smin,
+                       IndexAttr:$smax);
+  let results = (outs Index:$fakeVal);
+
+  let assemblyFormat = "attr-dict";
+}
+
+def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
+                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                           SingleBlock, NoTerminator]> {
+  let arguments = (ins IndexAttr:$umin,
+                       IndexAttr:$umax,
+                       IndexAttr:$smin,
+                       IndexAttr:$smax);
+  // The region has one argument of index type
+  let regions = (region SizedRegion<1>:$region);
+  let hasCustomAssemblyFormat = 1;
+}
+
+def TestIncrementOp : TEST_Op<"increment",
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                         NoSideEffect]> {
+  let arguments = (ins Index:$value);
+  let results = (outs Index:$result);
+
+  let assemblyFormat = "attr-dict $value";
+}
+
+def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+  let arguments = (ins Index:$value,
+                       OptionalAttr<IndexAttr>:$umin,
+                       OptionalAttr<IndexAttr>:$umax,
+                       OptionalAttr<IndexAttr>:$smin,
+                       OptionalAttr<IndexAttr>:$smax);
+  let results = (outs Index:$result);
+
+  let assemblyFormat = "attr-dict $value";
+}
 #endif // TEST_OPS

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 95c30c34f4950..00856562d8709 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
+  TestIntRangeInference.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
@@ -10,6 +11,8 @@ add_mlir_library(MLIRTestTransforms
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
 
   LINK_LIBS PUBLIC
+  MLIRAnalysis
+  MLIRInferIntRangeInterface
   MLIRTestDialect
   MLIRTransforms
   )

diff  --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
new file mode 100644
index 0000000000000..1bd2a24d3ce6c
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -0,0 +1,115 @@
+//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// TODO: This pass is needed to test integer range inference until that
+// functionality has been integrated into SCCP.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
+                                         OpBuilder &b, OperationFolder &folder,
+                                         Value value) {
+  Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
+  if (!maybeInferredRange)
+    return failure();
+  const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
+  Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+  if (!maybeConstValue.hasValue())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
+  Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
+                                              value.getType(), value.getLoc());
+  if (!constant)
+    return failure();
+
+  value.replaceAllUsesWith(constant);
+  return success();
+}
+
+static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
+                    MutableArrayRef<Region> initialRegions) {
+  SmallVector<Block *> worklist;
+  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+    for (Region &region : regions)
+      for (Block &block : llvm::reverse(region))
+        worklist.push_back(&block);
+  };
+
+  OpBuilder builder(context);
+  OperationFolder folder(context);
+
+  addToWorklist(initialRegions);
+  while (!worklist.empty()) {
+    Block *block = worklist.pop_back_val();
+
+    for (Operation &op : llvm::make_early_inc_range(*block)) {
+      builder.setInsertionPoint(&op);
+
+      // Replace any result with constants.
+      bool replacedAll = op.getNumResults() != 0;
+      for (Value res : op.getResults())
+        replacedAll &=
+            succeeded(replaceWithConstant(analysis, builder, folder, res));
+
+      // If all of the results of the operation were replaced, try to erase
+      // the operation completely.
+      if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+        assert(op.use_empty() && "expected all uses to be replaced");
+        op.erase();
+        continue;
+      }
+
+      // Add any the regions of this operation to the worklist.
+      addToWorklist(op.getRegions());
+    }
+
+    // Replace any block arguments with constants.
+    builder.setInsertionPointToStart(block);
+    for (BlockArgument arg : block->getArguments())
+      (void)replaceWithConstant(analysis, builder, folder, arg);
+  }
+}
+
+namespace {
+struct TestIntRangeInference
+    : PassWrapper<TestIntRangeInference, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
+
+  StringRef getArgument() const final { return "test-int-range-inference"; }
+  StringRef getDescription() const final {
+    return "Test integer range inference analysis";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    IntRangeAnalysis analysis(op);
+    rewrite(analysis, op->getContext(), op->getRegions());
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIntRangeInference() {
+  PassRegistration<TestIntRangeInference>();
+}
+} // end namespace test
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e75c2758e88db..aa94294b4ea8d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -79,6 +79,7 @@ void registerTestDynamicPipelinePass();
 void registerTestExpandMathPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
+void registerTestIntRangeInference();
 void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
@@ -175,6 +176,7 @@ void registerTestPasses() {
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();
+  mlir::test::registerTestIntRangeInference();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();

diff  --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index 54a6837b0ed88..c4201b135e917 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_unittest(MLIRInterfacesTests
   ControlFlowInterfacesTest.cpp
   DataLayoutInterfacesTest.cpp
+  InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
 )
 
@@ -10,6 +11,7 @@ target_link_libraries(MLIRInterfacesTests
   MLIRDataLayoutInterfaces
   MLIRDLTI
   MLIRFunc
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRParser
 )

diff  --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
new file mode 100644
index 0000000000000..97c75b3680567
--- /dev/null
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -0,0 +1,99 @@
+//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/APInt.h"
+#include <limits>
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+TEST(IntRangeAttrs, BasicConstructors) {
+  APInt zero = APInt::getZero(64);
+  APInt two(64, 2);
+  APInt three(64, 3);
+  ConstantIntRanges boundedAbove(zero, two, zero, three);
+  EXPECT_EQ(boundedAbove.umin(), zero);
+  EXPECT_EQ(boundedAbove.umax(), two);
+  EXPECT_EQ(boundedAbove.smin(), zero);
+  EXPECT_EQ(boundedAbove.smax(), three);
+}
+
+TEST(IntRangeAttrs, FromUnsigned) {
+  APInt zero = APInt::getZero(64);
+  APInt maxInt = APInt::getSignedMaxValue(64);
+  APInt minInt = APInt::getSignedMinValue(64);
+  APInt minIntPlusOne = minInt + 1;
+
+  ConstantIntRanges canPortToSigned =
+      ConstantIntRanges::fromUnsigned(zero, maxInt);
+  EXPECT_EQ(canPortToSigned.smin(), zero);
+  EXPECT_EQ(canPortToSigned.smax(), maxInt);
+
+  ConstantIntRanges cantPortToSigned =
+      ConstantIntRanges::fromUnsigned(zero, minInt);
+  EXPECT_EQ(cantPortToSigned.smin(), minInt);
+  EXPECT_EQ(cantPortToSigned.smax(), maxInt);
+
+  ConstantIntRanges signedNegative =
+      ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne);
+  EXPECT_EQ(signedNegative.smin(), minInt);
+  EXPECT_EQ(signedNegative.smax(), minIntPlusOne);
+}
+
+TEST(IntRangeAttrs, FromSigned) {
+  APInt zero = APInt::getZero(64);
+  APInt one = zero + 1;
+  APInt negOne = zero - 1;
+  APInt intMax = APInt::getSignedMaxValue(64);
+  APInt intMin = APInt::getSignedMinValue(64);
+  APInt uintMax = APInt::getMaxValue(64);
+
+  ConstantIntRanges noUnsignedBound =
+      ConstantIntRanges::fromSigned(negOne, one);
+  EXPECT_EQ(noUnsignedBound.umin(), zero);
+  EXPECT_EQ(noUnsignedBound.umax(), uintMax);
+
+  ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax);
+  EXPECT_EQ(positive.umin(), one);
+  EXPECT_EQ(positive.umax(), intMax);
+
+  ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne);
+  EXPECT_EQ(negative.umin(), intMin);
+  EXPECT_EQ(negative.umax(), negOne);
+
+  ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one);
+  EXPECT_EQ(preserved.umin(), zero);
+  EXPECT_EQ(preserved.umax(), one);
+}
+
+TEST(IntRangeAttrs, Join) {
+  APInt zero = APInt::getZero(64);
+  APInt one = zero + 1;
+  APInt two = zero + 2;
+  APInt intMin = APInt::getSignedMinValue(64);
+  APInt intMax = APInt::getSignedMaxValue(64);
+  APInt uintMax = APInt::getMaxValue(64);
+
+  ConstantIntRanges maximal(zero, uintMax, intMin, intMax);
+  ConstantIntRanges zeroOne(zero, one, zero, one);
+
+  EXPECT_EQ(zeroOne.rangeUnion(maximal), maximal);
+  EXPECT_EQ(maximal.rangeUnion(zeroOne), maximal);
+
+  EXPECT_EQ(zeroOne.rangeUnion(zeroOne), zeroOne);
+
+  ConstantIntRanges oneTwo(one, two, one, two);
+  ConstantIntRanges zeroTwo(zero, two, zero, two);
+  EXPECT_EQ(zeroOne.rangeUnion(oneTwo), zeroTwo);
+
+  ConstantIntRanges zeroOneUnsignedOnly(zero, one, intMin, intMax);
+  ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
+  EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
+}


        


More information about the Mlir-commits mailing list