[Mlir-commits] [mlir] 1a867bf - [mlir][arith] Optimize arith.cmpi based on integer range analysis.
Ivan Butygin
llvmlistbot at llvm.org
Wed Jan 11 03:16:06 PST 2023
Author: Ivan Butygin
Date: 2023-01-11T12:15:58+01:00
New Revision: 1a867bf1c7ccd4fe38ca59346f4b6268643940bb
URL: https://github.com/llvm/llvm-project/commit/1a867bf1c7ccd4fe38ca59346f4b6268643940bb
DIFF: https://github.com/llvm/llvm-project/commit/1a867bf1c7ccd4fe38ca59346f4b6268643940bb.diff
LOG: [mlir][arith] Optimize arith.cmpi based on integer range analysis.
Add a pass which do arith dialect ops optimization based on integer range analysis (only cmpi for now).
Differential Revision: https://reviews.llvm.org/D140629
Added:
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
mlir/test/Dialect/Arith/int-range-opts.mlir
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index d087ac69828a9..257a62aa39f78 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -12,10 +12,14 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class DataFlowSolver;
+
namespace arith {
#define GEN_PASS_DECL
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+#define GEN_PASS_DECL_ARITHINTRANGEOPTS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
class WideIntEmulationConverter;
@@ -44,6 +48,13 @@ std::unique_ptr<Pass> createArithExpandOpsPass();
/// equivalent.
std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
+/// Add patterns for int range based optimizations.
+void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver);
+
+/// Create a pass which do optimizations based on integer range analysis.
+std::unique_ptr<Pass> createIntRangeOptimizationsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 16ef294a90d28..ee561e655965f 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,6 +49,15 @@ def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
let constructor = "mlir::arith::createArithUnsignedWhenEquivalentPass()";
}
+def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
+ let summary = "Do optimizations based on integer range analysis";
+ let description = [{
+ This pass runs integer range analysis and apllies optimizations based on its
+ results. e.g. replace arith.cmpi with const if it can be inferred from
+ args ranges.
+ }];
+}
+
def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
let summary = "Emulate 2*N-bit integer operations using N-bit operations";
let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index b45ae48e80181..9f098f006d2e8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRArithTransforms
Bufferize.cpp
EmulateWideInt.cpp
ExpandOps.cpp
+ IntRangeOptimizations.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
new file mode 100644
index 0000000000000..7f34c0a20d517
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -0,0 +1,183 @@
+//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHINTRANGEOPTS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::dataflow;
+
+/// Returns true if 2 integer ranges have intersection.
+static bool intersects(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
+ (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
+}
+
+static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ if (!intersects(lhs, rhs))
+ return false;
+
+ return failure();
+}
+
+static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ if (!intersects(lhs, rhs))
+ return true;
+
+ return failure();
+}
+
+static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ if (lhs.smax().slt(rhs.smin()))
+ return true;
+
+ if (lhs.smin().sge(rhs.smax()))
+ return false;
+
+ return failure();
+}
+
+static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ if (lhs.smax().sle(rhs.smin()))
+ return true;
+
+ if (lhs.smin().sgt(rhs.smax()))
+ return false;
+
+ return failure();
+}
+
+static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ return handleSlt(rhs, lhs);
+}
+
+static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ return handleSle(rhs, lhs);
+}
+
+static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ if (lhs.umax().ult(rhs.umin()))
+ return true;
+
+ if (lhs.umin().uge(rhs.umax()))
+ return false;
+
+ return failure();
+}
+
+static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ if (lhs.umax().ule(rhs.umin()))
+ return true;
+
+ if (lhs.umin().ugt(rhs.umax()))
+ return false;
+
+ return failure();
+}
+
+static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ return handleUlt(rhs, lhs);
+}
+
+static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
+ return handleUle(rhs, lhs);
+}
+
+namespace {
+struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
+
+ ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
+
+ LogicalResult matchAndRewrite(arith::CmpIOp op,
+ PatternRewriter &rewriter) const override {
+ auto *lhsResult =
+ solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
+ if (!lhsResult || lhsResult->getValue().isUninitialized())
+ return failure();
+
+ auto *rhsResult =
+ solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
+ if (!rhsResult || rhsResult->getValue().isUninitialized())
+ return failure();
+
+ using HandlerFunc =
+ FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
+ std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
+ handlers{};
+ using Pred = arith::CmpIPredicate;
+ handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
+ handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
+ handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
+ handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
+ handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
+ handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
+ handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
+ handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
+ handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
+ handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
+
+ HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
+ if (!handler)
+ return failure();
+
+ ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
+ ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
+ FailureOr<bool> result = handler(lhsValue, rhsValue);
+
+ if (failed(result))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
+ op, static_cast<int64_t>(*result), /*width*/ 1);
+ return success();
+ }
+
+private:
+ DataFlowSolver &solver;
+};
+
+struct IntRangeOptimizationsPass
+ : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ RewritePatternSet patterns(ctx);
+ populateIntRangeOptimizationsPatterns(patterns, solver);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::arith::populateIntRangeOptimizationsPatterns(
+ RewritePatternSet &patterns, DataFlowSolver &solver) {
+ patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
+}
+
+std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
+ return std::make_unique<IntRangeOptimizationsPass>();
+}
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
new file mode 100644
index 0000000000000..be0a7e8ccd70b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @test
+// CHECK: %[[C:.*]] = arith.constant false
+// CHECK: return %[[C]]
+func.func @test() -> i1 {
+ %cst1 = arith.constant -1 : index
+ %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = arith.cmpi eq, %0, %cst1 : index
+ return %1: i1
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: %[[C:.*]] = arith.constant true
+// CHECK: return %[[C]]
+func.func @test() -> i1 {
+ %cst1 = arith.constant -1 : index
+ %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = arith.cmpi ne, %0, %cst1 : index
+ return %1: i1
+}
+
+// -----
+
+
+// CHECK-LABEL: func @test
+// CHECK: %[[C:.*]] = arith.constant true
+// CHECK: return %[[C]]
+func.func @test() -> i1 {
+ %cst = arith.constant 0 : index
+ %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = arith.cmpi sge, %0, %cst : index
+ return %1: i1
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: %[[C:.*]] = arith.constant false
+// CHECK: return %[[C]]
+func.func @test() -> i1 {
+ %cst = arith.constant 0 : index
+ %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = arith.cmpi slt, %0, %cst : index
+ return %1: i1
+}
+
+// -----
+
+
+// CHECK-LABEL: func @test
+// CHECK: %[[C:.*]] = arith.constant true
+// CHECK: return %[[C]]
+func.func @test() -> i1 {
+ %cst1 = arith.constant -1 : index
+ %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = arith.cmpi sgt, %0, %cst1 : index
+ return %1: i1
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK: %[[C:.*]] = arith.constant false
+// CHECK: return %[[C]]
+func.func @test() -> i1 {
+ %cst1 = arith.constant -1 : index
+ %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+ %1 = arith.cmpi sle, %0, %cst1 : index
+ return %1: i1
+}
More information about the Mlir-commits
mailing list