[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)
Ivan Butygin
llvmlistbot at llvm.org
Wed Oct 16 05:13:48 PDT 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/112404
>From 584e8e5fc5ade2df915f02d85ff119b676b77d8b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 16:32:32 +0200
Subject: [PATCH 1/5] [mlir] Add `arith-int-range-narrowing` pass
This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified in the future, but it's a task for the future.
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 6 +
.../mlir/Dialect/Arith/Transforms/Passes.td | 19 ++
.../Transforms/IntRangeOptimizations.cpp | 260 ++++++++++++++++++
.../Dialect/Arith/int-range-narrowing.mlir | 57 ++++
4 files changed, 342 insertions(+)
create mode 100644 mlir/test/Dialect/Arith/int-range-narrowing.mlir
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..beeefea497d9b2 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,6 +73,12 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
+/// Add patterns for int range based norrowing.
+void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver,
+ unsigned targetBitwidth);
+
+// TODO: merge these two narrowing passes.
/// Add patterns for integer bitwidth narrowing.
void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
const ArithIntNarrowingOptions &options);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 1517f71f1a7c90..8c565f6489638e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
];
}
+def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
+ let summary = "Reduce integer operations bitwidth based on integer range analysis";
+ let description = [{
+ This pass runs integer range analysis and tries to narrow arith ops to the
+ specified bitwidth based on its results.
+ }];
+
+ let options = [
+ Option<"targetBitwidth", "target-bitwidth", "unsigned",
+ /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+ ];
+
+ // Explicitly depend on "arith" because this pass could create operations in
+ // `arith` out of thin air in some cases.
+ let dependentDialects = [
+ "::mlir::arith::ArithDialect"
+ ];
+}
+
def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
let summary = "Emulate operations on unsupported floats with extf/truncf";
let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..4e7b44c939e796 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -15,8 +15,10 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,6 +26,9 @@
namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTRANGEOPTS
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+
+#define GEN_PASS_DEF_ARITHINTRANGENARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace mlir::arith
using namespace mlir;
@@ -184,6 +189,223 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DataFlowSolver &solver;
};
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+ type = getElementTypeOrSelf(type);
+ if (isa<IndexType>(type))
+ return type;
+
+ if (auto intType = dyn_cast<IntegerType>(type))
+ if (intType.getWidth() > targetBitwidth)
+ return type;
+
+ return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+ if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+ return nullptr;
+
+ Type type;
+ for (auto range :
+ {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+ for (Value val : range) {
+ if (!type) {
+ type = val.getType();
+ continue;
+ } else if (type != val.getType()) {
+ return nullptr;
+ }
+ }
+ }
+
+ return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+ ValueRange results) {
+ std::optional<ConstantIntRanges> ret;
+ for (Value value : results) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return std::nullopt;
+
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+
+ if (!ret) {
+ ret = inferredRange;
+ } else {
+ ret = ret->rangeUnion(inferredRange);
+ }
+ }
+ return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+ auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+ if (auto shaped = dyn_cast<ShapedType>(srcType))
+ return shaped.clone(dstType);
+
+ assert(srcType.isIntOrIndex() && "Invalid src type");
+ return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+ APInt umin, APInt umax) {
+ auto sge = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.sext(width);
+ val2 = val2.sext(width);
+ return val1.sge(val2);
+ };
+ auto sle = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.sext(width);
+ val2 = val2.sext(width);
+ return val1.sle(val2);
+ };
+ auto uge = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.zext(width);
+ val2 = val2.zext(width);
+ return val1.uge(val2);
+ };
+ auto ule = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.zext(width);
+ val2 = val2.zext(width);
+ return val1.ule(val2);
+ };
+ return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+ uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+ Type srcType = src.getType();
+ assert(srcType.isIntOrIndex() && "Invalid src type");
+ assert(dstType.isIntOrIndex() && "Invalid dst type");
+ if (srcType == dstType)
+ return src;
+
+ if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+ return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+ auto srcInt = cast<IntegerType>(srcType);
+ auto dstInt = cast<IntegerType>(dstType);
+ if (dstInt.getWidth() < srcInt.getWidth()) {
+ return builder.create<arith::TruncIOp>(loc, dstType, src);
+ } else {
+ return builder.create<arith::ExtUIOp>(loc, dstType, src);
+ }
+}
+
+struct NarrowElementwise final
+ : public OpTraitRewritePattern<OpTrait::Elementwise> {
+ NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+ : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+ targetBitwidth(target) {}
+
+ using OpTraitRewritePattern::OpTraitRewritePattern;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ Type srcType = checkElementwiseOpType(op, targetBitwidth);
+ if (!srcType)
+ return failure();
+
+ std::optional<ConstantIntRanges> range =
+ getOperandsRange(solver, op->getResults());
+ if (!range)
+ return failure();
+
+ // We are truncating op args to the desired bitwidth before the op and then
+ // extending op results back to the original width after.
+ // extui and exti will produce different results for negative values, so
+ // limit signed range to non-negative values.
+ auto smin = APInt::getZero(targetBitwidth);
+ auto smax = APInt::getSignedMaxValue(targetBitwidth);
+ auto umin = APInt::getMinValue(targetBitwidth);
+ auto umax = APInt::getMaxValue(targetBitwidth);
+ if (!checkRange(*range, smin, smax, umin, umax))
+ return failure();
+
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ return failure();
+
+ Location loc = op->getLoc();
+ IRMapping mapping;
+ for (Value arg : op->getOperands()) {
+ Value newArg = doCast(rewriter, loc, arg, targetType);
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ rewriter.modifyOpInPlace(newOp, [&]() {
+ for (OpResult res : newOp->getResults()) {
+ res.setType(targetType);
+ }
+ });
+ SmallVector<Value> newResults;
+ for (Value res : newOp->getResults())
+ newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+
+private:
+ DataFlowSolver &solver;
+ unsigned targetBitwidth;
+};
+
+struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
+ NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+ unsigned target)
+ : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+
+ LogicalResult matchAndRewrite(arith::CmpIOp op,
+ PatternRewriter &rewriter) const override {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+ if (!srcType)
+ return failure();
+
+ std::optional<ConstantIntRanges> range =
+ getOperandsRange(solver, {lhs, rhs});
+ if (!range)
+ return failure();
+
+ auto smin = APInt::getSignedMinValue(targetBitwidth);
+ auto smax = APInt::getSignedMaxValue(targetBitwidth);
+ auto umin = APInt::getMinValue(targetBitwidth);
+ auto umax = APInt::getMaxValue(targetBitwidth);
+ if (!checkRange(*range, smin, smax, umin, umax))
+ return failure();
+
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ return failure();
+
+ Location loc = op->getLoc();
+ IRMapping mapping;
+ for (Value arg : op->getOperands()) {
+ Value newArg = doCast(rewriter, loc, arg, targetType);
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+
+private:
+ DataFlowSolver &solver;
+ unsigned targetBitwidth;
+};
+
struct IntRangeOptimizationsPass
: public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
@@ -208,6 +430,32 @@ struct IntRangeOptimizationsPass
signalPassFailure();
}
};
+
+struct IntRangeNarrowingPass
+ : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+ using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ DataFlowListener listener(solver);
+
+ RewritePatternSet patterns(ctx);
+ populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ signalPassFailure();
+ }
+};
} // namespace
void mlir::arith::populateIntRangeOptimizationsPatterns(
@@ -216,6 +464,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
}
+void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver,
+ unsigned targetBitwidth) {
+ // Cmpi uses args ranges instead of results, run it with higher benefit,
+ // as its argumens can be potentially replaced.
+ patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+ targetBitwidth);
+
+ patterns.add<NarrowElementwise>(patterns.getContext(), solver,
+ targetBitwidth);
+}
+
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
return std::make_unique<IntRangeOptimizationsPass>();
}
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
new file mode 100644
index 00000000000000..d85cb3384061be
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32" %s | FileCheck %s
+
+// Do not truncate negative values
+// CHECK-LABEL: func @test_addi_neg
+// CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+// CHECK: return %[[RES]] : index
+func.func @test_addi_neg() -> index {
+ %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
+ %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index
+ %2 = arith.addi %0, %1 : index
+ return %2 : index
+}
+
+// CHECK-LABEL: func @test_addi
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+// CHECK: return %[[RES_CASTED]] : index
+func.func @test_addi() -> index {
+ %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+ %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index
+ %2 = arith.addi %0, %1 : index
+ return %2 : index
+}
+
+
+// CHECK-LABEL: func @test_addi_i64
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
+// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
+// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
+// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+// CHECK: return %[[RES_CASTED]] : i64
+func.func @test_addi_i64() -> i64 {
+ %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
+ %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64
+ %2 = arith.addi %0, %1 : i64
+ return %2 : i64
+}
+
+// CHECK-LABEL: func @test_cmpi
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: return %[[RES]] : i1
+func.func @test_cmpi() -> i1 {
+ %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+ %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+ %2 = arith.cmpi slt, %0, %1 : index
+ return %2 : i1
+}
>From 696b8a5fc822e9beeee4e7b4e18ac12a27f10c06 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 18:59:07 +0200
Subject: [PATCH 2/5] typo
---
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index beeefea497d9b2..ce8e819708b190 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,7 +73,7 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
-/// Add patterns for int range based norrowing.
+/// Add patterns for int range based narrowing.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver,
unsigned targetBitwidth);
>From 106beda0062847cc8ae6d624b825081e2702b032 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 13:05:29 +0200
Subject: [PATCH 3/5] use list of bitwidths instead of 1
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
.../mlir/Dialect/Arith/Transforms/Passes.td | 4 +-
.../Transforms/IntRangeOptimizations.cpp | 141 ++++++++++--------
.../Dialect/Arith/int-range-narrowing.mlir | 2 +-
4 files changed, 79 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index ce8e819708b190..4d386827571dfa 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> createIntRangeOptimizationsPass();
/// Add patterns for int range based narrowing.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver,
- unsigned targetBitwidth);
+ ArrayRef<unsigned> bitwidthsSupported);
// TODO: merge these two narrowing passes.
/// Add patterns for integer bitwidth narrowing.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 8c565f6489638e..898d74249af618 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -58,8 +58,8 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
}];
let options = [
- Option<"targetBitwidth", "target-bitwidth", "unsigned",
- /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+ ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
+ "Integer bitwidths supported">,
];
// Explicitly depend on "arith" because this pass could create operations in
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 4e7b44c939e796..861602f2032172 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -302,108 +302,117 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
struct NarrowElementwise final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
- NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+ NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
+ ArrayRef<unsigned> target)
: OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
- targetBitwidth(target) {}
+ targetBitwidths(target) {}
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- Type srcType = checkElementwiseOpType(op, targetBitwidth);
- if (!srcType)
- return failure();
std::optional<ConstantIntRanges> range =
getOperandsRange(solver, op->getResults());
if (!range)
return failure();
- // We are truncating op args to the desired bitwidth before the op and then
- // extending op results back to the original width after.
- // extui and exti will produce different results for negative values, so
- // limit signed range to non-negative values.
- auto smin = APInt::getZero(targetBitwidth);
- auto smax = APInt::getSignedMaxValue(targetBitwidth);
- auto umin = APInt::getMinValue(targetBitwidth);
- auto umax = APInt::getMaxValue(targetBitwidth);
- if (!checkRange(*range, smin, smax, umin, umax))
- return failure();
+ for (unsigned targetBitwidth : targetBitwidths) {
+ Type srcType = checkElementwiseOpType(op, targetBitwidth);
+ if (!srcType)
+ continue;
- Type targetType = getTargetType(srcType, targetBitwidth);
- if (targetType == srcType)
- return failure();
+ // We are truncating op args to the desired bitwidth before the op and
+ // then extending op results back to the original width after. extui and
+ // exti will produce different results for negative values, so limit
+ // signed range to non-negative values.
+ auto smin = APInt::getZero(targetBitwidth);
+ auto smax = APInt::getSignedMaxValue(targetBitwidth);
+ auto umin = APInt::getMinValue(targetBitwidth);
+ auto umax = APInt::getMaxValue(targetBitwidth);
+ if (!checkRange(*range, smin, smax, umin, umax))
+ continue;
- Location loc = op->getLoc();
- IRMapping mapping;
- for (Value arg : op->getOperands()) {
- Value newArg = doCast(rewriter, loc, arg, targetType);
- mapping.map(arg, newArg);
- }
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ continue;
- Operation *newOp = rewriter.clone(*op, mapping);
- rewriter.modifyOpInPlace(newOp, [&]() {
- for (OpResult res : newOp->getResults()) {
- res.setType(targetType);
+ Location loc = op->getLoc();
+ IRMapping mapping;
+ for (Value arg : op->getOperands()) {
+ Value newArg = doCast(rewriter, loc, arg, targetType);
+ mapping.map(arg, newArg);
}
- });
- SmallVector<Value> newResults;
- for (Value res : newOp->getResults())
- newResults.emplace_back(doCast(rewriter, loc, res, srcType));
- rewriter.replaceOp(op, newResults);
- return success();
+ Operation *newOp = rewriter.clone(*op, mapping);
+ rewriter.modifyOpInPlace(newOp, [&]() {
+ for (OpResult res : newOp->getResults()) {
+ res.setType(targetType);
+ }
+ });
+ SmallVector<Value> newResults;
+ for (Value res : newOp->getResults())
+ newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+ return failure();
}
private:
DataFlowSolver &solver;
- unsigned targetBitwidth;
+ SmallVector<unsigned, 4> targetBitwidths;
};
struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
- unsigned target)
- : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+ ArrayRef<unsigned> target)
+ : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
+ }
LogicalResult matchAndRewrite(arith::CmpIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
- Type srcType = checkArithType(lhs.getType(), targetBitwidth);
- if (!srcType)
- return failure();
-
std::optional<ConstantIntRanges> range =
getOperandsRange(solver, {lhs, rhs});
if (!range)
return failure();
- auto smin = APInt::getSignedMinValue(targetBitwidth);
- auto smax = APInt::getSignedMaxValue(targetBitwidth);
- auto umin = APInt::getMinValue(targetBitwidth);
- auto umax = APInt::getMaxValue(targetBitwidth);
- if (!checkRange(*range, smin, smax, umin, umax))
- return failure();
+ for (unsigned targetBitwidth : targetBitwidths) {
+ Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+ if (!srcType)
+ continue;
- Type targetType = getTargetType(srcType, targetBitwidth);
- if (targetType == srcType)
- return failure();
+ auto smin = APInt::getSignedMinValue(targetBitwidth);
+ auto smax = APInt::getSignedMaxValue(targetBitwidth);
+ auto umin = APInt::getMinValue(targetBitwidth);
+ auto umax = APInt::getMaxValue(targetBitwidth);
+ if (!checkRange(*range, smin, smax, umin, umax))
+ continue;
- Location loc = op->getLoc();
- IRMapping mapping;
- for (Value arg : op->getOperands()) {
- Value newArg = doCast(rewriter, loc, arg, targetType);
- mapping.map(arg, newArg);
- }
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ continue;
- Operation *newOp = rewriter.clone(*op, mapping);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
+ Location loc = op->getLoc();
+ IRMapping mapping;
+ for (Value arg : op->getOperands()) {
+ Value newArg = doCast(rewriter, loc, arg, targetType);
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+ return failure();
}
private:
DataFlowSolver &solver;
- unsigned targetBitwidth;
+ SmallVector<unsigned, 4> targetBitwidths;
};
struct IntRangeOptimizationsPass
@@ -447,7 +456,7 @@ struct IntRangeNarrowingPass
DataFlowListener listener(solver);
RewritePatternSet patterns(ctx);
- populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+ populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
GreedyRewriteConfig config;
config.listener = &listener;
@@ -464,16 +473,16 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
}
-void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
- DataFlowSolver &solver,
- unsigned targetBitwidth) {
+void mlir::arith::populateIntRangeNarrowingPatterns(
+ RewritePatternSet &patterns, DataFlowSolver &solver,
+ ArrayRef<unsigned> bitwidthsSupported) {
// Cmpi uses args ranges instead of results, run it with higher benefit,
// as its argumens can be potentially replaced.
patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
- targetBitwidth);
+ bitwidthsSupported);
patterns.add<NarrowElementwise>(patterns.getContext(), solver,
- targetBitwidth);
+ bitwidthsSupported);
}
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index d85cb3384061be..6283284567dce8 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32" %s | FileCheck %s
+// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
// Do not truncate negative values
// CHECK-LABEL: func @test_addi_neg
>From b7053b25a2840cbff6878b7c252c97ff02497591 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 13:12:00 +0200
Subject: [PATCH 4/5] update tests
---
.../Dialect/Arith/int-range-narrowing.mlir | 24 +++++++++----------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 6283284567dce8..8b73f02fd214a2 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
+// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s
// Do not truncate negative values
// CHECK-LABEL: func @test_addi_neg
@@ -14,10 +14,10 @@ func.func @test_addi_neg() -> index {
// CHECK-LABEL: func @test_addi
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
-// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
-// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+// CHECK: %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index
// CHECK: return %[[RES_CASTED]] : index
func.func @test_addi() -> index {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -30,10 +30,10 @@ func.func @test_addi() -> index {
// CHECK-LABEL: func @test_addi_i64
// CHECK: %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
// CHECK: %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
-// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
-// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
-// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
-// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+// CHECK: %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8
+// CHECK: %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8
+// CHECK: %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+// CHECK: %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i8 to i64
// CHECK: return %[[RES_CASTED]] : i64
func.func @test_addi_i64() -> i64 {
%0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
@@ -45,9 +45,9 @@ func.func @test_addi_i64() -> i64 {
// CHECK-LABEL: func @test_cmpi
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
-// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
-// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+// CHECK: %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+// CHECK: %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+// CHECK: %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
// CHECK: return %[[RES]] : i1
func.func @test_cmpi() -> i1 {
%0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
>From b5f6014f35be7f17bcf9688d46ab209ad288529b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 14:13:06 +0200
Subject: [PATCH 5/5] more tests
---
.../Dialect/Arith/int-range-narrowing.mlir | 142 ++++++++++++++++++
1 file changed, 142 insertions(+)
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 8b73f02fd214a2..cd0a4c449913e1 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// Some basic tests
+//===----------------------------------------------------------------------===//
+
// Do not truncate negative values
// CHECK-LABEL: func @test_addi_neg
// CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
@@ -55,3 +59,141 @@ func.func @test_cmpi() -> i1 {
%2 = arith.cmpi slt, %0, %1 : index
return %2 : i1
}
+
+//===----------------------------------------------------------------------===//
+// arith.addi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @addi_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.addi %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @addi_mixed_ext_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[ADD]] : i32
+func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.addi %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because we cannot reduce the bitwidth
+// below i16, given the pass options set.
+//
+// CHECK-LABEL: func.func @addi_extsi_i16
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16
+// CHECK-NEXT: return %[[ADD]] : i16
+func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
+ %a = arith.extsi %lhs : i8 to i16
+ %b = arith.extsi %rhs : i8 to i16
+ %r = arith.addi %a, %b : i16
+ return %r : i16
+}
+
+//===----------------------------------------------------------------------===//
+// arith.subi
+//===----------------------------------------------------------------------===//
+
+// This patterns should only apply to `arith.subi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @subi_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[SUB]] : i32
+func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.subi %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @subi_mixed_ext_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[ADD]] : i32
+func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.subi %a, %b : i32
+ return %r : i32
+}
+
+//===----------------------------------------------------------------------===//
+// arith.muli
+//===----------------------------------------------------------------------===//
+
+// TODO: This should be optimized into i16
+// CHECK-LABEL: func.func @muli_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i24 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
+
+// We do not expect this case to be optimized because given n-bit operands,
+// arith.muli produces 2n bits of result.
+//
+// CHECK-LABEL: func.func @muli_extsi_i32
+// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
+ %a = arith.extsi %lhs : i16 to i32
+ %b = arith.extsi %rhs : i16 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @muli_mixed_ext_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[MUL]] : i32
+func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
More information about the Mlir-commits
mailing list