[Mlir-commits] [mlir] 4722911 - [mlir][Arith] Generalize and improve -int-range-optimizations (#94712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 10 07:56:37 PDT 2024


Author: Krzysztof Drewniak
Date: 2024-06-10T09:56:33-05:00
New Revision: 472291111d9135961305afebe4e283e3e4e7eebc

URL: https://github.com/llvm/llvm-project/commit/472291111d9135961305afebe4e283e3e4e7eebc
DIFF: https://github.com/llvm/llvm-project/commit/472291111d9135961305afebe4e283e3e4e7eebc.diff

LOG: [mlir][Arith] Generalize and improve -int-range-optimizations (#94712)

When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
    mlir/test/Dialect/Arith/int-range-interface.mlir
    mlir/test/Dialect/Arith/int-range-opts.mlir
    mlir/test/Dialect/GPU/int-range-interface.mlir
    mlir/test/Dialect/Index/int-range-inference.mlir
    mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/test/lib/Transforms/TestIntRangeInference.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..1517f71f1a7c9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   let summary = "Do optimizations based on integer range analysis";
   let description = [{
     This pass runs integer range analysis and apllies optimizations based on its
-    results. e.g. replace arith.cmpi with const if it can be inferred from
-    args ranges.
+    results. It replaces operations with known-constant results with said constants,
+    rewrites `(0 <= %x < D) mod D` to `%x`.
   }];
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
 }
 
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {

diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 2473169962b95..8005f9103b235 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,11 +8,17 @@
 
 #include <utility>
 
+#include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir::arith {
@@ -24,88 +30,50 @@ using namespace mlir;
 using namespace mlir::arith;
 using namespace mlir::dataflow;
 
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
-                       const ConstantIntRanges &rhs) {
-  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
-           (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
+                                                  Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return std::nullopt;
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  return inferredRange.getConstantValue();
 }
 
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (!intersects(lhs, rhs))
-    return true;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().slt(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sge(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.smax().sle(rhs.smin()))
-    return true;
-
-  if (lhs.smin().sgt(rhs.smax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ult(rhs.umin()))
-    return true;
-
-  if (lhs.umin().uge(rhs.umax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  if (lhs.umax().ule(rhs.umin()))
-    return true;
-
-  if (lhs.umin().ugt(rhs.umax()))
-    return false;
-
-  return failure();
-}
-
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
-  return handleUle(std::move(rhs), std::move(lhs));
+/// Patterned after SCCP
+static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
+                                              PatternRewriter &rewriter,
+                                              Value value) {
+  if (value.use_empty())
+    return failure();
+  std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
+  if (!maybeConstValue.has_value())
+    return failure();
+
+  Operation *maybeDefiningOp = value.getDefiningOp();
+  Dialect *valueDialect =
+      maybeDefiningOp ? maybeDefiningOp->getDialect()
+                      : value.getParentRegion()->getParentOp()->getDialect();
+  Attribute constAttr =
+      rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+  Operation *constOp = valueDialect->materializeConstant(
+      rewriter, constAttr, value.getType(), value.getLoc());
+  // Fall back to arith.constant if the dialect materializer doesn't know what
+  // to do with an integer constant.
+  if (!constOp)
+    constOp = rewriter.getContext()
+                  ->getLoadedDialect<ArithDialect>()
+                  ->materializeConstant(rewriter, constAttr, value.getType(),
+                                        value.getLoc());
+  if (!constOp)
+    return failure();
+
+  rewriter.replaceAllUsesWith(value, constOp->getResult(0));
+  return success();
 }
 
 namespace {
-/// This class listens on IR transformations performed during a pass relying on
-/// information from a `DataflowSolver`. It erases state associated with the
-/// erased operation and its results from the `DataFlowSolver` so that Patterns
-/// do not accidentally query old state information for newly created Ops.
 class DataFlowListener : public RewriterBase::Listener {
 public:
   DataFlowListener(DataFlowSolver &s) : s(s) {}
@@ -120,52 +88,95 @@ class DataFlowListener : public RewriterBase::Listener {
   DataFlowSolver &s;
 };
 
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased. Also replace any block
+/// arguments with their constant values.
+struct MaterializeKnownConstantValues : public RewritePattern {
+  MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
+      : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
+        solver(s) {}
+
+  LogicalResult match(Operation *op) const override {
+    if (matchPattern(op, m_Constant()))
+      return failure();
 
-  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
-      : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
+    auto needsReplacing = [&](Value v) {
+      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+    };
+    bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
+    if (op->getNumRegions() == 0)
+      return success(hasConstantResults);
+    bool hasConstantRegionArgs = false;
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        hasConstantRegionArgs |=
+            llvm::any_of(block.getArguments(), needsReplacing);
+      }
+    }
+    return success(hasConstantResults || hasConstantRegionArgs);
+  }
 
-  LogicalResult matchAndRewrite(arith::CmpIOp op,
+  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+    bool replacedAll = (op->getNumResults() != 0);
+    for (Value v : op->getResults())
+      replacedAll &=
+          (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
+           v.use_empty());
+    if (replacedAll && isOpTriviallyDead(op)) {
+      rewriter.eraseOp(op);
+      return;
+    }
+
+    PatternRewriter::InsertionGuard guard(rewriter);
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        rewriter.setInsertionPointToStart(&block);
+        for (BlockArgument &arg : block.getArguments()) {
+          (void)maybeReplaceWithConstant(solver, rewriter, arg);
+        }
+      }
+    }
+  }
+
+private:
+  DataFlowSolver &solver;
+};
+
+template <typename RemOp>
+struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
+  DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<RemOp>(context), solver(s) {}
+
+  LogicalResult matchAndRewrite(RemOp op,
                                 PatternRewriter &rewriter) const override {
-    auto *lhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
-    if (!lhsResult || lhsResult->getValue().isUninitialized())
+    Value lhs = op.getOperand(0);
+    Value rhs = op.getOperand(1);
+    auto maybeModulus = getConstantIntValue(rhs);
+    if (!maybeModulus.has_value())
       return failure();
-
-    auto *rhsResult =
-        solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
-    if (!rhsResult || rhsResult->getValue().isUninitialized())
+    int64_t modulus = *maybeModulus;
+    if (modulus <= 0)
       return failure();
-
-    using HandlerFunc =
-        FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
-    std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
-        handlers{};
-    using Pred = arith::CmpIPredicate;
-    handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
-    handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
-    handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
-    handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
-    handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
-    handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
-    handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
-    handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
-    handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
-    handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
-    HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
-    if (!handler)
+    auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+    if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
       return failure();
-
-    ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
-    ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
-    FailureOr<bool> result = handler(lhsValue, rhsValue);
-
-    if (failed(result))
+    const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+    const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+    const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+    // The minima and maxima here are given as closed ranges, we must be
+    // strictly less than the modulus.
+    if (min.isNegative() || min.uge(modulus))
+      return failure();
+    if (max.isNegative() || max.uge(modulus))
+      return failure();
+    if (!min.ule(max))
       return failure();
 
-    rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
-        op, static_cast<int64_t>(*result), /*width*/ 1);
+    // With all those conditions out of the way, we know thas this invocation of
+    // a remainder is a noop because the input is strictly within the range
+    // [0, modulus), so get rid of it.
+    rewriter.replaceOp(op, ValueRange{lhs});
     return success();
   }
 
@@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass
 
 void mlir::arith::populateIntRangeOptimizationsPatterns(
     RewritePatternSet &patterns, DataFlowSolver &solver) {
-  patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
+  patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
+               DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

diff  --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 60f0ab41afa48..e00b7692fe396 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @add_min_max
 // CHECK: %[[c3:.*]] = arith.constant 3 : index

diff  --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index dd62a481a1246..ea5969a100258 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
   return %1: i8
 }
 
+// -----
+
+// CHECK-LABEL: func @trivial_rem
+// CHECK: [[val:%.+]] = test.with_bounds
+// CHECK: return [[val]]
+func.func @trivial_rem() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @non_const_rhs
+// CHECK: [[mod:%.+]] = arith.remui
+// CHECK: return [[mod]]
+func.func @non_const_rhs() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
+  %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
+  %mod = arith.remui %val, %rhs : i8
+  return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @wraps
+// CHECK: [[mod:%.+]] = arith.remsi
+// CHECK: return [[mod]]
+func.func @wraps() -> i8 {
+  %c64 = arith.constant 64 : i8
+  %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
+  %mod = arith.remsi %val, %c64 : i8
+  return %mod : i8
+}

diff  --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 980f7e5873e0c..a0917a2fdf110 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @launch_func
 func.func @launch_func(%arg0 : index) {

diff  --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
index 2784d5fd5cf70..951624d573a64 100644
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
 
 // Most operations are covered by the `arith` tests, which use the same code
 // Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling

diff  --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 2106eeefdca4d..1ec3441b1fde8 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
 
 // CHECK-LABEL: func @constant
 // CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop_false()
 func.func @propagate_across_while_loop_false() -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     %false = arith.constant false
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%false) %0 : index
   } do {
   ^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
 
 // CHECK-LABEL: func @propagate_across_while_loop
 func.func @propagate_across_while_loop(%arg0 : i1) -> index {
-  // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
-  // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+  // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
                           smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
-    // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%arg0) %0 : index
   } do {
   ^bb0(%i1: index):

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..66b1faf78e2d8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
-  TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
   ${MLIRTestTransformsPDLSrc}
 

diff  --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
deleted file mode 100644
index 5758f6acf2f0f..0000000000000
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// TODO: This pass is needed to test integer range inference until that
-// functionality has been integrated into SCCP.
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/TypeID.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include <optional>
-
-using namespace mlir;
-using namespace mlir::dataflow;
-
-/// Patterned after SCCP
-static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
-                                         OperationFolder &folder, Value value) {
-  auto *maybeInferredRange =
-      solver.lookupState<IntegerValueRangeLattice>(value);
-  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
-    return failure();
-  const ConstantIntRanges &inferredRange =
-      maybeInferredRange->getValue().getValue();
-  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
-  if (!maybeConstValue.has_value())
-    return failure();
-
-  Operation *maybeDefiningOp = value.getDefiningOp();
-  Dialect *valueDialect =
-      maybeDefiningOp ? maybeDefiningOp->getDialect()
-                      : value.getParentRegion()->getParentOp()->getDialect();
-  Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
-  Value constant = folder.getOrCreateConstant(
-      b.getInsertionBlock(), valueDialect, constAttr, value.getType());
-  if (!constant)
-    return failure();
-
-  value.replaceAllUsesWith(constant);
-  return success();
-}
-
-static void rewrite(DataFlowSolver &solver, MLIRContext *context,
-                    MutableArrayRef<Region> initialRegions) {
-  SmallVector<Block *> worklist;
-  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
-    for (Region &region : regions)
-      for (Block &block : llvm::reverse(region))
-        worklist.push_back(&block);
-  };
-
-  OpBuilder builder(context);
-  OperationFolder folder(context);
-
-  addToWorklist(initialRegions);
-  while (!worklist.empty()) {
-    Block *block = worklist.pop_back_val();
-
-    for (Operation &op : llvm::make_early_inc_range(*block)) {
-      builder.setInsertionPoint(&op);
-
-      // Replace any result with constants.
-      bool replacedAll = op.getNumResults() != 0;
-      for (Value res : op.getResults())
-        replacedAll &=
-            succeeded(replaceWithConstant(solver, builder, folder, res));
-
-      // If all of the results of the operation were replaced, try to erase
-      // the operation completely.
-      if (replacedAll && wouldOpBeTriviallyDead(&op)) {
-        assert(op.use_empty() && "expected all uses to be replaced");
-        op.erase();
-        continue;
-      }
-
-      // Add any the regions of this operation to the worklist.
-      addToWorklist(op.getRegions());
-    }
-
-    // Replace any block arguments with constants.
-    builder.setInsertionPointToStart(block);
-    for (BlockArgument arg : block->getArguments())
-      (void)replaceWithConstant(solver, builder, folder, arg);
-  }
-}
-
-namespace {
-struct TestIntRangeInference
-    : PassWrapper<TestIntRangeInference, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
-
-  StringRef getArgument() const final { return "test-int-range-inference"; }
-  StringRef getDescription() const final {
-    return "Test integer range inference analysis";
-  }
-
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    DataFlowSolver solver;
-    solver.load<DeadCodeAnalysis>();
-    solver.load<SparseConstantPropagation>();
-    solver.load<IntegerRangeAnalysis>();
-    if (failed(solver.initializeAndRun(op)))
-      return signalPassFailure();
-    rewrite(solver, op->getContext(), op->getRegions());
-  }
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace test {
-void registerTestIntRangeInference() {
-  PassRegistration<TestIntRangeInference>();
-}
-} // end namespace test
-} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d2ba3d06835fb..d0de74dd6eaf4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -97,9 +97,11 @@ void registerTestDynamicPipelinePass();
 void registerTestEmulateNarrowTypePass();
 void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
+void registerTestComposeSubView();
+void registerTestMultiBuffering();
+void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
 void registerTestInterfaces();
-void registerTestIntRangeInference();
 void registerTestIRVisitorsPass();
 void registerTestLastModifiedPass();
 void registerTestLinalgDecomposeOps();
@@ -226,9 +228,11 @@ void registerTestPasses() {
   mlir::test::registerTestEmulateNarrowTypePass();
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestFooAnalysisPass();
+  mlir::test::registerTestComposeSubView();
+  mlir::test::registerTestMultiBuffering();
+  mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
-  mlir::test::registerTestIntRangeInference();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestLastModifiedPass();
   mlir::test::registerTestLinalgDecomposeOps();


        


More information about the Mlir-commits mailing list