[Mlir-commits] [mlir] [mlir][Arith] Generalize and improve -int-range-optimizations (PR #94712)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Jun 6 18:25:16 PDT 2024


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/94712

When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where `rem{s,u}i` operations that are noops are replaced by their left operands.

>From e1c84a3a243c6b8c963fa98f916fc70612f6093c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 7 Jun 2024 00:54:48 +0000
Subject: [PATCH] [mlir][Arith] Generalize and improve -int-range-optimizations

When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that
hasn't happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict
subset of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   4 -
 .../mlir/Dialect/Arith/Transforms/Passes.td   |   9 +-
 .../Transforms/IntRangeOptimizations.cpp      | 287 ++++++++----------
 .../Dialect/Arith/int-range-interface.mlir    |   2 +-
 mlir/test/Dialect/Arith/int-range-opts.mlir   |  36 +++
 .../test/Dialect/GPU/int-range-interface.mlir |   2 +-
 .../Dialect/Index/int-range-inference.mlir    |   2 +-
 .../infer-int-range-test-ops.mlir             |  10 +-
 mlir/test/lib/Transforms/CMakeLists.txt       |   1 -
 .../lib/Transforms/TestIntRangeInference.cpp  | 125 --------
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 -
 11 files changed, 181 insertions(+), 299 deletions(-)
 delete mode 100644 mlir/test/lib/Transforms/TestIntRangeInference.cpp

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 9dc262cc72ed0..b8a7d0c78d323 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -64,10 +64,6 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 /// equivalent.
 std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
 
-/// Add patterns for int range based optimizations.
-void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
-                                           DataFlowSolver &solver);
-
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..1517f71f1a7c9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   let summary = "Do optimizations based on integer range analysis";
   let description = [{
     This pass runs integer range analysis and apllies optimizations based on its
-    results. e.g. replace arith.cmpi with const if it can be inferred from
-    args ranges.
+    results. It replaces operations with known-constant results with said constants,
+    rewrites `(0 <= %x < D) mod D` to `%x`.
   }];
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
 }
 
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 2473169962b95..e991d0fbe7410 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -13,7 +13,8 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
@@ -24,155 +25,145 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::dataflow;
 
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
-                       const ConstantIntRanges &rhs) {
-  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
-           (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
+                                         RewriterBase &rewriter,
+                                         OperationFolder &folder, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+  if (!maybeConstValue.has_value())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr =
+      rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+  Value constant = folder.getOrCreateConstant(
+      rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
+  // Fall back to arith.constant if the dialect materializer doesn't know what
+  // to do with an integer constant.
+  if (!constant)
+    constant = folder.getOrCreateConstant(
+        rewriter.getInsertionBlock(),
+        rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
+        value.getType());
+  if (!constant)
+    return failure();
+
+  rewriter.replaceAllUsesWith(value, constant);
+  return success();
 }
 
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return true;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().slt(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sge(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().sle(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sgt(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ult(rhs.umin()))
-    return true;
-
-  if (lhs.umin().uge(rhs.umax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ule(rhs.umin()))
-    return true;
-
-  if (lhs.umin().ugt(rhs.umax()))
-    return false;
-
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased.
+static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            OperationFolder &folder,
+                                            Operation &op) {
+  bool replacedAll = op.getNumResults() != 0;
+  for (Value res : op.getResults())
+    replacedAll &=
+        succeeded(replaceWithConstant(solver, rewriter, folder, res));
+
+  // If all of the results of the operation were replaced, try to erase
+  // the operation completely.
+  if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+    assert(op.use_empty() && "expected all uses to be replaced");
+    rewriter.eraseOp(&op);
+    return success();
+  }
   return failure();
 }
 
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUlt(std::move(rhs), std::move(lhs));
+/// This function hasn't come from anywhere and is relying on the overall
+/// tests of the integer range inference implementation for its correctness.
+static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
+                                            RewriterBase &rewriter,
+                                            Operation &op) {
+  if (!isa<RemSIOp, RemUIOp>(op))
+    return failure();
+  Value lhs = op.getOperand(0);
+  Value rhs = op.getOperand(1);
+  auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
+  if (!rhsConstVal)
+    return failure();
+  int64_t modulus = rhsConstVal.value();
+  if (modulus <= 0)
+    return failure();
+  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
+    return failure();
+  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+  const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+  const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+  // The minima and maxima here are given as closed ranges, we must be strictly
+  // less than the modulus.
+  if (min.isNegative() || min.uge(modulus))
+    return failure();
+  if (max.isNegative() || max.uge(modulus))
+    return failure();
+  if (!min.ule(max))
+    return failure();
+
+  // With all those conditions out of the way, we know thas this invocation of
+  // a remainder is a noop because the input is strictly within the range
+  // [0, modulus), so get rid of it.
+  rewriter.replaceOp(&op, ValueRange{lhs});
+  return success();
 }
 
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUle(std::move(rhs), std::move(lhs));
+static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
+                       MutableArrayRef<Region> initialRegions) {
+  SmallVector<Block *> worklist;
+  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+    for (Region &region : regions)
+      for (Block &block : llvm::reverse(region))
+        worklist.push_back(&block);
+  };
+
+  IRRewriter rewriter(context);
+  OperationFolder folder(context, rewriter.getListener());
+
+  addToWorklist(initialRegions);
+  while (!worklist.empty()) {
+    Block *block = worklist.pop_back_val();
+
+    for (Operation &op : llvm::make_early_inc_range(*block)) {
+      if (matchPattern(&op, m_Constant())) {
+        if (auto arithConstant = dyn_cast<ConstantOp>(op))
+          folder.insertKnownConstant(&op, arithConstant.getValue());
+        else
+          folder.insertKnownConstant(&op);
+        continue;
+      }
+      rewriter.setInsertionPoint(&op);
+
+      // Try rewrites. Success means that the underlying operation was erased.
+      if (succeeded(foldResultsToConstants(solver, rewriter, folder, op)))
+        continue;
+      if (isa<RemSIOp, RemUIOp>(op) &&
+          succeeded(deleteTrivialRemainder(solver, rewriter, op)))
+        continue;
+      // Add any the regions of this operation to the worklist.
+      addToWorklist(op.getRegions());
+    }
+
+    // Replace any block arguments with constants.
+    rewriter.setInsertionPointToStart(block);
+    for (BlockArgument arg : block->getArguments())
+      (void)replaceWithConstant(solver, rewriter, folder, arg);
+  }
 }
 
 namespace {
-/// This class listens on IR transformations performed during a pass relying on
-/// information from a `DataflowSolver`. It erases state associated with the
-/// erased operation and its results from the `DataFlowSolver` so that Patterns
-/// do not accidentally query old state information for newly created Ops.
-class DataFlowListener : public RewriterBase::Listener {
-public:
-  DataFlowListener(DataFlowSolver &s) : s(s) {}
-
-protected:
-  void notifyOperationErased(Operation *op) override {
-    s.eraseState(op);
-    for (Value res : op->getResults())
-      s.eraseState(res);
-  }
-
-  DataFlowSolver &s;
-};
-
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
-
-  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
-      : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
-
-  LogicalResult matchAndRewrite(arith::CmpIOp op,
-                                PatternRewriter &rewriter) const override {
-    auto *lhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
-    if (!lhsResult || lhsResult->getValue().isUninitialized())
-      return failure();
-
-    auto *rhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
-    if (!rhsResult || rhsResult->getValue().isUninitialized())
-      return failure();
-
-    using HandlerFunc =
-        FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
-    std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
-        handlers{};
-    using Pred = arith::CmpIPredicate;
-    handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
-    handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
-    handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
-    handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
-    handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
-    handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
-    handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
-    handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
-    handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
-    handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
-    HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
-    if (!handler)
-      return failure();
-
-    ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
-    ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
-    FailureOr<bool> result = handler(lhsValue, rhsValue);
-
-    if (failed(result))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
-        op, static_cast<int64_t>(*result), /*width*/ 1);
-    return success();
-  }
-
-private:
-  DataFlowSolver &solver;
-};
-
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -185,25 +176,11 @@ struct IntRangeOptimizationsPass
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
 
-    DataFlowListener listener(solver);
-
-    RewritePatternSet patterns(ctx);
-    populateIntRangeOptimizationsPatterns(patterns, solver);
-
-    GreedyRewriteConfig config;
-    config.listener = &listener;
-
-    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-      signalPassFailure();
+    doRewrites(solver, ctx, op->getRegions());
   }
 };
 } // namespace
 
-void mlir::arith::populateIntRangeOptimizationsPatterns(
-    RewritePatternSet &patterns, DataFlowSolver &solver) {
-  patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
-}
-
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 60f0ab41afa48..e00b7692fe396 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @add_min_max
 // CHECK: %[[c3:.*]] = arith.constant 3 : index
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index dd62a481a1246..ea5969a100258 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
   return %1: i8
 }
 
+// -----
+
+// CHECK-LABEL: func @trivial_rem
+// CHECK: [[val:%.+]] = test.with_bounds
+// CHECK: return [[val]]
+func.func @trivial_rem() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @non_const_rhs
+// CHECK: [[mod:%.+]] = arith.remui
+// CHECK: return [[mod]]
+func.func @non_const_rhs() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
+  %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
+  %mod = arith.remui %val, %rhs : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @wraps
+// CHECK: [[mod:%.+]] = arith.remsi
+// CHECK: return [[mod]]
+func.func @wraps() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 980f7e5873e0c..a0917a2fdf110 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @launch_func
 func.func @launch_func(%arg0 : index) {
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
index 2784d5fd5cf70..951624d573a64 100644
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // Most operations are covered by the `arith` tests, which use the same code
 // Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 2106eeefdca4d..1ec3441b1fde8 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
 
 // CHECK-LABEL: func @constant
 // CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop_false()
 func.func @propagate_across_while_loop_false() -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     %false = arith.constant false
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%false) %0 : index
   } do {
   ^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop
 func.func @propagate_across_while_loop(%arg0 : i1) -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%arg0) %0 : index
   } do {
   ^bb0(%i1: index):
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..66b1faf78e2d8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
-  TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
   ${MLIRTestTransformsPDLSrc}
 
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
deleted file mode 100644
index 5758f6acf2f0f..0000000000000
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// TODO: This pass is needed to test integer range inference until that
-// functionality has been integrated into SCCP.
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/TypeID.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include <optional>
-
-using namespace mlir;
-using namespace mlir::dataflow;
-
-/// Patterned after SCCP
-static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
-                                         OperationFolder &folder, Value value) {
-  auto *maybeInferredRange =
-      solver.lookupState<IntegerValueRangeLattice>(value);
-  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
-    return failure();
-  const ConstantIntRanges &inferredRange =
-      maybeInferredRange->getValue().getValue();
-  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
-  if (!maybeConstValue.has_value())
-    return failure();
-
-  Operation *maybeDefiningOp = value.getDefiningOp();
-  Dialect *valueDialect =
-      maybeDefiningOp ? maybeDefiningOp->getDialect()
-                      : value.getParentRegion()->getParentOp()->getDialect();
-  Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
-  Value constant = folder.getOrCreateConstant(
-      b.getInsertionBlock(), valueDialect, constAttr, value.getType());
-  if (!constant)
-    return failure();
-
-  value.replaceAllUsesWith(constant);
-  return success();
-}
-
-static void rewrite(DataFlowSolver &solver, MLIRContext *context,
-                    MutableArrayRef<Region> initialRegions) {
-  SmallVector<Block *> worklist;
-  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
-    for (Region &region : regions)
-      for (Block &block : llvm::reverse(region))
-        worklist.push_back(&block);
-  };
-
-  OpBuilder builder(context);
-  OperationFolder folder(context);
-
-  addToWorklist(initialRegions);
-  while (!worklist.empty()) {
-    Block *block = worklist.pop_back_val();
-
-    for (Operation &op : llvm::make_early_inc_range(*block)) {
-      builder.setInsertionPoint(&op);
-
-      // Replace any result with constants.
-      bool replacedAll = op.getNumResults() != 0;
-      for (Value res : op.getResults())
-        replacedAll &=
-            succeeded(replaceWithConstant(solver, builder, folder, res));
-
-      // If all of the results of the operation were replaced, try to erase
-      // the operation completely.
-      if (replacedAll && wouldOpBeTriviallyDead(&op)) {
-        assert(op.use_empty() && "expected all uses to be replaced");
-        op.erase();
-        continue;
-      }
-
-      // Add any the regions of this operation to the worklist.
-      addToWorklist(op.getRegions());
-    }
-
-    // Replace any block arguments with constants.
-    builder.setInsertionPointToStart(block);
-    for (BlockArgument arg : block->getArguments())
-      (void)replaceWithConstant(solver, builder, folder, arg);
-  }
-}
-
-namespace {
-struct TestIntRangeInference
-    : PassWrapper<TestIntRangeInference, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
-
-  StringRef getArgument() const final { return "test-int-range-inference"; }
-  StringRef getDescription() const final {
-    return "Test integer range inference analysis";
-  }
-
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    DataFlowSolver solver;
-    solver.load<DeadCodeAnalysis>();
-    solver.load<SparseConstantPropagation>();
-    solver.load<IntegerRangeAnalysis>();
-    if (failed(solver.initializeAndRun(op)))
-      return signalPassFailure();
-    rewrite(solver, op->getContext(), op->getRegions());
-  }
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace test {
-void registerTestIntRangeInference() {
-  PassRegistration<TestIntRangeInference>();
-}
-} // end namespace test
-} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0e8b161d51345..b50cae1056ba4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -97,7 +97,6 @@ void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
-void registerTestIntRangeInference();
 void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
 void registerTestInterfaces();
@@ -226,7 +225,6 @@ void registerTestPasses() {
   mlir::test::registerTestFooAnalysisPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();
-  mlir::test::registerTestIntRangeInference();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();



More information about the Mlir-commits mailing list