[Mlir-commits] [mlir] [mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion (PR #112454)

Ivan Butygin llvmlistbot at llvm.org
Tue Oct 15 17:16:02 PDT 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/112454

`UnsignedWhenEquivalent` doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations (and probably faster).

>From ac8de1ebc6dc09be184d1d692a1e37e153557a07 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 02:13:42 +0200
Subject: [PATCH] [mlir] UnsignedWhenEquivalent: use greedy rewriter instead of
 dialect conversion

UnsignedWhenEquivalent doesn't really need any dialect conversion features and switching it normal patterns makes it more composable with other patterns-based transformations.
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |  4 +
 .../Transforms/UnsignedWhenEquivalent.cpp     | 95 ++++++++++++-------
 .../Arith/unsigned-when-equivalent.mlir       | 20 ++--
 3 files changed, 73 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..e866ac518dbbcb 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -70,6 +70,10 @@ std::unique_ptr<Pass> createArithUnsignedWhenEquivalentPass();
 void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
                                            DataFlowSolver &solver);
 
+/// Replace signed ops with unsigned ones where they are proven equivalent.
+void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
+                                            DataFlowSolver &solver);
+
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
index 4edce84bafd416..c76f56279db706 100644
--- a/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
@@ -13,7 +13,8 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 namespace arith {
@@ -85,35 +86,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
 }
 
 namespace {
+class DataFlowListener : public RewriterBase::Listener {
+public:
+  DataFlowListener(DataFlowSolver &s) : s(s) {}
+
+protected:
+  void notifyOperationErased(Operation *op) override {
+    s.eraseState(s.getProgramPointAfter(op));
+    for (Value res : op->getResults())
+      s.eraseState(res);
+  }
+
+  DataFlowSolver &s;
+};
+
 template <typename Signed, typename Unsigned>
-struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
-  using OpConversionPattern<Signed>::OpConversionPattern;
+struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
+  ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<Signed>(context), solver(s) {}
 
-  LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
-                                ConversionPatternRewriter &rw) const override {
-    rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
-                                    adaptor.getOperands(), op->getAttrs());
+  LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
+    if (failed(
+            staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
+      return failure();
+
+    rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
+                                    op->getAttrs());
     return success();
   }
+
+private:
+  DataFlowSolver &solver;
 };
 
-struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
-  using OpConversionPattern<CmpIOp>::OpConversionPattern;
+struct ConvertCmpIToUnsigned final : public OpRewritePattern<CmpIOp> {
+  ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
+      : OpRewritePattern<CmpIOp>(context), solver(s) {}
+
+  LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
+    if (failed(isCmpIConvertable(this->solver, op)))
+      return failure();
 
-  LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
-                                ConversionPatternRewriter &rw) const override {
     rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
                                   op.getLhs(), op.getRhs());
     return success();
   }
+
+private:
+  DataFlowSolver &solver;
 };
 
 struct ArithUnsignedWhenEquivalentPass
     : public arith::impl::ArithUnsignedWhenEquivalentBase<
           ArithUnsignedWhenEquivalentPass> {
-  /// Implementation structure: first find all equivalent ops and collect them,
-  /// then perform all the rewrites in a second pass over the target op. This
-  /// ensures that analysis results are not invalidated during rewriting.
+
   void runOnOperation() override {
     Operation *op = getOperation();
     MLIRContext *ctx = op->getContext();
@@ -123,35 +149,32 @@ struct ArithUnsignedWhenEquivalentPass
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
 
-    ConversionTarget target(*ctx);
-    target.addLegalDialect<ArithDialect>();
-    target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
-                                 MinSIOp, MaxSIOp, ExtSIOp>(
-        [&solver](Operation *op) -> std::optional<bool> {
-          return failed(staticallyNonNegative(solver, op));
-        });
-    target.addDynamicallyLegalOp<CmpIOp>(
-        [&solver](CmpIOp op) -> std::optional<bool> {
-          return failed(isCmpIConvertable(solver, op));
-        });
+    DataFlowListener listener(solver);
 
     RewritePatternSet patterns(ctx);
-    patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
-                 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
-                 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
-                 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
-                 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
-                 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
-                 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
-        ctx);
-
-    if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
+    populateUnsignedWhenEquivalentPatterns(patterns, solver);
+
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
       signalPassFailure();
-    }
   }
 };
 } // end anonymous namespace
 
+void mlir::arith::populateUnsignedWhenEquivalentPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver) {
+  patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
+               ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
+               ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
+               ConvertOpToUnsigned<RemSIOp, RemUIOp>,
+               ConvertOpToUnsigned<MinSIOp, MinUIOp>,
+               ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
+               ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
+      patterns.getContext(), solver);
+}
+
 std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
   return std::make_unique<ArithUnsignedWhenEquivalentPass>();
 }
diff --git a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
index 49bd74cfe9124a..e015d2d7543c93 100644
--- a/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
+++ b/mlir/test/Dialect/Arith/unsigned-when-equivalent.mlir
@@ -12,7 +12,7 @@
 // CHECK: arith.cmpi slt
 // CHECK: arith.cmpi sge
 // CHECK: arith.cmpi sgt
-func.func @not_with_maybe_overflow(%arg0 : i32) {
+func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
     %ci32_smax = arith.constant 0x7fffffff : i32
     %c1 = arith.constant 1 : i32
     %c4 = arith.constant 4 : i32
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
     %10 = arith.cmpi slt, %1, %c4 : i32
     %11 = arith.cmpi sge, %1, %c4 : i32
     %12 = arith.cmpi sgt, %1, %c4 : i32
-    func.return
+    func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @yes_with_no_overflow
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) {
 // CHECK: arith.cmpi ult
 // CHECK: arith.cmpi uge
 // CHECK: arith.cmpi ugt
-func.func @yes_with_no_overflow(%arg0 : i32) {
+func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) {
     %ci32_almost_smax = arith.constant 0x7ffffffe : i32
     %c1 = arith.constant 1 : i32
     %c4 = arith.constant 4 : i32
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) {
     %10 = arith.cmpi slt, %1, %c4 : i32
     %11 = arith.cmpi sge, %1, %c4 : i32
     %12 = arith.cmpi sgt, %1, %c4 : i32
-    func.return
+    func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: func @preserves_structure
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) {
 func.func private @external() -> i8
 
 // CHECK-LABEL: @dead_code
-func.func @dead_code() {
+func.func @dead_code() -> i8 {
   %0 = call @external() : () -> i8
   // CHECK: arith.floordivsi
   %1 = arith.floordivsi %0, %0 : i8
-  return
+  return %1 : i8
 }
 
 // Make sure not crash.
 // CHECK-LABEL: @no_integer_or_index
-func.func @no_integer_or_index() { 
+func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> {
   // CHECK: arith.cmpi
   %cst_0 = arith.constant dense<[0]> : vector<1xi32> 
-  %cmp = arith.cmpi slt, %cst_0, %cst_0 : vector<1xi32> 
-  return
+  %cmp = arith.cmpi slt, %cst_0, %arg0 : vector<1xi32>
+  return %cmp : vector<1xi1>
 }
 
 // CHECK-LABEL: @gpu_func
@@ -113,4 +113,4 @@ func.func @gpu_func(%arg0: memref<2x32xf32>, %arg1: memref<2x32xf32>, %arg2: mem
     gpu.terminator
   } 
   return %arg1 : memref<2x32xf32> 
-}  
+}



More information about the Mlir-commits mailing list