[Mlir-commits] [mlir] [mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ing (PR #89552)

Matthias Springer llvmlistbot at llvm.org
Tue Apr 23 15:43:29 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/89552

>From 9bb74b5e1d3403ef83058a181a89763744966597 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 23 Apr 2024 22:41:44 +0000
Subject: [PATCH] [mlir][Transforms] GreedyPatternRewriteDriver: Add flag to
 control constant CSE'ing

By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
---
 .../mlir/Transforms/GreedyPatternRewriteDriver.h   |  4 ++++
 mlir/include/mlir/Transforms/Passes.td             |  2 ++
 mlir/lib/Transforms/Canonicalizer.cpp              |  2 ++
 .../Utils/GreedyPatternRewriteDriver.cpp           |  4 ++--
 mlir/test/Pass/run-reproducer.mlir                 |  2 +-
 mlir/test/Transforms/test-canonicalize.mlir        | 14 ++++++++++++++
 6 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9c..880426c2411bcf 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -47,6 +47,10 @@ class GreedyRewriteConfig {
   /// Note: Only applicable when simplifying entire regions.
   bool enableRegionSimplification = true;
 
+  /// If set to "true", constants are CSE'd (even across multiple regions that
+  /// are in a parent-ancestor relationship).
+  bool cseConstants = true;
+
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
   /// disable this iteration limit.
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27e..549161c96030d3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> {
     Option<"enableRegionSimplification", "region-simplify", "bool",
            /*default=*/"true",
            "Perform control flow optimizations to the region tree">,
+    Option<"cseConstants", "cse-constants", "bool", /*default=*/"true",
+           "CSE constant operations">,
     Option<"maxIterations", "max-iterations", "int64_t",
            /*default=*/"10",
            "Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee55..2600df32b69c1d 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
       : config(config) {
     this->topDownProcessingEnabled = config.useTopDownTraversal;
     this->enableRegionSimplification = config.enableRegionSimplification;
+    this->cseConstants = config.cseConstants;
     this->maxIterations = config.maxIterations;
     this->maxNumRewrites = config.maxNumRewrites;
     this->disabledPatterns = disabledPatterns;
@@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     // Set the config from possible pass options set in the meantime.
     config.useTopDownTraversal = topDownProcessingEnabled;
     config.enableRegionSimplification = enableRegionSimplification;
+    config.cseConstants = cseConstants;
     config.maxIterations = maxIterations;
     config.maxNumRewrites = maxNumRewrites;
 
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff2..cf4a192a0281d7 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
-        if (!insertKnownConstant(op))
+        if (!config.cseConstants || !insertKnownConstant(op))
           addToWorklist(op);
       });
     } else {
       // Add all nested operations to the worklist in preorder.
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-        if (!insertKnownConstant(op)) {
+        if (!config.cseConstants || !insertKnownConstant(op)) {
           addToWorklist(op);
           return WalkResult::advance();
         }
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58dbaa5b96f..220ea2468eed7d 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,7 +14,7 @@ func.func @bar() {
   external_resources: {
     mlir_reproducer: {
       verify_each: true,
-      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
+      // CHECK:  builtin.module(func.func(cse,canonicalize{cse-constants=true max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
       pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
       disable_threading: true
     }
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf4..98eae142d1870e 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
 func.func @remove_op_with_inner_ops_pattern() {
@@ -89,3 +90,16 @@ func.func @test_region_simplify() {
 ^bb1:
   return
 }
+
+// CHECK-LABEL: do_not_cse_constant
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: return %[[c0]], %[[c0]]
+// NO-CSE-LABEL: do_not_cse_constant
+// NO-CSE: %[[c0:.*]] = arith.constant 0 : index
+// NO-CSE: %[[c1:.*]] = arith.constant 0 : index
+// NO-CSE: return %[[c0]], %[[c1]]
+func.func @do_not_cse_constant() -> (index, index) {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 0 : index
+  return %0, %1 : index, index
+}
\ No newline at end of file



More information about the Mlir-commits mailing list