[Mlir-commits] [mlir] [mlir] Do not merge blocks during canonicalization by default (PR #95057)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 10 16:12:27 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
This is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors.
Previous behavior can be restored by passing:
{test-convergence region-simplify=aggressive}
to the canonicalize pass.
---
Full diff: https://github.com/llvm/llvm-project/pull/95057.diff
12 Files Affected:
- (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+12-1)
- (modified) mlir/include/mlir/Transforms/Passes.h (+1)
- (modified) mlir/include/mlir/Transforms/Passes.td (+11-3)
- (modified) mlir/include/mlir/Transforms/RegionUtils.h (+4-1)
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+7-2)
- (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+5-3)
- (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+1-1)
- (modified) mlir/test/Pass/run-reproducer.mlir (+2-2)
- (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+17-1)
- (modified) mlir/test/Transforms/canonicalize-dce.mlir (+1-1)
- (modified) mlir/test/Transforms/canonicalize.mlir (+6-1)
- (modified) mlir/test/Transforms/test-canonicalize.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..eaff85804f6b3 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
ExistingOps
};
+enum class GreedySimplifyRegionLevel {
+ /// Disable region control-flow simplification.
+ Disabled,
+ /// Run the normal simplification (e.g. dead args elimination).
+ Normal,
+ /// Run extra simplificiations (e.g. block merging), these can be
+ /// more costly or have some tradeoffs associated.
+ Aggressive
+};
+
/// This class allows control over how the GreedyPatternRewriteDriver works.
class GreedyRewriteConfig {
public:
@@ -45,7 +55,8 @@ class GreedyRewriteConfig {
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
- bool enableRegionSimplification = true;
+ GreedySimplifyRegionLevel enableRegionSimplification =
+ GreedySimplifyRegionLevel::Aggressive;
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..8e4a43c3f2458 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -15,6 +15,7 @@
#define MLIR_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..000d9f697618e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -32,9 +32,17 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
- Option<"enableRegionSimplification", "region-simplify", "bool",
- /*default=*/"true",
- "Perform control flow optimizations to the region tree">,
+ Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+ /*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
+ "Perform control flow optimizations to the region tree",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::GreedySimplifyRegionLevel::Disabled, "disabled",
+ "Don't run any control-flow simplification."),
+ clEnumValN(mlir::GreedySimplifyRegionLevel::Normal, "normal",
+ "Perform simple control-flow simplifications (e.g. dead args elimination)."),
+ clEnumValN(mlir::GreedySimplifyRegionLevel::Aggressive, "aggressive",
+ "Perform aggressive control-flow simplification (e.g. block merging).")
+ )}]>,
Option<"maxIterations", "max-iterations", "int64_t",
/*default=*/"10",
"Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 192ff71384059..86b22839f6335 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -74,8 +74,11 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
/// elimination, as well as some other DCE. This function returns success if any
/// of the regions were simplified, failure otherwise. The provided rewriter is
/// used to notify callers of operation and block deletion.
+/// Structurally similar blocks will be merged if the `mergeBlock` argument is
+/// true. Note this can lead to merged blocks with extra arguments.
LogicalResult simplifyRegions(RewriterBase &rewriter,
- MutableArrayRef<Region> regions);
+ MutableArrayRef<Region> regions,
+ bool mergeBlocks = true);
/// Erase the unreachable blocks within the provided regions. Returns success
/// if any blocks were erased, failure otherwise.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff..d22b3d3672425 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -871,8 +871,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
- if (config.enableRegionSimplification)
- continueRewrites |= succeeded(simplifyRegions(*this, region));
+ if (config.enableRegionSimplification !=
+ GreedySimplifyRegionLevel::Disabled) {
+ continueRewrites |= succeeded(simplifyRegions(
+ *this, region,
+ /*mergeBlocks=*/config.enableRegionSimplification ==
+ GreedySimplifyRegionLevel::Aggressive));
+ }
},
{®ion}, iteration);
} while (continueRewrites);
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e25867b527b71..a1bebc4809c45 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -828,11 +828,13 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
/// elimination, as well as some other DCE. This function returns success if any
/// of the regions were simplified, failure otherwise.
LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
- MutableArrayRef<Region> regions) {
+ MutableArrayRef<Region> regions,
+ bool mergeBlocks) {
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
- bool mergedIdenticalBlocks =
- succeeded(mergeIdenticalBlocks(rewriter, regions));
+ bool mergedIdenticalBlocks = false;
+ if (mergeBlocks)
+ mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 977d31a6bfe54..d07389d6822ce 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence region-simplify=aggressive}))' | FileCheck %s
//===----------------------------------------------------------------------===//
// spirv.AccessChain
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58dbaa5b96..bf3ab2dae2ff8 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,8 +14,8 @@ func.func @bar() {
external_resources: {
mlir_reproducer: {
verify_each: true,
- // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
- pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
+ // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}))
+ pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
disable_threading: true
}
}
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index bf44973ab646c..122bfcca66a63 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s
// Check the simple case of single operation blocks with a return.
@@ -275,3 +275,19 @@ func.func @mismatch_dominance() -> i32 {
^bb4(%3: i32):
return %3 : i32
}
+
+
+// CHECK-LABEL: func @dead_dealloc_fold_multi_use
+func.func @dead_dealloc_fold_multi_use(%cond : i1) {
+ // CHECK-NEXT: return
+ %a = memref.alloc() : memref<4xf32>
+ cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+ memref.dealloc %a: memref<4xf32>
+ return
+
+^bb2:
+ memref.dealloc %a: memref<4xf32>
+ return
+}
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 3048a7fed636b..ac034d567a26a 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' | FileCheck %s
// Test case: Simple case of deleting a dead pure op.
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index d2c2c12d32389..6927189fc626f 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -387,16 +387,21 @@ func.func @dead_dealloc_fold() {
// CHECK-LABEL: func @dead_dealloc_fold_multi_use
func.func @dead_dealloc_fold_multi_use(%cond : i1) {
- // CHECK-NEXT: return
+ // CHECK-NOT: alloc
%a = memref.alloc() : memref<4xf32>
+ // CHECK: cond_br
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
+ // CHECK-NOT: alloc
memref.dealloc %a: memref<4xf32>
+ // CHECK: return
return
^bb2:
+ // CHECK-NOT: alloc
memref.dealloc %a: memref<4xf32>
+ // CHECK: return
return
}
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf..0fc822b0a23ae 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
func.func @remove_op_with_inner_ops_pattern() {
``````````
</details>
https://github.com/llvm/llvm-project/pull/95057
More information about the Mlir-commits
mailing list