[Mlir-commits] [mlir] [mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ing (PR #89552)

Matthias Springer llvmlistbot at llvm.org
Sun Apr 21 10:54:31 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/89552

By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.

>From 29f45ee1c6b1c74cc465f2b372ea2fe2dae78e7e Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sun, 21 Apr 2024 17:52:56 +0000
Subject: [PATCH] [mlir][Transforms] GreedyPatternRewriteDriver: Add flag to
 control constant CSE'ing

By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
---
 .../mlir/Transforms/GreedyPatternRewriteDriver.h   |  4 ++++
 mlir/include/mlir/Transforms/Passes.td             |  2 ++
 mlir/lib/Transforms/Canonicalizer.cpp              |  2 ++
 .../Utils/GreedyPatternRewriteDriver.cpp           |  4 ++--
 mlir/test/Transforms/test-canonicalize.mlir        | 14 ++++++++++++++
 5 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9c..880426c2411bcf 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -47,6 +47,10 @@ class GreedyRewriteConfig {
   /// Note: Only applicable when simplifying entire regions.
   bool enableRegionSimplification = true;
 
+  /// If set to "true", constants are CSE'd (even across multiple regions that
+  /// are in a parent-ancestor relationship).
+  bool cseConstants = true;
+
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
   /// disable this iteration limit.
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27e..549161c96030d3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> {
     Option<"enableRegionSimplification", "region-simplify", "bool",
            /*default=*/"true",
            "Perform control flow optimizations to the region tree">,
+    Option<"cseConstants", "cse-constants", "bool", /*default=*/"true",
+           "CSE constant operations">,
     Option<"maxIterations", "max-iterations", "int64_t",
            /*default=*/"10",
            "Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee55..2600df32b69c1d 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
       : config(config) {
     this->topDownProcessingEnabled = config.useTopDownTraversal;
     this->enableRegionSimplification = config.enableRegionSimplification;
+    this->cseConstants = config.cseConstants;
     this->maxIterations = config.maxIterations;
     this->maxNumRewrites = config.maxNumRewrites;
     this->disabledPatterns = disabledPatterns;
@@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     // Set the config from possible pass options set in the meantime.
     config.useTopDownTraversal = topDownProcessingEnabled;
     config.enableRegionSimplification = enableRegionSimplification;
+    config.cseConstants = cseConstants;
     config.maxIterations = maxIterations;
     config.maxNumRewrites = maxNumRewrites;
 
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff2..cf4a192a0281d7 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
-        if (!insertKnownConstant(op))
+        if (!config.cseConstants || !insertKnownConstant(op))
           addToWorklist(op);
       });
     } else {
       // Add all nested operations to the worklist in preorder.
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-        if (!insertKnownConstant(op)) {
+        if (!config.cseConstants || !insertKnownConstant(op)) {
           addToWorklist(op);
           return WalkResult::advance();
         }
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf4..98eae142d1870e 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
 func.func @remove_op_with_inner_ops_pattern() {
@@ -89,3 +90,16 @@ func.func @test_region_simplify() {
 ^bb1:
   return
 }
+
+// CHECK-LABEL: do_not_cse_constant
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: return %[[c0]], %[[c0]]
+// NO-CSE-LABEL: do_not_cse_constant
+// NO-CSE: %[[c0:.*]] = arith.constant 0 : index
+// NO-CSE: %[[c1:.*]] = arith.constant 0 : index
+// NO-CSE: return %[[c0]], %[[c1]]
+func.func @do_not_cse_constant() -> (index, index) {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 0 : index
+  return %0, %1 : index, index
+}
\ No newline at end of file



More information about the Mlir-commits mailing list