[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)
Ivan Butygin
llvmlistbot at llvm.org
Tue Oct 15 10:00:32 PDT 2024
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/112404
This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program. They ideally should be unified, but it's a task for the future.
>From 584e8e5fc5ade2df915f02d85ff119b676b77d8b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 16:32:32 +0200
Subject: [PATCH 1/2] [mlir] Add `arith-int-range-narrowing` pass
This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified in the future, but it's a task for the future.
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 6 +
.../mlir/Dialect/Arith/Transforms/Passes.td | 19 ++
.../Transforms/IntRangeOptimizations.cpp | 260 ++++++++++++++++++
.../Dialect/Arith/int-range-narrowing.mlir | 57 ++++
4 files changed, 342 insertions(+)
create mode 100644 mlir/test/Dialect/Arith/int-range-narrowing.mlir
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..beeefea497d9b2 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,6 +73,12 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
+/// Add patterns for int range based norrowing.
+void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver,
+ unsigned targetBitwidth);
+
+// TODO: merge these two narrowing passes.
/// Add patterns for integer bitwidth narrowing.
void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
const ArithIntNarrowingOptions &options);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 1517f71f1a7c90..8c565f6489638e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
];
}
+def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
+ let summary = "Reduce integer operations bitwidth based on integer range analysis";
+ let description = [{
+ This pass runs integer range analysis and tries to narrow arith ops to the
+ specified bitwidth based on its results.
+ }];
+
+ let options = [
+ Option<"targetBitwidth", "target-bitwidth", "unsigned",
+ /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+ ];
+
+ // Explicitly depend on "arith" because this pass could create operations in
+ // `arith` out of thin air in some cases.
+ let dependentDialects = [
+ "::mlir::arith::ArithDialect"
+ ];
+}
+
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
let summary = "Emulate operations on unsupported floats with extf/truncf";
let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..4e7b44c939e796 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -15,8 +15,10 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,6 +26,9 @@
namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTRANGEOPTS
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+
+#define GEN_PASS_DEF_ARITHINTRANGENARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace mlir::arith
using namespace mlir;
@@ -184,6 +189,223 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DataFlowSolver &solver;
};
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+ type = getElementTypeOrSelf(type);
+ if (isa<IndexType>(type))
+ return type;
+
+ if (auto intType = dyn_cast<IntegerType>(type))
+ if (intType.getWidth() > targetBitwidth)
+ return type;
+
+ return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+ if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+ return nullptr;
+
+ Type type;
+ for (auto range :
+ {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+ for (Value val : range) {
+ if (!type) {
+ type = val.getType();
+ continue;
+ } else if (type != val.getType()) {
+ return nullptr;
+ }
+ }
+ }
+
+ return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+ ValueRange results) {
+ std::optional<ConstantIntRanges> ret;
+ for (Value value : results) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return std::nullopt;
+
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+
+ if (!ret) {
+ ret = inferredRange;
+ } else {
+ ret = ret->rangeUnion(inferredRange);
+ }
+ }
+ return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+ auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+ if (auto shaped = dyn_cast<ShapedType>(srcType))
+ return shaped.clone(dstType);
+
+ assert(srcType.isIntOrIndex() && "Invalid src type");
+ return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+ APInt umin, APInt umax) {
+ auto sge = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.sext(width);
+ val2 = val2.sext(width);
+ return val1.sge(val2);
+ };
+ auto sle = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.sext(width);
+ val2 = val2.sext(width);
+ return val1.sle(val2);
+ };
+ auto uge = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.zext(width);
+ val2 = val2.zext(width);
+ return val1.uge(val2);
+ };
+ auto ule = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.zext(width);
+ val2 = val2.zext(width);
+ return val1.ule(val2);
+ };
+ return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+ uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+ Type srcType = src.getType();
+ assert(srcType.isIntOrIndex() && "Invalid src type");
+ assert(dstType.isIntOrIndex() && "Invalid dst type");
+ if (srcType == dstType)
+ return src;
+
+ if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+ return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+ auto srcInt = cast<IntegerType>(srcType);
+ auto dstInt = cast<IntegerType>(dstType);
+ if (dstInt.getWidth() < srcInt.getWidth()) {
+ return builder.create<arith::TruncIOp>(loc, dstType, src);
+ } else {
+ return builder.create<arith::ExtUIOp>(loc, dstType, src);
+ }
+}
+
+struct NarrowElementwise final
+ : public OpTraitRewritePattern<OpTrait::Elementwise> {
+ NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+ : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+ targetBitwidth(target) {}
+
+ using OpTraitRewritePattern::OpTraitRewritePattern;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ Type srcType = checkElementwiseOpType(op, targetBitwidth);
+ if (!srcType)
+ return failure();
+
+ std::optional<ConstantIntRanges> range =
+ getOperandsRange(solver, op->getResults());
+ if (!range)
+ return failure();
+
+ // We are truncating op args to the desired bitwidth before the op and then
+ // extending op results back to the original width after.
+ // extui and exti will produce different results for negative values, so
+ // limit signed range to non-negative values.
+ auto smin = APInt::getZero(targetBitwidth);
+ auto smax = APInt::getSignedMaxValue(targetBitwidth);
+ auto umin = APInt::getMinValue(targetBitwidth);
+ auto umax = APInt::getMaxValue(targetBitwidth);
+ if (!checkRange(*range, smin, smax, umin, umax))
+ return failure();
+
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ return failure();
+
+ Location loc = op->getLoc();
+ IRMapping mapping;
+ for (Value arg : op->getOperands()) {
+ Value newArg = doCast(rewriter, loc, arg, targetType);
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ rewriter.modifyOpInPlace(newOp, [&]() {
+ for (OpResult res : newOp->getResults()) {
+ res.setType(targetType);
+ }
+ });
+ SmallVector<Value> newResults;
+ for (Value res : newOp->getResults())
+ newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+
+private:
+ DataFlowSolver &solver;
+ unsigned targetBitwidth;
+};
+
+struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
+ NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+ unsigned target)
+ : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+
+ LogicalResult matchAndRewrite(arith::CmpIOp op,
+ PatternRewriter &rewriter) const override {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+ if (!srcType)
+ return failure();
+
+ std::optional<ConstantIntRanges> range =
+ getOperandsRange(solver, {lhs, rhs});
+ if (!range)
+ return failure();
+
+ auto smin = APInt::getSignedMinValue(targetBitwidth);
+ auto smax = APInt::getSignedMaxValue(targetBitwidth);
+ auto umin = APInt::getMinValue(targetBitwidth);
+ auto umax = APInt::getMaxValue(targetBitwidth);
+ if (!checkRange(*range, smin, smax, umin, umax))
+ return failure();
+
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ return failure();
+
+ Location loc = op->getLoc();
+ IRMapping mapping;
+ for (Value arg : op->getOperands()) {
+ Value newArg = doCast(rewriter, loc, arg, targetType);
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+
+private:
+ DataFlowSolver &solver;
+ unsigned targetBitwidth;
+};
+
struct IntRangeOptimizationsPass
: public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
@@ -208,6 +430,32 @@ struct IntRangeOptimizationsPass
signalPassFailure();
}
};
+
+struct IntRangeNarrowingPass
+ : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+ using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ DataFlowListener listener(solver);
+
+ RewritePatternSet patterns(ctx);
+ populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ signalPassFailure();
+ }
+};
} // namespace
void mlir::arith::populateIntRangeOptimizationsPatterns(
@@ -216,6 +464,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
}
+void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver,
+ unsigned targetBitwidth) {
+ // Cmpi uses args ranges instead of results, run it with higher benefit,
+ // as its argumens can be potentially replaced.
+ patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+ targetBitwidth);
+
+ patterns.add<NarrowElementwise>(patterns.getContext(), solver,
+ targetBitwidth);
+}
+
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
return std::make_unique<IntRangeOptimizationsPass>();
}
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
new file mode 100644
index 00000000000000..d85cb3384061be
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32" %s | FileCheck %s
+
+// Do not truncate negative values
+// CHECK-LABEL: func @test_addi_neg
+// CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+// CHECK: return %[[RES]] : index
+func.func @test_addi_neg() -> index {
+ %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
+ %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index
+ %2 = arith.addi %0, %1 : index
+ return %2 : index
+}
+
+// CHECK-LABEL: func @test_addi
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+// CHECK: return %[[RES_CASTED]] : index
+func.func @test_addi() -> index {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+ %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index
+ %2 = arith.addi %0, %1 : index
+ return %2 : index
+}
+
+
+// CHECK-LABEL: func @test_addi_i64
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
+// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
+// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
+// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+// CHECK: return %[[RES_CASTED]] : i64
+func.func @test_addi_i64() -> i64 {
+ %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
+ %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64
+ %2 = arith.addi %0, %1 : i64
+ return %2 : i64
+}
+
+// CHECK-LABEL: func @test_cmpi
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: return %[[RES]] : i1
+func.func @test_cmpi() -> i1 {
+ %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+ %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+ %2 = arith.cmpi slt, %0, %1 : index
+ return %2 : i1
+}
>From 696b8a5fc822e9beeee4e7b4e18ac12a27f10c06 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 18:59:07 +0200
Subject: [PATCH 2/2] typo
---
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index beeefea497d9b2..ce8e819708b190 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,7 +73,7 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
-/// Add patterns for int range based norrowing.
+/// Add patterns for int range based narrowing.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver,
unsigned targetBitwidth);
More information about the Mlir-commits
mailing list