[flang-commits] [flang] [mlir] [mlir] Do not merge blocks during canonicalization by default (PR #95057)

Mehdi Amini via flang-commits flang-commits at lists.llvm.org
Wed Jun 12 05:00:12 PDT 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/95057

>From 02df4d111ee7f9c6cdc4cddbf09744df5b51afd2 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 10 Jun 2024 23:15:28 +0200
Subject: [PATCH] [mlir] Do not merge blocks during canonicalization.

This is a heavy process, and it can trigger a massive explosion
in adding block arguments. While potentially reducing the code
size, the resulting merged blocks with arguments are hiding some
of the def-use chain and can even hinder some further
analyses/optimizations: a merge block does not have it's own
path-sensitive context, instead the context is merged from all the
predecessors.

Previous behavior can be restored by passing:

  {test-convergence region-simplify=aggressive}

to the canonicalize pass.
---
 flang/include/flang/Tools/CLOptions.inc         |  4 ++--
 .../HLFIR/Transforms/InlineElementals.cpp       |  2 +-
 .../HLFIR/Transforms/LowerHLFIRIntrinsics.cpp   |  2 +-
 .../HLFIR/Transforms/OptimizedBufferization.cpp |  2 +-
 .../Transforms/AssumedRankOpConversion.cpp      |  2 +-
 flang/lib/Optimizer/Transforms/StackArrays.cpp  |  2 +-
 .../Transforms/GreedyPatternRewriteDriver.h     | 13 ++++++++++++-
 mlir/include/mlir/Transforms/Passes.h           |  1 +
 mlir/include/mlir/Transforms/Passes.td          | 14 +++++++++++---
 mlir/include/mlir/Transforms/RegionUtils.h      |  5 ++++-
 .../Utils/GreedyPatternRewriteDriver.cpp        |  9 +++++++--
 mlir/lib/Transforms/Utils/RegionUtils.cpp       |  8 +++++---
 .../Dialect/SPIRV/Transforms/canonicalize.mlir  |  2 +-
 mlir/test/Pass/run-reproducer.mlir              |  4 ++--
 .../Transforms/canonicalize-block-merge.mlir    | 17 ++++++++++++++++-
 mlir/test/Transforms/canonicalize-dce.mlir      |  2 +-
 mlir/test/Transforms/canonicalize.mlir          |  7 ++++++-
 mlir/test/Transforms/test-canonicalize.mlir     |  2 +-
 18 files changed, 74 insertions(+), 24 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 2a0cfc04aa350..16f681ef2e33c 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -144,7 +144,7 @@ namespace fir {
 static void addCanonicalizerPassWithoutRegionSimplification(
     mlir::OpPassManager &pm) {
   mlir::GreedyRewriteConfig config;
-  config.enableRegionSimplification = false;
+  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
   pm.addPass(mlir::createCanonicalizerPass(config));
 }
 
@@ -260,7 +260,7 @@ inline void createDefaultFIROptimizerPassPipeline(
 
   // simplify the IR
   mlir::GreedyRewriteConfig config;
-  config.enableRegionSimplification = false;
+  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
   pm.addPass(mlir::createCSEPass());
   fir::addAVC(pm, pc.OptLevel);
   addNestedPassToAllTopLevelOperations(pm, fir::createCharacterConversion);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
index 6c8e3e1193747..eb4ddbdf9a3be 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
@@ -119,7 +119,7 @@ class InlineElementalsPass
 
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks.
-    config.enableRegionSimplification = false;
+    config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
 
     mlir::RewritePatternSet patterns(context);
     patterns.insert<InlineElementalConversion>(context);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 707c0feffbb36..f1d8418365074 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -486,7 +486,7 @@ class LowerHLFIRIntrinsics
     // Pattern rewriting only requires that the resulting IR is still valid
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks
-    config.enableRegionSimplification = false;
+    config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
 
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             module, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 3c8424ca564e2..90847cdfecbb4 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1042,7 +1042,7 @@ class OptimizedBufferizationPass
 
     mlir::GreedyRewriteConfig config;
     // Prevent the pattern driver from merging blocks
-    config.enableRegionSimplification = false;
+    config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
 
     mlir::RewritePatternSet patterns(context);
     // TODO: right now the patterns are non-conflicting,
diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
index 2c545d66ebd8e..8c1892a57f61b 100644
--- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
@@ -150,7 +150,7 @@ class AssumedRankOpConversion
     patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
     patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
     mlir::GreedyRewriteConfig config;
-    config.enableRegionSimplification = false;
+    config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
     (void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
   }
 };
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 16bbb1c356460..7157b55b47c93 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -767,7 +767,7 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
   mlir::RewritePatternSet patterns(&context);
   mlir::GreedyRewriteConfig config;
   // prevent the pattern driver form merging blocks
-  config.enableRegionSimplification = false;
+  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
 
   patterns.insert<AllocMemConversion>(&context, *candidateOps);
   if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..eaff85804f6b3 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
   ExistingOps
 };
 
+enum class GreedySimplifyRegionLevel {
+  /// Disable region control-flow simplification.
+  Disabled,
+  /// Run the normal simplification (e.g. dead args elimination).
+  Normal,
+  /// Run extra simplificiations (e.g. block merging), these can be
+  /// more costly or have some tradeoffs associated.
+  Aggressive
+};
+
 /// This class allows control over how the GreedyPatternRewriteDriver works.
 class GreedyRewriteConfig {
 public:
@@ -45,7 +55,8 @@ class GreedyRewriteConfig {
   /// patterns.
   ///
   /// Note: Only applicable when simplifying entire regions.
-  bool enableRegionSimplification = true;
+  GreedySimplifyRegionLevel enableRegionSimplification =
+      GreedySimplifyRegionLevel::Aggressive;
 
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..8e4a43c3f2458 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -15,6 +15,7 @@
 #define MLIR_TRANSFORMS_PASSES_H
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/ViewOpGraph.h"
 #include "llvm/Support/Debug.h"
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..000d9f697618e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -32,9 +32,17 @@ def Canonicalizer : Pass<"canonicalize"> {
     Option<"topDownProcessingEnabled", "top-down", "bool",
            /*default=*/"true",
            "Seed the worklist in general top-down order">,
-    Option<"enableRegionSimplification", "region-simplify", "bool",
-           /*default=*/"true",
-           "Perform control flow optimizations to the region tree">,
+    Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+           /*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
+           "Perform control flow optimizations to the region tree",
+             [{::llvm::cl::values(
+               clEnumValN(mlir::GreedySimplifyRegionLevel::Disabled, "disabled",
+                "Don't run any control-flow simplification."),
+               clEnumValN(mlir::GreedySimplifyRegionLevel::Normal, "normal",
+                "Perform simple control-flow simplifications (e.g. dead args elimination)."),
+               clEnumValN(mlir::GreedySimplifyRegionLevel::Aggressive, "aggressive",
+                "Perform aggressive control-flow simplification (e.g. block merging).")
+              )}]>,
     Option<"maxIterations", "max-iterations", "int64_t",
            /*default=*/"10",
            "Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 06eebff201d1b..5c57dd5b7532a 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -74,8 +74,11 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise. The provided rewriter is
 /// used to notify callers of operation and block deletion.
+/// Structurally similar blocks will be merged if the `mergeBlock` argument is
+/// true. Note this can lead to merged blocks with extra arguments.
 LogicalResult simplifyRegions(RewriterBase &rewriter,
-                              MutableArrayRef<Region> regions);
+                              MutableArrayRef<Region> regions,
+                              bool mergeBlocks = true);
 
 /// Erase the unreachable blocks within the provided regions. Returns success
 /// if any blocks were erased, failure otherwise.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 597cb29ce911b..e0d0acd122e26 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -875,8 +875,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
-          if (config.enableRegionSimplification)
-            continueRewrites |= succeeded(simplifyRegions(rewriter, region));
+          if (config.enableRegionSimplification !=
+              GreedySimplifyRegionLevel::Disabled) {
+            continueRewrites |= succeeded(simplifyRegions(
+                rewriter, region,
+                /*mergeBlocks=*/config.enableRegionSimplification ==
+                    GreedySimplifyRegionLevel::Aggressive));
+          }
         },
         {&region}, iteration);
   } while (continueRewrites);
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b5e641d39fc0a..4c0f15bafbaba 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -827,11 +827,13 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise.
 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
-                                    MutableArrayRef<Region> regions) {
+                                    MutableArrayRef<Region> regions,
+                                    bool mergeBlocks) {
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
-  bool mergedIdenticalBlocks =
-      succeeded(mergeIdenticalBlocks(rewriter, regions));
+  bool mergedIdenticalBlocks = false;
+  if (mergeBlocks)
+    mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 977d31a6bfe54..d07389d6822ce 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence region-simplify=aggressive}))' | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.AccessChain
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58dbaa5b96..bf3ab2dae2ff8 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,8 +14,8 @@ func.func @bar() {
   external_resources: {
     mlir_reproducer: {
       verify_each: true,
-      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
-      pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
+      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}))
+      pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
       disable_threading: true
     }
   }
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index bf44973ab646c..3b8b1fce0575a 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s
 
 // Check the simple case of single operation blocks with a return.
 
@@ -275,3 +275,18 @@ func.func @mismatch_dominance() -> i32 {
 ^bb4(%3: i32):
   return %3 : i32
 }
+
+// CHECK-LABEL: func @dead_dealloc_fold_multi_use
+func.func @dead_dealloc_fold_multi_use(%cond : i1) {
+  // CHECK-NEXT: return
+  %a = memref.alloc() : memref<4xf32>
+  cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  memref.dealloc %a: memref<4xf32>
+  return
+
+^bb2:
+  memref.dealloc %a: memref<4xf32>
+  return
+}
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 3048a7fed636b..ac034d567a26a 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' | FileCheck %s
 
 // Test case: Simple case of deleting a dead pure op.
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index d2c2c12d32389..6927189fc626f 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -387,16 +387,21 @@ func.func @dead_dealloc_fold() {
 
 // CHECK-LABEL: func @dead_dealloc_fold_multi_use
 func.func @dead_dealloc_fold_multi_use(%cond : i1) {
-  // CHECK-NEXT: return
+  // CHECK-NOT: alloc
   %a = memref.alloc() : memref<4xf32>
+  // CHECK: cond_br
   cf.cond_br %cond, ^bb1, ^bb2
 
 ^bb1:
+  // CHECK-NOT: alloc
   memref.dealloc %a: memref<4xf32>
+  // CHECK: return
   return
 
 ^bb2:
+  // CHECK-NOT: alloc
   memref.dealloc %a: memref<4xf32>
+  // CHECK: return
   return
 }
 
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf..0fc822b0a23ae 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
 func.func @remove_op_with_inner_ops_pattern() {



More information about the flang-commits mailing list