[Mlir-commits] [mlir] [mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion (PR #112454)
Ivan Butygin
llvmlistbot at llvm.org
Tue Oct 15 17:34:13 PDT 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/112454
>From ac8de1ebc6dc09be184d1d692a1e37e153557a07 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 02:13:42 +0200
Subject: [PATCH 1/2] [mlir] UnsignedWhenEquivalent: use greedy rewriter
instead of dialect conversion
UnsignedWhenEquivalent doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations.
---
.../mlir/Dialect/Arith/Transforms/Passes.h | 4 +
.../Transforms/UnsignedWhenEquivalent.cpp | 95 ++++++++++++-------
.../Arith/unsigned-when-equivalent.mlir | 20 ++--
3 files changed, 73 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..e866ac518dbbcb 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -70,6 +70,10 @@ std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver);
+/// Replace signed ops with unsigned ones where they are proven equivalent.
+void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
+ DataFlowSolver &solver);
+
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index 4edce84bafd416..c76f56279db706 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,7 +13,8 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace arith {
@@ -85,35 +86,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
}
namespace {
+class DataFlowListener : public RewriterBase::Listener {
+public:
+ DataFlowListener(DataFlowSolver &s) : s(s) {}
+
+protected:
+ void notifyOperationErased(Operation *op) override {
+ s.eraseState(s.getProgramPointAfter(op));
+ for (Value res : op->getResults())
+ s.eraseState(res);
+ }
+
+ DataFlowSolver &s;
+};
+
template <typename Signed, typename Unsigned>
-struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
- using OpConversionPattern<Signed>::OpConversionPattern;
+struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
+ ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<Signed>(context), solver(s) {}
- LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
- ConversionPatternRewriter &rw) const override {
- rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
- adaptor.getOperands(), op->getAttrs());
+ LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
+ if (failed(
+ staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
+ return failure();
+
+ rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
+ op->getAttrs());
return success();
}
+
+private:
+ DataFlowSolver &solver;
};
-struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
- using OpConversionPattern<CmpIOp>::OpConversionPattern;
+struct ConvertCmpIToUnsigned final : public OpRewritePattern<CmpIOp> {
+ ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
+ : OpRewritePattern<CmpIOp>(context), solver(s) {}
+
+ LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
+ if (failed(isCmpIConvertable(this->solver, op)))
+ return failure();
- LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
- ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
op.getLhs(), op.getRhs());
return success();
}
+
+private:
+ DataFlowSolver &solver;
};
struct ArithUnsignedWhenEquivalentPass
: public arith::impl::ArithUnsignedWhenEquivalentBase<
ArithUnsignedWhenEquivalentPass> {
- /// Implementation structure: first find all equivalent ops and collect them,
- /// then perform all the rewrites in a second pass over the target op. This
- /// ensures that analysis results are not invalidated during rewriting.
+
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
@@ -123,35 +149,32 @@ struct ArithUnsignedWhenEquivalentPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
- ConversionTarget target(*ctx);
- target.addLegalDialect<ArithDialect>();
- target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
- MinSIOp, MaxSIOp, ExtSIOp>(
- [&solver](Operation *op) -> std::optional<bool> {
- return failed(staticallyNonNegative(solver, op));
- });
- target.addDynamicallyLegalOp<CmpIOp>(
- [&solver](CmpIOp op) -> std::optional<bool> {
- return failed(isCmpIConvertable(solver, op));
- });
+ DataFlowListener listener(solver);
RewritePatternSet patterns(ctx);
- patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
- ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
- ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
- ConvertOpToUnsigned<RemSIOp, RemUIOp>,
- ConvertOpToUnsigned<MinSIOp, MinUIOp>,
- ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
- ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
- ctx);
-
- if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
+ populateUnsignedWhenEquivalentPatterns(patterns, solver);
+
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
- }
}
};
} // end anonymous namespace
+void mlir::arith::populateUnsignedWhenEquivalentPatterns(
+ RewritePatternSet &patterns, DataFlowSolver &solver) {
+ patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
+ ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
+ ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
+ ConvertOpToUnsigned<RemSIOp, RemUIOp>,
+ ConvertOpToUnsigned<MinSIOp, MinUIOp>,
+ ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
+ ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
+ patterns.getContext(), solver);
+}
+
std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
return std::make_unique<ArithUnsignedWhenEquivalentPass>();
}
diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
index 49bd74cfe9124a..e015d2d7543c93 100644
--- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
+++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
@@ -12,7 +12,7 @@
// CHECK: arith.cmpi slt
// CHECK: arith.cmpi sge
// CHECK: arith.cmpi sgt
-func.func @not_with_maybe_overflow(%arg0 : i32) {
+func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
%ci32_smax = arith.constant 0x7fffffff : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
- func.return
+ func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
}
// CHECK-LABEL: func @yes_with_no_overflow
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
// CHECK: arith.cmpi ult
// CHECK: arith.cmpi uge
// CHECK: arith.cmpi ugt
-func.func @yes_with_no_overflow(%arg0 : i32) {
+func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
%ci32_almost_smax = arith.constant 0x7ffffffe : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) {
%10 = arith.cmpi slt, %1, %c4 : i32
%11 = arith.cmpi sge, %1, %c4 : i32
%12 = arith.cmpi sgt, %1, %c4 : i32
- func.return
+ func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
}
// CHECK-LABEL: func @preserves_structure
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) {
func.func private @external() -> i8
// CHECK-LABEL: @dead_code
-func.func @dead_code() {
+func.func @dead_code() -> i8 {
%0 = call @external() : () -> i8
// CHECK: arith.floordivsi
%1 = arith.floordivsi %0, %0 : i8
- return
+ return %1 : i8
}
// Make sure not crash.
// CHECK-LABEL: @no_integer_or_index
-func.func @no_integer_or_index() {
+func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> {
// CHECK: arith.cmpi
%cst_0 = arith.constant dense<[0]> : vector<1xi32>
- %cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32>
- return
+ %cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32>
+ return %cmp : vector<1xi1>
}
// CHECK-LABEL: @gpu_func
@@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem
gpu.terminator
}
return %arg1 : memref<2x32xf32>
-}
+}
>From 0329fad82869d72bd217754069ad25e40a4d0b23 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 02:33:54 +0200
Subject: [PATCH 2/2] comment
---
mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index c76f56279db706..d6b2ee9f313a99 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -30,6 +30,9 @@ using namespace mlir::dataflow;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
+// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
+// relies on this. These transformations may not be valid for 32bit index,
+// need more investigation.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result || result->getValue().isUninitialized())
More information about the Mlir-commits
mailing list