[Mlir-commits] [mlir] [mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ing (PR #89552)
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 23 15:43:29 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/89552
>From 9bb74b5e1d3403ef83058a181a89763744966597 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 23 Apr 2024 22:41:44 +0000
Subject: [PATCH] [mlir][Transforms] GreedyPatternRewriteDriver: Add flag to
control constant CSE'ing
By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
---
.../mlir/Transforms/GreedyPatternRewriteDriver.h | 4 ++++
mlir/include/mlir/Transforms/Passes.td | 2 ++
mlir/lib/Transforms/Canonicalizer.cpp | 2 ++
.../Utils/GreedyPatternRewriteDriver.cpp | 4 ++--
mlir/test/Pass/run-reproducer.mlir | 2 +-
mlir/test/Transforms/test-canonicalize.mlir | 14 ++++++++++++++
6 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9c..880426c2411bcf 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -47,6 +47,10 @@ class GreedyRewriteConfig {
/// Note: Only applicable when simplifying entire regions.
bool enableRegionSimplification = true;
+ /// If set to "true", constants are CSE'd (even across multiple regions that
+ /// are in a parent-ancestor relationship).
+ bool cseConstants = true;
+
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27e..549161c96030d3 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"enableRegionSimplification", "region-simplify", "bool",
/*default=*/"true",
"Perform control flow optimizations to the region tree">,
+ Option<"cseConstants", "cse-constants", "bool", /*default=*/"true",
+ "CSE constant operations">,
Option<"maxIterations", "max-iterations", "int64_t",
/*default=*/"10",
"Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee55..2600df32b69c1d 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
: config(config) {
this->topDownProcessingEnabled = config.useTopDownTraversal;
this->enableRegionSimplification = config.enableRegionSimplification;
+ this->cseConstants = config.cseConstants;
this->maxIterations = config.maxIterations;
this->maxNumRewrites = config.maxNumRewrites;
this->disabledPatterns = disabledPatterns;
@@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
// Set the config from possible pass options set in the meantime.
config.useTopDownTraversal = topDownProcessingEnabled;
config.enableRegionSimplification = enableRegionSimplification;
+ config.cseConstants = cseConstants;
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff2..cf4a192a0281d7 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
- if (!insertKnownConstant(op))
+ if (!config.cseConstants || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (!insertKnownConstant(op)) {
+ if (!config.cseConstants || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58dbaa5b96f..220ea2468eed7d 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,7 +14,7 @@ func.func @bar() {
external_resources: {
mlir_reproducer: {
verify_each: true,
- // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
+ // CHECK: builtin.module(func.func(cse,canonicalize{cse-constants=true max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
disable_threading: true
}
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf4..98eae142d1870e 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
func.func @remove_op_with_inner_ops_pattern() {
@@ -89,3 +90,16 @@ func.func @test_region_simplify() {
^bb1:
return
}
+
+// CHECK-LABEL: do_not_cse_constant
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: return %[[c0]], %[[c0]]
+// NO-CSE-LABEL: do_not_cse_constant
+// NO-CSE: %[[c0:.*]] = arith.constant 0 : index
+// NO-CSE: %[[c1:.*]] = arith.constant 0 : index
+// NO-CSE: return %[[c0]], %[[c1]]
+func.func @do_not_cse_constant() -> (index, index) {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 0 : index
+ return %0, %1 : index, index
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list