[Mlir-commits] [mlir] 95aff23 - Re-land "[mlir] Add integer range inference analysis""

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Jun 3 10:13:53 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-06-03T17:13:48Z
New Revision: 95aff23e29214543360d893f9a61df0ebd1b65d2

URL: https://github.com/llvm/llvm-project/commit/95aff23e29214543360d893f9a61df0ebd1b65d2
DIFF: https://github.com/llvm/llvm-project/commit/95aff23e29214543360d893f9a61df0ebd1b65d2.diff

LOG: Re-land "[mlir] Add integer range inference analysis""

This reverts commit 4e5ce2056e3e85f109a074e80bdd23a10ca2bed9.

This relands commit 1350c9887dca5ba80af8e3c1e61b29d6696eb240.

Reinstates the range analysis with the build issue fixed.

Differential Revision: https://reviews.llvm.org/D126926

Added: 
    mlir/include/mlir/Analysis/IntRangeAnalysis.h
    mlir/include/mlir/Interfaces/InferIntRangeInterface.h
    mlir/include/mlir/Interfaces/InferIntRangeInterface.td
    mlir/lib/Analysis/IntRangeAnalysis.cpp
    mlir/lib/Interfaces/InferIntRangeInterface.cpp
    mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
    mlir/test/lib/Transforms/TestIntRangeInference.cpp
    mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp

Modified: 
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Analysis/CMakeLists.txt
    mlir/lib/Analysis/DataFlowAnalysis.cpp
    mlir/lib/Interfaces/CMakeLists.txt
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    mlir/unittests/Interfaces/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
new file mode 100644
index 0000000000000..b2b604359b48b
--- /dev/null
+++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
@@ -0,0 +1,41 @@
+//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer range inference
+// so that it can be used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
+#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+namespace detail {
+class IntRangeAnalysisImpl;
+} // end namespace detail
+
+class IntRangeAnalysis {
+public:
+  /// Analyze all operations rooted under (but not including)
+  /// `topLevelOperation`.
+  IntRangeAnalysis(Operation *topLevelOperation);
+  IntRangeAnalysis(IntRangeAnalysis &&other);
+  ~IntRangeAnalysis();
+
+  /// Get inferred range for value `v` if one exists.
+  Optional<ConstantIntRanges> getResult(Value v);
+
+private:
+  std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
+};
+} // end namespace mlir
+
+#endif

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index cf075f728fdbf..918f3ea398e47 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(CopyOpInterface)
 add_mlir_interface(DerivedAttributeOpInterface)
+add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(SideEffectInterfaces)

diff  --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
new file mode 100644
index 0000000000000..9a393855d05ff
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -0,0 +1,98 @@
+//===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the integer range inference interface
+// defined in `InferIntRange.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
+#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+/// A set of arbitrary-precision integers representing bounds on a given integer
+/// value. These bounds are inclusive on both ends, so
+/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
+/// the unsigned and signed interpretations of values in order to enable more
+/// precice inference of the interplay between operations with signed and
+/// unsigned semantics.
+class ConstantIntRanges {
+public:
+  /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
+  /// Non-integer values should be bounded by APInts of bitwidth 0.
+  ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
+                    const APInt &smax)
+      : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
+    assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
+           umaxVal.getBitWidth() == sminVal.getBitWidth() &&
+           sminVal.getBitWidth() == smaxVal.getBitWidth() &&
+           "All bounds in the ranges must have the same bitwidth");
+  }
+
+  bool operator==(const ConstantIntRanges &other) const;
+
+  /// The minimum value of an integer when it is interpreted as unsigned.
+  const APInt &umin() const;
+
+  /// The maximum value of an integer when it is interpreted as unsigned.
+  const APInt &umax() const;
+
+  /// The minimum value of an integer when it is interpreted as signed.
+  const APInt &smin() const;
+
+  /// The maximum value of an integer when it is interpreted as signed.
+  const APInt &smax() const;
+
+  /// Return the bitwidth that should be used for integer ranges describing
+  /// `type`. For concrete integer types, this is their bitwidth, for `index`,
+  /// this is the internal storage bitwidth of `index` attributes, and for
+  /// non-integer types this is 0.
+  static unsigned getStorageBitwidth(Type type);
+
+  /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned
+  /// minimum and `max` is both the signed and unsigned maximum.
+  static ConstantIntRanges range(const APInt &min, const APInt &max);
+
+  /// Create an `IntRangeAttrs` with the signed minimum and maximum equal
+  /// to `smin` and `smax`, where the unsigned bounds are constructed from the
+  /// signed ones if they correspond to a contigious range of bit patterns when
+  /// viewed as unsigned values and are left at [0, int_max()] otherwise.
+  static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
+
+  /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal
+  /// to `umin` and `umax` and the signed part equal to `umin` and `umax`
+  /// unless the sign bit changes between the minimum and maximum.
+  static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
+
+  /// Returns the union (computed separately for signed and unsigned bounds)
+  /// of `a` and `b`.
+  ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
+
+  /// If either the signed or unsigned interpretations of the range
+  /// indicate that the value it bounds is a constant, return that constant
+  /// value.
+  Optional<APInt> getConstantValue() const;
+
+  friend raw_ostream &operator<<(raw_ostream &os,
+                                 const ConstantIntRanges &range);
+
+private:
+  APInt uminVal, umaxVal, sminVal, smaxVal;
+};
+
+/// The type of the `setResultRanges` callback provided to ops implementing
+/// InferIntRangeInterface. It should be called once for each integer result
+/// value and be passed the ConstantIntRanges corresponding to that value.
+using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H

diff  --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
new file mode 100644
index 0000000000000..57f8d693b7916
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -0,0 +1,52 @@
+//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-----------------------------------------------------===//
+//
+// Defines the interface for range analysis on scalar integers
+//
+//===-----------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE
+#define MLIR_INTERFACES_INFERINTRANGEINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
+  let description = [{
+    Allows operations to participate in range analysis for scalar integer values by
+    providing a methods that allows them to specify lower and upper bounds on their
+    result(s) given lower and upper bounds on their input(s) if known.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+      Infer the bounds on the results of this op given the bounds on its arguments.
+      For each result value or block argument (that isn't a branch argument,
+      since the dataflow analysis handles those case), the method should call
+      `setValueRange` with that `Value` as an argument. When `setValueRange`
+      is not called for some value, it will recieve a default value of the mimimum
+      and maximum values forits type (the unbounded range).
+
+      When called on an op that also implements the RegionBranchOpInterface
+      or BranchOpInterface, this method should not attempt to infer the values
+      of the branch results, as this will be handled by the analyses that use
+      this interface.
+
+      This function will only be called when at least one result of the op is a
+      scalar integer value or the op has a region.
+
+      `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
+       order. Non-integer arguments will have the an unbounded range of width-0
+       APInts in their `argRanges` element.
+    }],
+    "void", "inferResultRanges", (ins
+      "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
+      "::mlir::SetIntRangeFn":$setResultRanges)
+  >];
+}
+#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 29314ed535931..6c45e40efa9ab 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
   CallGraph.cpp
   DataFlowAnalysis.cpp
   DataLayoutAnalysis.cpp
+  IntRangeAnalysis.cpp
   Liveness.cpp
   SliceAnalysis.cpp
 
@@ -16,6 +17,7 @@ add_mlir_library(MLIRAnalysis
   CallGraph.cpp
   DataFlowAnalysis.cpp
   DataLayoutAnalysis.cpp
+  IntRangeAnalysis.cpp
   Liveness.cpp
   SliceAnalysis.cpp
 
@@ -31,7 +33,9 @@ add_mlir_library(MLIRAnalysis
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
+  MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
   MLIRViewLikeInterface
   )

diff  --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index 9c10595dbb00a..239d9e4060bca 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -359,11 +359,20 @@ void ForwardDataFlowSolver::visitOperation(Operation *op) {
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
       return visitRegionBranchOperation(branch, operandLattices);
 
-    // If we can't, conservatively mark all regions as executable.
-    // TODO: Let the `visitOperation` method decide how to propagate
-    // information to the block arguments.
-    for (Region &region : op->getRegions())
-      markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
+    for (Region &region : op->getRegions()) {
+      analysis.visitNonControlFlowArguments(op, RegionSuccessor(&region),
+                                            operandLattices);
+      // `visitNonControlFlowArguments` is required to define all of the region
+      // argument lattices.
+      assert(llvm::none_of(
+                 region.getArguments(),
+                 [&](Value value) {
+                   return analysis.getLatticeElement(value).isUninitialized();
+                 }) &&
+             "expected `visitNonControlFlowArguments` to define all argument "
+             "lattices");
+      markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/false);
+    }
   }
 
   // If this op produces no results, it can't produce any constants.
@@ -567,12 +576,45 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
       return;
 
+    // If the branch is a RegionBranchTerminatorOpInterface,
+    // construct the set of operand lattices as the set of non control-flow
+    // arguments of the parent and the values this op returns. This allows
+    // for the correct lattices to be passed to getSuccessorsForOperands()
+    // in cases such as scf.while.
+    ArrayRef<AbstractLatticeElement *> branchOpLattices = operandLattices;
+    SmallVector<AbstractLatticeElement *, 0> parentLattices;
+    if (auto regionTerminator =
+            dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+      parentLattices.reserve(regionInterface->getNumOperands());
+      for (Value parentOperand : regionInterface->getOperands()) {
+        AbstractLatticeElement *operandLattice =
+            analysis.lookupLatticeElement(parentOperand);
+        if (!operandLattice || operandLattice->isUninitialized())
+          return;
+        parentLattices.push_back(operandLattice);
+      }
+      unsigned regionNumber = parentRegion->getRegionNumber();
+      OperandRange iterArgs =
+          regionInterface.getSuccessorEntryOperands(regionNumber);
+      OperandRange terminatorArgs =
+          regionTerminator.getSuccessorOperands(regionNumber);
+      assert(iterArgs.size() == terminatorArgs.size() &&
+             "Number of iteration arguments for region should equal number of "
+             "those arguments defined by terminator");
+      if (!iterArgs.empty()) {
+        unsigned iterStart = iterArgs.getBeginOperandIndex();
+        unsigned terminatorStart = terminatorArgs.getBeginOperandIndex();
+        for (unsigned i = 0, e = iterArgs.size(); i < e; ++i)
+          parentLattices[iterStart + i] = operandLattices[terminatorStart + i];
+      }
+      branchOpLattices = parentLattices;
+    }
     // Query the set of successors of the current region using the current
     // optimistic lattice state.
     SmallVector<RegionSuccessor, 1> regionSuccessors;
     analysis.getSuccessorsForOperands(regionInterface,
                                       parentRegion->getRegionNumber(),
-                                      operandLattices, regionSuccessors);
+                                      branchOpLattices, regionSuccessors);
     if (regionSuccessors.empty())
       return;
 
@@ -584,7 +626,7 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
         // region index (if any).
         return *getRegionBranchSuccessorOperands(op, regionIndex);
       };
-      return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
+      return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices,
                                    getOperands);
     }
 

diff  --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
new file mode 100644
index 0000000000000..7e6d61ff89560
--- /dev/null
+++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp
@@ -0,0 +1,325 @@
+//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+
+namespace {
+/// A wrapper around ConstantIntRanges that provides the lattice functions
+/// expected by dataflow analysis.
+struct IntRangeLattice {
+  IntRangeLattice(const ConstantIntRanges &value) : value(value){};
+  IntRangeLattice(ConstantIntRanges &&value) : value(value){};
+
+  bool operator==(const IntRangeLattice &other) const {
+    return value == other.value;
+  }
+
+  /// wrapper around rangeUnion()
+  static IntRangeLattice join(const IntRangeLattice &a,
+                              const IntRangeLattice &b) {
+    return a.value.rangeUnion(b.value);
+  }
+
+  /// Creates a range with bitwidth 0 to represent that we don't know if the
+  /// value being marked overdefined is even an integer.
+  static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
+    APInt noIntValue = APInt::getZeroWidth();
+    return ConstantIntRanges::range(noIntValue, noIntValue);
+  }
+
+  /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+  /// range that is used to mark the value v as unable to be analyzed further,
+  /// where t is the type of v.
+  static IntRangeLattice getPessimisticValueState(Value v) {
+    unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
+    APInt umin = APInt::getMinValue(width);
+    APInt umax = APInt::getMaxValue(width);
+    APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+    APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+    return ConstantIntRanges{umin, umax, smin, smax};
+  }
+
+  ConstantIntRanges value;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace detail {
+class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
+  using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
+
+public:
+  /// Define bounds on the results or block arguments of the operation
+  /// based on the bounds on the arguments given in `operands`
+  ChangeResult
+  visitOperation(Operation *op,
+                 ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
+
+  /// Skip regions of branch ops when we can statically infer constant
+  /// values for operands to the branch op and said op tells us it's safe to do
+  /// so.
+  LogicalResult
+  getSuccessorsForOperands(BranchOpInterface branch,
+                           ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+                           SmallVectorImpl<Block *> &successors) final;
+
+  /// Skip regions of branch or loop ops when we can statically infer constant
+  /// values for operands to the branch op and said op tells us it's safe to do
+  /// so.
+  void
+  getSuccessorsForOperands(RegionBranchOpInterface branch,
+                           Optional<unsigned> sourceIndex,
+                           ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+                           SmallVectorImpl<RegionSuccessor> &successors) final;
+
+  /// Call the InferIntRangeInterface implementation for region-using ops
+  /// that implement it, and infer the bounds of loop induction variables
+  /// for ops that implement LoopLikeOPInterface.
+  ChangeResult visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &region,
+      ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
+};
+} // end namespace detail
+} // end namespace mlir
+
+/// Given the results of getConstant{Lower,Upper}Bound()
+/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
+/// that result if possible.
+static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
+                                  Type boundType,
+                                  detail::IntRangeAnalysisImpl &analysis,
+                                  bool getUpper) {
+  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
+  if (loopBound.hasValue()) {
+    if (loopBound->is<Attribute>()) {
+      if (auto bound =
+              loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
+        return bound.getValue();
+    } else if (loopBound->is<Value>()) {
+      LatticeElement<IntRangeLattice> *lattice =
+          analysis.lookupLatticeElement(loopBound->get<Value>());
+      if (lattice != nullptr)
+        return getUpper ? lattice->getValue().value.smax()
+                        : lattice->getValue().value.smin();
+    }
+  }
+  return getUpper ? APInt::getSignedMaxValue(width)
+                  : APInt::getSignedMinValue(width);
+}
+
+ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
+    Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+  ChangeResult result = ChangeResult::NoChange;
+  // Ignore non-integer outputs - return early if the op has no scalar
+  // integer results
+  bool hasIntegerResult = false;
+  for (Value v : op->getResults()) {
+    if (v.getType().isIntOrIndex())
+      hasIntegerResult = true;
+    else
+      result |= markAllPessimisticFixpoint(v);
+  }
+  if (!hasIntegerResult)
+    return result;
+
+  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+    LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
+    LLVM_DEBUG(inferrable->print(llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "\n");
+    SmallVector<ConstantIntRanges> argRanges(
+        llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
+          return val->getValue().value;
+        }));
+
+    auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+      LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
+      Optional<IntRangeLattice> oldRange;
+      if (!lattice.isUninitialized())
+        oldRange = lattice.getValue();
+      result |= lattice.join(IntRangeLattice(attrs));
+
+      // Catch loop results with loop variant bounds and conservatively make
+      // them [-inf, inf] so we don't circle around infinitely often (because
+      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+      // and often can't).
+      bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+        return op->hasTrait<OpTrait::IsTerminator>();
+      });
+      if (isYieldedResult && oldRange.hasValue() &&
+          !(lattice.getValue() == *oldRange)) {
+        LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+        result |= lattice.markPessimisticFixpoint();
+      }
+    };
+
+    inferrable.inferResultRanges(argRanges, joinCallback);
+    for (Value opResult : op->getResults()) {
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
+      // setResultRange() not called, make pessimistic.
+      if (lattice.isUninitialized())
+        result |= lattice.markPessimisticFixpoint();
+    }
+  } else if (op->getNumRegions() == 0) {
+    // No regions + no result inference method -> unbounded results (ex. memory
+    // ops)
+    result |= markAllPessimisticFixpoint(op->getResults());
+  }
+  return result;
+}
+
+LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
+    BranchOpInterface branch,
+    ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+    SmallVectorImpl<Block *> &successors) {
+  auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
+    Optional<APInt> maybeConstValue =
+        enumPair.value()->getValue().value.getConstantValue();
+
+    if (maybeConstValue) {
+      return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+                              *maybeConstValue);
+    }
+    return {};
+  };
+  SmallVector<Attribute> inferredConsts(
+      llvm::map_range(llvm::enumerate(operands), toConstantAttr));
+  if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
+    successors.push_back(singleSucc);
+    return success();
+  }
+  return failure();
+}
+
+void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
+    RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
+    ArrayRef<LatticeElement<IntRangeLattice> *> operands,
+    SmallVectorImpl<RegionSuccessor> &successors) {
+  auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
+    Optional<APInt> maybeConstValue =
+        enumPair.value()->getValue().value.getConstantValue();
+
+    if (maybeConstValue) {
+      return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
+                              *maybeConstValue);
+    }
+    return {};
+  };
+  SmallVector<Attribute> inferredConsts(
+      llvm::map_range(llvm::enumerate(operands), toConstantAttr));
+  branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
+}
+
+ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
+    Operation *op, const RegionSuccessor &region,
+    ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+    LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
+    LLVM_DEBUG(inferrable->print(llvm::dbgs()));
+    LLVM_DEBUG(llvm::dbgs() << "\n");
+    SmallVector<ConstantIntRanges> argRanges(
+        llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
+          return val->getValue().value;
+        }));
+
+    ChangeResult result = ChangeResult::NoChange;
+    auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+      LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
+      Optional<IntRangeLattice> oldRange;
+      if (!lattice.isUninitialized())
+        oldRange = lattice.getValue();
+      result |= lattice.join(IntRangeLattice(attrs));
+
+      // Catch loop results with loop variant bounds and conservatively make
+      // them [-inf, inf] so we don't circle around infinitely often (because
+      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+      // and often can't).
+      bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+        return op->hasTrait<OpTrait::IsTerminator>();
+      });
+      if (isYieldedValue && oldRange.hasValue() &&
+          !(lattice.getValue() == *oldRange)) {
+        LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+        result |= lattice.markPessimisticFixpoint();
+      }
+    };
+
+    inferrable.inferResultRanges(argRanges, joinCallback);
+    for (Value regionArg : region.getSuccessor()->getArguments()) {
+      LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
+      // setResultRange() not called, make pessimistic.
+      if (lattice.isUninitialized())
+        result |= lattice.markPessimisticFixpoint();
+    }
+
+    return result;
+  }
+
+  // Infer bounds for loop arguments that have static bounds
+  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+    Optional<Value> iv = loop.getSingleInductionVar();
+    if (!iv.hasValue()) {
+      return ForwardDataFlowAnalysis<
+          IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
+    }
+    Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
+    Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
+    Optional<OpFoldResult> step = loop.getSingleStep();
+    APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
+                                     /*getUpper=*/false);
+    APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
+                                     /*getUpper=*/true);
+    // Assume positivity for uniscoverable steps by way of getUpper = true.
+    APInt stepVal =
+        getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
+
+    if (stepVal.isNegative()) {
+      std::swap(min, max);
+    } else {
+      // Correct the upper bound by subtracting 1 so that it becomes a <= bound,
+      // because loops do not generally include their upper bound.
+      max -= 1;
+    }
+
+    LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
+    return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
+  }
+  return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
+      op, region, operands);
+}
+
+IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
+  impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
+      topLevelOperation->getContext());
+  impl->run(topLevelOperation);
+}
+
+IntRangeAnalysis::~IntRangeAnalysis() = default;
+IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
+
+Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
+  LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
+  if (result == nullptr || result->isUninitialized())
+    return llvm::None;
+  return result->getValue().value;
+}

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 1178c40207566..2082ad41a7f27 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
   CopyOpInterface.cpp
   DataLayoutInterfaces.cpp
   DerivedAttributeOpInterface.cpp
+  InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
   SideEffectInterfaces.cpp
@@ -35,6 +36,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
+add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(SideEffectInterfaces)
 add_mlir_interface_library(TilingInterface)

diff  --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
new file mode 100644
index 0000000000000..777ea18456551
--- /dev/null
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -0,0 +1,99 @@
+//===- InferIntRangeInterface.cpp -  Integer range inference interface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
+
+using namespace mlir;
+
+bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
+  return umin().getBitWidth() == other.umin().getBitWidth() &&
+         umin() == other.umin() && umax() == other.umax() &&
+         smin() == other.smin() && smax() == other.smax();
+}
+
+const APInt &ConstantIntRanges::umin() const { return uminVal; }
+
+const APInt &ConstantIntRanges::umax() const { return umaxVal; }
+
+const APInt &ConstantIntRanges::smin() const { return sminVal; }
+
+const APInt &ConstantIntRanges::smax() const { return smaxVal; }
+
+unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
+  if (type.isIndex())
+    return IndexType::kInternalStorageBitWidth;
+  if (auto integerType = type.dyn_cast<IntegerType>())
+    return integerType.getWidth();
+  // Non-integer types have their bounds stored in width 0 `APInt`s.
+  return 0;
+}
+
+ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
+  return {min, max, min, max};
+}
+
+ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
+                                                const APInt &smax) {
+  unsigned int width = smin.getBitWidth();
+  APInt umin, umax;
+  if (smin.isNonNegative() == smax.isNonNegative()) {
+    umin = smin.ult(smax) ? smin : smax;
+    umax = smin.ugt(smax) ? smin : smax;
+  } else {
+    umin = APInt::getMinValue(width);
+    umax = APInt::getMaxValue(width);
+  }
+  return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
+                                                  const APInt &umax) {
+  unsigned int width = umin.getBitWidth();
+  APInt smin, smax;
+  if (umin.isNonNegative() == umax.isNonNegative()) {
+    smin = umin.slt(umax) ? umin : umax;
+    smax = umin.sgt(umax) ? umin : umax;
+  } else {
+    smin = APInt::getSignedMinValue(width);
+    smax = APInt::getSignedMaxValue(width);
+  }
+  return {umin, umax, smin, smax};
+}
+
+ConstantIntRanges
+ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
+  // "Not an integer" poisons everything and also cannot be fed to comparison
+  // operators.
+  if (umin().getBitWidth() == 0)
+    return *this;
+  if (other.umin().getBitWidth() == 0)
+    return other;
+
+  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+
+  return {uminUnion, umaxUnion, sminUnion, smaxUnion};
+}
+
+Optional<APInt> ConstantIntRanges::getConstantValue() const {
+  // Note: we need to exclude the trivially-equal width 0 values here.
+  if (umin() == umax() && umin().getBitWidth() != 0)
+    return umin();
+  if (smin() == smax() && smin().getBitWidth() != 0)
+    return smin();
+  return None;
+}
+
+raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
+  return os << "unsigned : [" << range.umin() << ", " << range.umax()
+            << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+}

diff  --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
new file mode 100644
index 0000000000000..45d506d00d65d
--- /dev/null
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+
+// CHECK-LABEL: func @constant
+// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
+// CHECK: return %[[cst]]
+func.func @constant() -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+                               smin = 3 : index, smax = 3 : index}
+  func.return %0 : index
+}
+
+// CHECK-LABEL: func @increment
+// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
+// CHECK: return %[[cst]]
+func.func @increment() -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %1 = test.increment %0
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @maybe_increment
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @maybe_increment(%arg0 : i1) -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+                               smin = 3 : index, smax = 3 : index}
+  %1 = scf.if %arg0 -> index {
+    scf.yield %0 : index
+  } else {
+    %2 = test.increment %0
+    scf.yield %2 : index
+  }
+  %3 = test.reflect_bounds %1
+  func.return %3 : index
+}
+
+// CHECK-LABEL: func @maybe_increment_br
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @maybe_increment_br(%arg0 : i1) -> index {
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
+                               smin = 3 : index, smax = 3 : index}
+  cf.cond_br %arg0, ^bb0, ^bb1
+^bb0:
+    %1 = test.increment %0
+    cf.br ^bb2(%1 : index)
+^bb1:
+    cf.br ^bb2(%0 : index)
+^bb2(%2 : index):
+  %3 = test.reflect_bounds %2
+  func.return %3 : index
+}
+
+// CHECK-LABEL: func @for_bounds
+// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
+func.func @for_bounds() -> index {
+  %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                                smin = 0 : index, smax = 0 : index}
+  %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
+                                smin = 1 : index, smax = 1 : index}
+  %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
+                                smin = 2 : index, smax = 2 : index}
+
+  %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
+    scf.yield %arg0 : index
+  }
+  %1 = test.reflect_bounds %0
+  func.return %1 : index
+}
+
+// CHECK-LABEL: func @no_analysis_of_loop_variants
+// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
+func.func @no_analysis_of_loop_variants() -> index {
+  %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
+                                smin = 0 : index, smax = 0 : index}
+  %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
+                                smin = 1 : index, smax = 1 : index}
+  %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
+                                smin = 2 : index, smax = 2 : index}
+
+  %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
+    %1 = test.increment %arg2
+    scf.yield %1 : index
+  }
+  %2 = test.reflect_bounds %0
+  func.return %2 : index
+}
+
+// CHECK-LABEL: func @region_args
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
+func.func @region_args() {
+  test.with_bounds_region { umin = 3 : index, umax = 4 : index,
+                            smin = 3 : index, smax = 4 : index } %arg0 {
+    %0 = test.reflect_bounds %arg0
+  }
+  func.return
+}
+
+// CHECK-LABEL: func @func_args_unbound
+// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
+func.func @func_args_unbound(%arg0 : index) -> index {
+  %0 = test.reflect_bounds %arg0
+  func.return %0 : index
+}

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index de0e6d2f7ad07..de495bf8fe3db 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
   MLIRFunc
   MLIRFuncTransforms
   MLIRIR
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRLinalg
   MLIRLinalgTransforms

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 24e5bc6123603..415b545a6eae0 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -14,15 +14,21 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 
@@ -1396,6 +1402,67 @@ LogicalResult TestVerifiersOp::verifyRegions() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                         SetIntRangeFn setResultRanges) {
+  setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
+                                          OperationState &result) {
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  // Parse the input argument
+  OpAsmParser::Argument argInfo;
+  argInfo.type = parser.getBuilder().getIndexType();
+  if (failed(parser.parseArgument(argInfo)))
+    return failure();
+
+  // Parse the body region, and reuse the operand info as the argument info.
+  Region *body = result.addRegion();
+  return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
+}
+
+void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
+  p.printOptionalAttrDict((*this)->getAttrs());
+  p << ' ';
+  p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
+                        /*omitType=*/true);
+  p << ' ';
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void TestWithBoundsRegionOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  Value arg = getRegion().getArgument(0);
+  setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+                                        SetIntRangeFn setResultRanges) {
+  const ConstantIntRanges &range = argRanges[0];
+  APInt one(range.umin().getBitWidth(), 1);
+  setResultRanges(getResult(),
+                  {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+                   range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+}
+
+void TestReflectBoundsOp::inferResultRanges(
+    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+  const ConstantIntRanges &range = argRanges[0];
+  MLIRContext *ctx = getContext();
+  Builder b(ctx);
+  setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
+  setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
+  setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
+  setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+  setResultRanges(getResult(), range);
+}
+
 #include "TestOpEnums.cpp.inc"
 #include "TestOpInterfaces.cpp.inc"
 #include "TestOpStructs.cpp.inc"

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 480585000a89f..a894524311030 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -33,6 +33,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0fd0566803df6..94556c4d59a79 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -789,7 +790,7 @@ def StringAttrPrettyNameOp
 def CustomResultsNameOp
  : TEST_Op<"custom_result_name",
            [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
-  let arguments = (ins 
+  let arguments = (ins
     Variadic<AnyInteger>:$optional,
     StrArrayAttr:$names
   );
@@ -2885,4 +2886,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+def TestWithBoundsOp : TEST_Op<"with_bounds",
+                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                           NoSideEffect]> {
+  let arguments = (ins IndexAttr:$umin,
+                       IndexAttr:$umax,
+                       IndexAttr:$smin,
+                       IndexAttr:$smax);
+  let results = (outs Index:$fakeVal);
+
+  let assemblyFormat = "attr-dict";
+}
+
+def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
+                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                           SingleBlock, NoTerminator]> {
+  let arguments = (ins IndexAttr:$umin,
+                       IndexAttr:$umax,
+                       IndexAttr:$smin,
+                       IndexAttr:$smax);
+  // The region has one argument of index type
+  let regions = (region SizedRegion<1>:$region);
+  let hasCustomAssemblyFormat = 1;
+}
+
+def TestIncrementOp : TEST_Op<"increment",
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface>,
+                         NoSideEffect]> {
+  let arguments = (ins Index:$value);
+  let results = (outs Index:$result);
+
+  let assemblyFormat = "attr-dict $value";
+}
+
+def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
+  let arguments = (ins Index:$value,
+                       OptionalAttr<IndexAttr>:$umin,
+                       OptionalAttr<IndexAttr>:$umax,
+                       OptionalAttr<IndexAttr>:$smin,
+                       OptionalAttr<IndexAttr>:$smax);
+  let results = (outs Index:$result);
+
+  let assemblyFormat = "attr-dict $value";
+}
 #endif // TEST_OPS

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 95c30c34f4950..00856562d8709 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
+  TestIntRangeInference.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
@@ -10,6 +11,8 @@ add_mlir_library(MLIRTestTransforms
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
 
   LINK_LIBS PUBLIC
+  MLIRAnalysis
+  MLIRInferIntRangeInterface
   MLIRTestDialect
   MLIRTransforms
   )

diff  --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
new file mode 100644
index 0000000000000..1bd2a24d3ce6c
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -0,0 +1,115 @@
+//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// TODO: This pass is needed to test integer range inference until that
+// functionality has been integrated into SCCP.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/FoldUtils.h"
+
+using namespace mlir;
+
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
+                                         OpBuilder &b, OperationFolder &folder,
+                                         Value value) {
+  Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
+  if (!maybeInferredRange)
+    return failure();
+  const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
+  Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+  if (!maybeConstValue.hasValue())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
+  Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
+                                              value.getType(), value.getLoc());
+  if (!constant)
+    return failure();
+
+  value.replaceAllUsesWith(constant);
+  return success();
+}
+
+static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
+                    MutableArrayRef<Region> initialRegions) {
+  SmallVector<Block *> worklist;
+  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+    for (Region &region : regions)
+      for (Block &block : llvm::reverse(region))
+        worklist.push_back(&block);
+  };
+
+  OpBuilder builder(context);
+  OperationFolder folder(context);
+
+  addToWorklist(initialRegions);
+  while (!worklist.empty()) {
+    Block *block = worklist.pop_back_val();
+
+    for (Operation &op : llvm::make_early_inc_range(*block)) {
+      builder.setInsertionPoint(&op);
+
+      // Replace any result with constants.
+      bool replacedAll = op.getNumResults() != 0;
+      for (Value res : op.getResults())
+        replacedAll &=
+            succeeded(replaceWithConstant(analysis, builder, folder, res));
+
+      // If all of the results of the operation were replaced, try to erase
+      // the operation completely.
+      if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+        assert(op.use_empty() && "expected all uses to be replaced");
+        op.erase();
+        continue;
+      }
+
+      // Add any the regions of this operation to the worklist.
+      addToWorklist(op.getRegions());
+    }
+
+    // Replace any block arguments with constants.
+    builder.setInsertionPointToStart(block);
+    for (BlockArgument arg : block->getArguments())
+      (void)replaceWithConstant(analysis, builder, folder, arg);
+  }
+}
+
+namespace {
+struct TestIntRangeInference
+    : PassWrapper<TestIntRangeInference, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
+
+  StringRef getArgument() const final { return "test-int-range-inference"; }
+  StringRef getDescription() const final {
+    return "Test integer range inference analysis";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    IntRangeAnalysis analysis(op);
+    rewrite(analysis, op->getContext(), op->getRegions());
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIntRangeInference() {
+  PassRegistration<TestIntRangeInference>();
+}
+} // end namespace test
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e75c2758e88db..aa94294b4ea8d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -79,6 +79,7 @@ void registerTestDynamicPipelinePass();
 void registerTestExpandMathPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
+void registerTestIntRangeInference();
 void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
@@ -175,6 +176,7 @@ void registerTestPasses() {
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();
+  mlir::test::registerTestIntRangeInference();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();

diff  --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index 54a6837b0ed88..c4201b135e917 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_unittest(MLIRInterfacesTests
   ControlFlowInterfacesTest.cpp
   DataLayoutInterfacesTest.cpp
+  InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
 )
 
@@ -10,6 +11,7 @@ target_link_libraries(MLIRInterfacesTests
   MLIRDataLayoutInterfaces
   MLIRDLTI
   MLIRFunc
+  MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRParser
 )

diff  --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
new file mode 100644
index 0000000000000..97c75b3680567
--- /dev/null
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -0,0 +1,99 @@
+//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/APInt.h"
+#include <limits>
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+TEST(IntRangeAttrs, BasicConstructors) {
+  APInt zero = APInt::getZero(64);
+  APInt two(64, 2);
+  APInt three(64, 3);
+  ConstantIntRanges boundedAbove(zero, two, zero, three);
+  EXPECT_EQ(boundedAbove.umin(), zero);
+  EXPECT_EQ(boundedAbove.umax(), two);
+  EXPECT_EQ(boundedAbove.smin(), zero);
+  EXPECT_EQ(boundedAbove.smax(), three);
+}
+
+TEST(IntRangeAttrs, FromUnsigned) {
+  APInt zero = APInt::getZero(64);
+  APInt maxInt = APInt::getSignedMaxValue(64);
+  APInt minInt = APInt::getSignedMinValue(64);
+  APInt minIntPlusOne = minInt + 1;
+
+  ConstantIntRanges canPortToSigned =
+      ConstantIntRanges::fromUnsigned(zero, maxInt);
+  EXPECT_EQ(canPortToSigned.smin(), zero);
+  EXPECT_EQ(canPortToSigned.smax(), maxInt);
+
+  ConstantIntRanges cantPortToSigned =
+      ConstantIntRanges::fromUnsigned(zero, minInt);
+  EXPECT_EQ(cantPortToSigned.smin(), minInt);
+  EXPECT_EQ(cantPortToSigned.smax(), maxInt);
+
+  ConstantIntRanges signedNegative =
+      ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne);
+  EXPECT_EQ(signedNegative.smin(), minInt);
+  EXPECT_EQ(signedNegative.smax(), minIntPlusOne);
+}
+
+TEST(IntRangeAttrs, FromSigned) {
+  APInt zero = APInt::getZero(64);
+  APInt one = zero + 1;
+  APInt negOne = zero - 1;
+  APInt intMax = APInt::getSignedMaxValue(64);
+  APInt intMin = APInt::getSignedMinValue(64);
+  APInt uintMax = APInt::getMaxValue(64);
+
+  ConstantIntRanges noUnsignedBound =
+      ConstantIntRanges::fromSigned(negOne, one);
+  EXPECT_EQ(noUnsignedBound.umin(), zero);
+  EXPECT_EQ(noUnsignedBound.umax(), uintMax);
+
+  ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax);
+  EXPECT_EQ(positive.umin(), one);
+  EXPECT_EQ(positive.umax(), intMax);
+
+  ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne);
+  EXPECT_EQ(negative.umin(), intMin);
+  EXPECT_EQ(negative.umax(), negOne);
+
+  ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one);
+  EXPECT_EQ(preserved.umin(), zero);
+  EXPECT_EQ(preserved.umax(), one);
+}
+
+TEST(IntRangeAttrs, Join) {
+  APInt zero = APInt::getZero(64);
+  APInt one = zero + 1;
+  APInt two = zero + 2;
+  APInt intMin = APInt::getSignedMinValue(64);
+  APInt intMax = APInt::getSignedMaxValue(64);
+  APInt uintMax = APInt::getMaxValue(64);
+
+  ConstantIntRanges maximal(zero, uintMax, intMin, intMax);
+  ConstantIntRanges zeroOne(zero, one, zero, one);
+
+  EXPECT_EQ(zeroOne.rangeUnion(maximal), maximal);
+  EXPECT_EQ(maximal.rangeUnion(zeroOne), maximal);
+
+  EXPECT_EQ(zeroOne.rangeUnion(zeroOne), zeroOne);
+
+  ConstantIntRanges oneTwo(one, two, one, two);
+  ConstantIntRanges zeroTwo(zero, two, zero, two);
+  EXPECT_EQ(zeroOne.rangeUnion(oneTwo), zeroTwo);
+
+  ConstantIntRanges zeroOneUnsignedOnly(zero, one, intMin, intMax);
+  ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
+  EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
+}


        


More information about the Mlir-commits mailing list