[Mlir-commits] [mlir] 4722911 - [mlir][Arith] Generalize and improve -int-range-optimizations (#94712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 10 07:56:37 PDT 2024
Author: Krzysztof Drewniak
Date: 2024-06-10T09:56:33-05:00
New Revision: 472291111d9135961305afebe4e283e3e4e7eebc
URL: https://github.com/llvm/llvm-project/commit/472291111d9135961305afebe4e283e3e4e7eebc
DIFF: https://github.com/llvm/llvm-project/commit/472291111d9135961305afebe4e283e3e4e7eebc.diff
LOG: [mlir][Arith] Generalize and improve -int-range-optimizations (#94712)
When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.
Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.
This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
mlir/test/Dialect/Arith/int-range-interface.mlir
mlir/test/Dialect/Arith/int-range-opts.mlir
mlir/test/Dialect/GPU/int-range-interface.mlir
mlir/test/Dialect/Index/int-range-inference.mlir
mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/test/lib/Transforms/TestIntRangeInference.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..1517f71f1a7c9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
let summary = "Do optimizations based on integer range analysis";
let description = [{
This pass runs integer range analysis and apllies optimizations based on its
- results. e.g. replace arith.cmpi with const if it can be inferred from
- args ranges.
+ results. It replaces operations with known-constant results with said constants,
+ rewrites `(0 <= %x < D) mod D` to `%x`.
}];
+ // Explicitly depend on "arith" because this pass could create operations in
+ // `arith` out of thin air in some cases.
+ let dependentDialects = [
+ "::mlir::arith::ArithDialect"
+ ];
}
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 2473169962b95..8005f9103b235 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,11 +8,17 @@
#include <utility>
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::arith {
@@ -24,88 +30,50 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
-/// Returns true if 2 integer ranges have intersection.
-static bool intersects(const ConstantIntRanges &lhs,
- const ConstantIntRanges &rhs) {
- return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
- (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
+ Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return std::nullopt;
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+ return inferredRange.getConstantValue();
}
-static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (!intersects(lhs, rhs))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (!intersects(lhs, rhs))
- return true;
-
- return failure();
-}
-
-static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.smax().slt(rhs.smin()))
- return true;
-
- if (lhs.smin().sge(rhs.smax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.smax().sle(rhs.smin()))
- return true;
-
- if (lhs.smin().sgt(rhs.smax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleSlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleSle(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.umax().ult(rhs.umin()))
- return true;
-
- if (lhs.umin().uge(rhs.umax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- if (lhs.umax().ule(rhs.umin()))
- return true;
-
- if (lhs.umin().ugt(rhs.umax()))
- return false;
-
- return failure();
-}
-
-static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleUlt(std::move(rhs), std::move(lhs));
-}
-
-static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
- return handleUle(std::move(rhs), std::move(lhs));
+/// Patterned after SCCP
+static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
+ PatternRewriter &rewriter,
+ Value value) {
+ if (value.use_empty())
+ return failure();
+ std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
+ if (!maybeConstValue.has_value())
+ return failure();
+
+ Operation *maybeDefiningOp = value.getDefiningOp();
+ Dialect *valueDialect =
+ maybeDefiningOp ? maybeDefiningOp->getDialect()
+ : value.getParentRegion()->getParentOp()->getDialect();
+ Attribute constAttr =
+ rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
+ Operation *constOp = valueDialect->materializeConstant(
+ rewriter, constAttr, value.getType(), value.getLoc());
+ // Fall back to arith.constant if the dialect materializer doesn't know what
+ // to do with an integer constant.
+ if (!constOp)
+ constOp = rewriter.getContext()
+ ->getLoadedDialect<ArithDialect>()
+ ->materializeConstant(rewriter, constAttr, value.getType(),
+ value.getLoc());
+ if (!constOp)
+ return failure();
+
+ rewriter.replaceAllUsesWith(value, constOp->getResult(0));
+ return success();
}
namespace {
-/// This class listens on IR transformations performed during a pass relying on
-/// information from a `DataflowSolver`. It erases state associated with the
-/// erased operation and its results from the `DataFlowSolver` so that Patterns
-/// do not accidentally query old state information for newly created Ops.
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}
@@ -120,52 +88,95 @@ class DataFlowListener : public RewriterBase::Listener {
DataFlowSolver &s;
};
-struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
+/// Rewrite any results of `op` that were inferred to be constant integers to
+/// and replace their uses with that constant. Return success() if all results
+/// where thus replaced and the operation is erased. Also replace any block
+/// arguments with their constant values.
+struct MaterializeKnownConstantValues : public RewritePattern {
+ MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
+ : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
+ solver(s) {}
+
+ LogicalResult match(Operation *op) const override {
+ if (matchPattern(op, m_Constant()))
+ return failure();
- ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
- : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
+ auto needsReplacing = [&](Value v) {
+ return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+ };
+ bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
+ if (op->getNumRegions() == 0)
+ return success(hasConstantResults);
+ bool hasConstantRegionArgs = false;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ hasConstantRegionArgs |=
+ llvm::any_of(block.getArguments(), needsReplacing);
+ }
+ }
+ return success(hasConstantResults || hasConstantRegionArgs);
+ }
- LogicalResult matchAndRewrite(arith::CmpIOp op,
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ bool replacedAll = (op->getNumResults() != 0);
+ for (Value v : op->getResults())
+ replacedAll &=
+ (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
+ v.use_empty());
+ if (replacedAll && isOpTriviallyDead(op)) {
+ rewriter.eraseOp(op);
+ return;
+ }
+
+ PatternRewriter::InsertionGuard guard(rewriter);
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ rewriter.setInsertionPointToStart(&block);
+ for (BlockArgument &arg : block.getArguments()) {
+ (void)maybeReplaceWithConstant(solver, rewriter, arg);
+ }
+ }
+ }
+ }
+
+private:
+ DataFlowSolver &solver;
+};
+
+template <typename RemOp>
+struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
+ DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<RemOp>(context), solver(s) {}
+
+ LogicalResult matchAndRewrite(RemOp op,
PatternRewriter &rewriter) const override {
- auto *lhsResult =
- solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
- if (!lhsResult || lhsResult->getValue().isUninitialized())
+ Value lhs = op.getOperand(0);
+ Value rhs = op.getOperand(1);
+ auto maybeModulus = getConstantIntValue(rhs);
+ if (!maybeModulus.has_value())
return failure();
-
- auto *rhsResult =
- solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
- if (!rhsResult || rhsResult->getValue().isUninitialized())
+ int64_t modulus = *maybeModulus;
+ if (modulus <= 0)
return failure();
-
- using HandlerFunc =
- FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
- std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
- handlers{};
- using Pred = arith::CmpIPredicate;
- handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
- handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
- handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
- handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
- handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
- handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
- handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
- handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
- handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
- handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
-
- HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
- if (!handler)
+ auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
+ if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
return failure();
-
- ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
- ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
- FailureOr<bool> result = handler(lhsValue, rhsValue);
-
- if (failed(result))
+ const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
+ const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
+ const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
+ // The minima and maxima here are given as closed ranges, we must be
+ // strictly less than the modulus.
+ if (min.isNegative() || min.uge(modulus))
+ return failure();
+ if (max.isNegative() || max.uge(modulus))
+ return failure();
+ if (!min.ule(max))
return failure();
- rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
- op, static_cast<int64_t>(*result), /*width*/ 1);
+ // With all those conditions out of the way, we know thas this invocation of
+ // a remainder is a noop because the input is strictly within the range
+ // [0, modulus), so get rid of it.
+ rewriter.replaceOp(op, ValueRange{lhs});
return success();
}
@@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass
void mlir::arith::populateIntRangeOptimizationsPatterns(
RewritePatternSet &patterns, DataFlowSolver &solver) {
- patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
+ patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
+ DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
}
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 60f0ab41afa48..e00b7692fe396 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @add_min_max
// CHECK: %[[c3:.*]] = arith.constant 3 : index
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index dd62a481a1246..ea5969a100258 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -96,3 +96,39 @@ func.func @test() -> i8 {
return %1: i8
}
+// -----
+
+// CHECK-LABEL: func @trivial_rem
+// CHECK: [[val:%.+]] = test.with_bounds
+// CHECK: return [[val]]
+func.func @trivial_rem() -> i8 {
+ %c64 = arith.constant 64 : i8
+ %val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
+ %mod = arith.remsi %val, %c64 : i8
+ return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @non_const_rhs
+// CHECK: [[mod:%.+]] = arith.remui
+// CHECK: return [[mod]]
+func.func @non_const_rhs() -> i8 {
+ %c64 = arith.constant 64 : i8
+ %val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
+ %rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
+ %mod = arith.remui %val, %rhs : i8
+ return %mod : i8
+}
+
+// -----
+
+// CHECK-LABEL: func @wraps
+// CHECK: [[mod:%.+]] = arith.remsi
+// CHECK: return [[mod]]
+func.func @wraps() -> i8 {
+ %c64 = arith.constant 64 : i8
+ %val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
+ %mod = arith.remsi %val, %c64 : i8
+ return %mod : i8
+}
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 980f7e5873e0c..a0917a2fdf110 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @launch_func
func.func @launch_func(%arg0 : index) {
diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir
index 2784d5fd5cf70..951624d573a64 100644
--- a/mlir/test/Dialect/Index/int-range-inference.mlir
+++ b/mlir/test/Dialect/Index/int-range-inference.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
// Most operations are covered by the `arith` tests, which use the same code
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 2106eeefdca4d..1ec3441b1fde8 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
// CHECK-LABEL: func @constant
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
@@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {
// CHECK-LABEL: func @propagate_across_while_loop_false()
func.func @propagate_across_while_loop_false() -> index {
- // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
- // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+ // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
%false = arith.constant false
- // CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%false) %0 : index
} do {
^bb0(%i1: index):
@@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {
// CHECK-LABEL: func @propagate_across_while_loop
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
- // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
- // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
+ // CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
- // CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%arg0) %0 : index
} do {
^bb0(%i1: index):
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..66b1faf78e2d8 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
- TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
${MLIRTestTransformsPDLSrc}
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
deleted file mode 100644
index 5758f6acf2f0f..0000000000000
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ /dev/null
@@ -1,125 +0,0 @@
-//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-// TODO: This pass is needed to test integer range inference until that
-// functionality has been integrated into SCCP.
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/TypeID.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include <optional>
-
-using namespace mlir;
-using namespace mlir::dataflow;
-
-/// Patterned after SCCP
-static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
- OperationFolder &folder, Value value) {
- auto *maybeInferredRange =
- solver.lookupState<IntegerValueRangeLattice>(value);
- if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
- return failure();
- const ConstantIntRanges &inferredRange =
- maybeInferredRange->getValue().getValue();
- std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
- if (!maybeConstValue.has_value())
- return failure();
-
- Operation *maybeDefiningOp = value.getDefiningOp();
- Dialect *valueDialect =
- maybeDefiningOp ? maybeDefiningOp->getDialect()
- : value.getParentRegion()->getParentOp()->getDialect();
- Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
- Value constant = folder.getOrCreateConstant(
- b.getInsertionBlock(), valueDialect, constAttr, value.getType());
- if (!constant)
- return failure();
-
- value.replaceAllUsesWith(constant);
- return success();
-}
-
-static void rewrite(DataFlowSolver &solver, MLIRContext *context,
- MutableArrayRef<Region> initialRegions) {
- SmallVector<Block *> worklist;
- auto addToWorklist = [&](MutableArrayRef<Region> regions) {
- for (Region ®ion : regions)
- for (Block &block : llvm::reverse(region))
- worklist.push_back(&block);
- };
-
- OpBuilder builder(context);
- OperationFolder folder(context);
-
- addToWorklist(initialRegions);
- while (!worklist.empty()) {
- Block *block = worklist.pop_back_val();
-
- for (Operation &op : llvm::make_early_inc_range(*block)) {
- builder.setInsertionPoint(&op);
-
- // Replace any result with constants.
- bool replacedAll = op.getNumResults() != 0;
- for (Value res : op.getResults())
- replacedAll &=
- succeeded(replaceWithConstant(solver, builder, folder, res));
-
- // If all of the results of the operation were replaced, try to erase
- // the operation completely.
- if (replacedAll && wouldOpBeTriviallyDead(&op)) {
- assert(op.use_empty() && "expected all uses to be replaced");
- op.erase();
- continue;
- }
-
- // Add any the regions of this operation to the worklist.
- addToWorklist(op.getRegions());
- }
-
- // Replace any block arguments with constants.
- builder.setInsertionPointToStart(block);
- for (BlockArgument arg : block->getArguments())
- (void)replaceWithConstant(solver, builder, folder, arg);
- }
-}
-
-namespace {
-struct TestIntRangeInference
- : PassWrapper<TestIntRangeInference, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
-
- StringRef getArgument() const final { return "test-int-range-inference"; }
- StringRef getDescription() const final {
- return "Test integer range inference analysis";
- }
-
- void runOnOperation() override {
- Operation *op = getOperation();
- DataFlowSolver solver;
- solver.load<DeadCodeAnalysis>();
- solver.load<SparseConstantPropagation>();
- solver.load<IntegerRangeAnalysis>();
- if (failed(solver.initializeAndRun(op)))
- return signalPassFailure();
- rewrite(solver, op->getContext(), op->getRegions());
- }
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace test {
-void registerTestIntRangeInference() {
- PassRegistration<TestIntRangeInference>();
-}
-} // end namespace test
-} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d2ba3d06835fb..d0de74dd6eaf4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -97,9 +97,11 @@ void registerTestDynamicPipelinePass();
void registerTestEmulateNarrowTypePass();
void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
+void registerTestComposeSubView();
+void registerTestMultiBuffering();
+void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestInterfaces();
-void registerTestIntRangeInference();
void registerTestIRVisitorsPass();
void registerTestLastModifiedPass();
void registerTestLinalgDecomposeOps();
@@ -226,9 +228,11 @@ void registerTestPasses() {
mlir::test::registerTestEmulateNarrowTypePass();
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
+ mlir::test::registerTestComposeSubView();
+ mlir::test::registerTestMultiBuffering();
+ mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
- mlir::test::registerTestIntRangeInference();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestLastModifiedPass();
mlir::test::registerTestLinalgDecomposeOps();
More information about the Mlir-commits
mailing list