[Mlir-commits] [mlir] [mlir][Arith] Generalize and improve -int-range-optimizations (PR #94712)
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Jun 6 18:25:16 PDT 2024
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/94712
When the integer range analysis was first develop, a pass that did integer range-based constant folding was developed and used as a test pass. There was an intent to add such a folding to SCCP, but that hasn't happened.
Meanwhile, -int-range-optimizations was added to the arith dialect's transformations. The cmpi simplification in that pass is a strict subset of the constant folding that lived in
-test-int-range-inference.
This commit moves the former test pass into -int-range-optimizaitons, subsuming its previous contents. It also adds an optimization from rocMLIR where `rem{s,u}i` operations that are noops are replaced by their left operands.
>From e1c84a3a243c6b8c963fa98f916fc70612f6093c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 7 Jun 2024 00:54:48 +0000
Subject: [PATCH] [mlir][Arith] Generalize and improve -int-range-optimizations
When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that
hasn't happened.
Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict
subset of the constant folding that lived in
-test-int-range-inference.
This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 4 -
.../mlir/Dialect/Arith/Transforms/Passes.td | 9 +-
.../Transforms/IntRangeOptimizations.cpp | 287 ++++++++----------
.../Dialect/Arith/int-range-interface.mlir | 2 +-
mlir/test/Dialect/Arith/int-range-opts.mlir | 36 +++
.../test/Dialect/GPU/int-range-interface.mlir | 2 +-
.../Dialect/Index/int-range-inference.mlir | 2 +-
.../infer-int-range-test-ops.mlir | 10 +-
mlir/test/lib/Transforms/CMakeLists.txt | 1 -
.../lib/Transforms/TestIntRangeInference.cpp | 125 --------
mlir/tools/mlir-opt/mlir-opt.cpp | 2 -
11 files changed, 181 insertions(+), 299 deletions(-)
delete mode 100644 mlir/test/lib/Transforms/TestIntRangeInference.cpp
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 9dc262cc72ed0..b8a7d0c78d323 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -64,10 +64,6 @@ void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
/// equivalent.
std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
-/// Add patterns for int range based optimizations.
-void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
- DataFlowSolver &solver);
-
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..1517f71f1a7c9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
let summary = "Do optimizations based on integer range analysis";
let description = [{
This pass runs integer range analysis and apllies optimizations based on its
- results. e.g. replace arith.cmpi with const if it can be inferred from
- args ranges.
+ results. It replaces operations with known-constant results with said constants,
+ rewrites `(0 <= %x < D) mod D` to `%x`.
}];
+ // Explicitly depend on "arith" because this pass could create operations in
+ // `arith` out of thin air in some cases.
+ let dependentDialects = [
+ "::mlir::arith::ArithDialect"
+ ];
}
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 2473169962b95..e991d0fbe7410 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -13,7 +13,8 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/FoldUtils.h"
namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTRANGEOPTS
@@ -24,155 +25,145 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs) {
- return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
- (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+/// Patterned after SCCP
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
+ RewriterBase &rewriter,
+ OperationFolder &folder, Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return failure();
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+ std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
+ if (!maybeConstValue.has_value())
+ return failure();
+
+ Operation *maybeDefiningOp = value.getDefiningOp();
+ Dialect *valueDialect =
+ maybeDefiningOp ? maybeDefiningOp->getDialect()
+ : value.getParentRegion()->getParentOp()->getDialect();
+ Attribute constAttr =
+ rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+ Value constant = folder.getOrCreateConstant(
+ rewriter.getInsertionBlock(), valueDialect, constAttr, value.getType());
+ // Fall back to arith.constant if the dialect materializer doesn't know what
+ // to do with an integer constant.
+ if (!constant)
+ constant = folder.getOrCreateConstant(
+ rewriter.getInsertionBlock(),
+ rewriter.getContext()->getLoadedDialect<ArithDialect>(), constAttr,
+ value.getType());
+ if (!constant)
+ return failure();
+
+ rewriter.replaceAllUsesWith(value, constant);
+ return success();
}
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (!intersects(lhs, rhs))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (!intersects(lhs, rhs))
- return true;
-
- return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.smax().slt(rhs.smin()))
- return true;
-
- if (lhs.smin().sge(rhs.smax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.smax().sle(rhs.smin()))
- return true;
-
- if (lhs.smin().sgt(rhs.smax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.umax().ult(rhs.umin()))
- return true;
-
- if (lhs.umin().uge(rhs.umax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.umax().ule(rhs.umin()))
- return true;
-
- if (lhs.umin().ugt(rhs.umax()))
- return false;
-
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased.
+static LogicalResult foldResultsToConstants(DataFlowSolver &solver,
+ RewriterBase &rewriter,
+ OperationFolder &folder,
+ Operation &op) {
+ bool replacedAll = op.getNumResults() != 0;
+ for (Value res : op.getResults())
+ replacedAll &=
+ succeeded(replaceWithConstant(solver, rewriter, folder, res));
+
+ // If all of the results of the operation were replaced, try to erase
+ // the operation completely.
+ if (replacedAll && wouldOpBeTriviallyDead(&op)) {
+ assert(op.use_empty() && "expected all uses to be replaced");
+ rewriter.eraseOp(&op);
+ return success();
+ }
return failure();
}
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleUlt(std::move(rhs), std::move(lhs));
+/// This function hasn't come from anywhere and is relying on the overall
+/// tests of the integer range inference implementation for its correctness.
+static LogicalResult deleteTrivialRemainder(DataFlowSolver &solver,
+ RewriterBase &rewriter,
+ Operation &op) {
+ if (!isa<RemSIOp, RemUIOp>(op))
+ return failure();
+ Value lhs = op.getOperand(0);
+ Value rhs = op.getOperand(1);
+ auto rhsConstVal = rhs.getDefiningOp<arith::ConstantIntOp>();
+ if (!rhsConstVal)
+ return failure();
+ int64_t modulus = rhsConstVal.value();
+ if (modulus <= 0)
+ return failure();
+ auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+ if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
+ return failure();
+ const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+ const APInt &min = llvm::isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+ const APInt &max = llvm::isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+ // The minima and maxima here are given as closed ranges, we must be strictly
+ // less than the modulus.
+ if (min.isNegative() || min.uge(modulus))
+ return failure();
+ if (max.isNegative() || max.uge(modulus))
+ return failure();
+ if (!min.ule(max))
+ return failure();
+
+ // With all those conditions out of the way, we know thas this invocation of
+ // a remainder is a noop because the input is strictly within the range
+ // [0, modulus), so get rid of it.
+ rewriter.replaceOp(&op, ValueRange{lhs});
+ return success();
}
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleUle(std::move(rhs), std::move(lhs));
+static void doRewrites(DataFlowSolver &solver, MLIRContext *context,
+ MutableArrayRef<Region> initialRegions) {
+ SmallVector<Block *> worklist;
+ auto addToWorklist = [&](MutableArrayRef<Region> regions) {
+ for (Region ®ion : regions)
+ for (Block &block : llvm::reverse(region))
+ worklist.push_back(&block);
+ };
+
+ IRRewriter rewriter(context);
+ OperationFolder folder(context, rewriter.getListener());
+
+ addToWorklist(initialRegions);
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+
+ for (Operation &op : llvm::make_early_inc_range(*block)) {
+ if (matchPattern(&op, m_Constant())) {
+ if (auto arithConstant = dyn_cast<ConstantOp>(op))
+ folder.insertKnownConstant(&op, arithConstant.getValue());
+ else
+ folder.insertKnownConstant(&op);
+ continue;
+ }
+ rewriter.setInsertionPoint(&op);
+
+ // Try rewrites. Success means that the underlying operation was erased.
+ if (succeeded(foldResultsToConstants(solver, rewriter, folder, op)))
+ continue;
+ if (isa<RemSIOp, RemUIOp>(op) &&
+ succeeded(deleteTrivialRemainder(solver, rewriter, op)))
+ continue;
+ // Add any the regions of this operation to the worklist.
+ addToWorklist(op.getRegions());
+ }
+
+ // Replace any block arguments with constants.
+ rewriter.setInsertionPointToStart(block);
+ for (BlockArgument arg : block->getArguments())
+ (void)replaceWithConstant(solver, rewriter, folder, arg);
+ }
}
namespace {
-/// This class listens on IR transformations performed during a pass relying on
-/// information from a `DataflowSolver`. It erases state associated with the
-/// erased operation and its results from the `DataFlowSolver` so that Patterns
-/// do not accidentally query old state information for newly created Ops.
-class DataFlowListener : public RewriterBase::Listener {
-public:
- DataFlowListener(DataFlowSolver &s) : s(s) {}
-
-protected:
- void notifyOperationErased(Operation *op) override {
- s.eraseState(op);
- for (Value res : op->getResults())
- s.eraseState(res);
- }
-
- DataFlowSolver &s;
-};
-
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
-
- ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
- : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
-
- LogicalResult matchAndRewrite(arith::CmpIOp op,
- PatternRewriter &rewriter) const override {
- auto *lhsResult =
- solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
- if (!lhsResult || lhsResult->getValue().isUninitialized())
- return failure();
-
- auto *rhsResult =
- solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
- if (!rhsResult || rhsResult->getValue().isUninitialized())
- return failure();
-
- using HandlerFunc =
- FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
- std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
- handlers{};
- using Pred = arith::CmpIPredicate;
- handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
- handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
- handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
- handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
- handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
- handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
- handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
- handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
- handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
- handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
- HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
- if (!handler)
- return failure();
-
- ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
- ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
- FailureOr<bool> result = handler(lhsValue, rhsValue);
-
- if (failed(result))
- return failure();
-
- rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
- op, static_cast<int64_t>(*result), /*width*/ 1);
- return success();
- }
-
-private:
- DataFlowSolver &solver;
-};
-
struct IntRangeOptimizationsPass
: public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
@@ -185,25 +176,11 @@ struct IntRangeOptimizationsPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
- DataFlowListener listener(solver);
-
- RewritePatternSet patterns(ctx);
- populateIntRangeOptimizationsPatterns(patterns, solver);
-
- GreedyRewriteConfig config;
- config.listener = &listener;
-
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- signalPassFailure();
+ doRewrites(solver, ctx, op->getRegions());
}
};
} // namespace
-void mlir::arith::populateIntRangeOptimizationsPatterns(
- RewritePatternSet &patterns, DataFlowSolver &solver) {
- patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
-}
-
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
return std::make_unique<IntRangeOptimizationsPass>();
}
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 60f0ab41afa48..e00b7692fe396 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @add_min_max
// CHECK: %[[c3:.*]] = arith.constant 3 : index
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index dd62a481a1246..ea5969a100258 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
return %1: i8
}
+// -----
+
+// CHECK-LABEL: func @trivial_rem
+// CHECK: [[val:%.+]] = test.with_bounds
+// CHECK: return [[val]]
+func.func @trivial_rem() -> i8 {
+ %c64 = arith.constant 64 : i8
+ %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
+ %mod = arith.remsi %val, %c64 : i8
+ return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @non_const_rhs
+// CHECK: [[mod:%.+]] = arith.remui
+// CHECK: return [[mod]]
+func.func @non_const_rhs() -> i8 {
+ %c64 = arith.constant 64 : i8
+ %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
+ %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
+ %mod = arith.remui %val, %rhs : i8
+ return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @wraps
+// CHECK: [[mod:%.+]] = arith.remsi
+// CHECK: return [[mod]]
+func.func @wraps() -> i8 {
+ %c64 = arith.constant 64 : i8
+ %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
+ %mod = arith.remsi %val, %c64 : i8
+ return %mod : i8
+}
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 980f7e5873e0c..a0917a2fdf110 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @launch_func
func.func @launch_func(%arg0 : index) {
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
index 2784d5fd5cf70..951624d573a64 100644
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
// Most operations are covered by the `arith` tests, which use the same code
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 2106eeefdca4d..1ec3441b1fde8 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
// CHECK-LABEL: func @constant
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
// CHECK-LABEL: func @propagate_across_while_loop_false()
func.func @propagate_across_while_loop_false() -> index {
- // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
- // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+ // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
%false = arith.constant false
- // CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%false) %0 : index
} do {
^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
// CHECK-LABEL: func @propagate_across_while_loop
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
- // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
- // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+ // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
- // CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%arg0) %0 : index
} do {
^bb0(%i1: index):
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..66b1faf78e2d8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
- TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
${MLIRTestTransformsPDLSrc}
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
deleted file mode 100644
index 5758f6acf2f0f..0000000000000
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// TODO: This pass is needed to test integer range inference until that
-// functionality has been integrated into SCCP.
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/TypeID.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include <optional>
-
-using namespace mlir;
-using namespace mlir::dataflow;
-
-/// Patterned after SCCP
-static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
- OperationFolder &folder, Value value) {
- auto *maybeInferredRange =
- solver.lookupState<IntegerValueRangeLattice>(value);
- if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
- return failure();
- const ConstantIntRanges &inferredRange =
- maybeInferredRange->getValue().getValue();
- std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
- if (!maybeConstValue.has_value())
- return failure();
-
- Operation *maybeDefiningOp = value.getDefiningOp();
- Dialect *valueDialect =
- maybeDefiningOp ? maybeDefiningOp->getDialect()
- : value.getParentRegion()->getParentOp()->getDialect();
- Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
- Value constant = folder.getOrCreateConstant(
- b.getInsertionBlock(), valueDialect, constAttr, value.getType());
- if (!constant)
- return failure();
-
- value.replaceAllUsesWith(constant);
- return success();
-}
-
-static void rewrite(DataFlowSolver &solver, MLIRContext *context,
- MutableArrayRef<Region> initialRegions) {
- SmallVector<Block *> worklist;
- auto addToWorklist = [&](MutableArrayRef<Region> regions) {
- for (Region ®ion : regions)
- for (Block &block : llvm::reverse(region))
- worklist.push_back(&block);
- };
-
- OpBuilder builder(context);
- OperationFolder folder(context);
-
- addToWorklist(initialRegions);
- while (!worklist.empty()) {
- Block *block = worklist.pop_back_val();
-
- for (Operation &op : llvm::make_early_inc_range(*block)) {
- builder.setInsertionPoint(&op);
-
- // Replace any result with constants.
- bool replacedAll = op.getNumResults() != 0;
- for (Value res : op.getResults())
- replacedAll &=
- succeeded(replaceWithConstant(solver, builder, folder, res));
-
- // If all of the results of the operation were replaced, try to erase
- // the operation completely.
- if (replacedAll && wouldOpBeTriviallyDead(&op)) {
- assert(op.use_empty() && "expected all uses to be replaced");
- op.erase();
- continue;
- }
-
- // Add any the regions of this operation to the worklist.
- addToWorklist(op.getRegions());
- }
-
- // Replace any block arguments with constants.
- builder.setInsertionPointToStart(block);
- for (BlockArgument arg : block->getArguments())
- (void)replaceWithConstant(solver, builder, folder, arg);
- }
-}
-
-namespace {
-struct TestIntRangeInference
- : PassWrapper<TestIntRangeInference, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
-
- StringRef getArgument() const final { return "test-int-range-inference"; }
- StringRef getDescription() const final {
- return "Test integer range inference analysis";
- }
-
- void runOnOperation() override {
- Operation *op = getOperation();
- DataFlowSolver solver;
- solver.load<DeadCodeAnalysis>();
- solver.load<SparseConstantPropagation>();
- solver.load<IntegerRangeAnalysis>();
- if (failed(solver.initializeAndRun(op)))
- return signalPassFailure();
- rewrite(solver, op->getContext(), op->getRegions());
- }
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace test {
-void registerTestIntRangeInference() {
- PassRegistration<TestIntRangeInference>();
-}
-} // end namespace test
-} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0e8b161d51345..b50cae1056ba4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -97,7 +97,6 @@ void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
-void registerTestIntRangeInference();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestInterfaces();
@@ -226,7 +225,6 @@ void registerTestPasses() {
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
- mlir::test::registerTestIntRangeInference();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
More information about the Mlir-commits
mailing list