[Mlir-commits] [mlir] [mlir][pass] Add composite pass utility (PR #87166)

Ivan Butygin llvmlistbot at llvm.org
Sun Mar 31 17:44:52 PDT 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/87166

>From 5acbe822f8e21c68cae9d3df952be6be024f0711 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 30 Mar 2024 19:03:37 +0100
Subject: [PATCH 1/9] [mlir][pass] Add composite pass utility

Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached.
The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.
---
 mlir/include/mlir/Transforms/Passes.h         |  7 ++
 mlir/lib/Transforms/CMakeLists.txt            |  1 +
 mlir/lib/Transforms/CompositePass.cpp         | 81 +++++++++++++++++++
 mlir/test/Transforms/composite-pass.mlir      | 25 ++++++
 mlir/test/lib/Transforms/CMakeLists.txt       |  1 +
 .../test/lib/Transforms/TestCompositePass.cpp | 30 +++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 7 files changed, 147 insertions(+)
 create mode 100644 mlir/lib/Transforms/CompositePass.cpp
 create mode 100644 mlir/test/Transforms/composite-pass.mlir
 create mode 100644 mlir/test/lib/Transforms/TestCompositePass.cpp

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..0cf45d8d40a93d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
 /// their producers.
 std::unique_ptr<Pass> createTopologicalSortPass();
 
+/// Create composite pass, which runs selected set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass>
+createCompositePass(std::string name, std::string argument,
+                    llvm::function_ref<void(OpPassManager &)> populateFunc,
+                    unsigned maxIterations = 10);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
 
 add_mlir_library(MLIRTransforms
   Canonicalizer.cpp
+  CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..3b9700f1f05176
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,81 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+namespace {
+struct CompositePass final
+    : public PassWrapper<CompositePass, OperationPass<void>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+
+  CompositePass(std::string name_, std::string argument_,
+                llvm::function_ref<void(OpPassManager &)> populateFunc,
+                unsigned maxIterations)
+      : name(std::move(name_)), argument(std::move(argument_)),
+        dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+    populateFunc(*dynamicPM);
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    dynamicPM->getDependentDialects(registry);
+  }
+
+  void runOnOperation() override {
+    auto op = getOperation();
+    OperationFingerPrint fp(op);
+
+    unsigned currentIter = 0;
+    while (true) {
+      if (failed(runPipeline(*dynamicPM, op)))
+        return signalPassFailure();
+
+      if (currentIter++ >= maxIters) {
+        op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+                        "\"+ didn't converge in " + llvm::Twine(maxIters) +
+                        " iterations");
+        break;
+      }
+
+      OperationFingerPrint newFp(op);
+      if (newFp == fp)
+        break;
+
+      fp = newFp;
+    }
+  }
+
+protected:
+  llvm::StringRef getName() const override { return name; }
+
+  llvm::StringRef getArgument() const override { return argument; }
+
+private:
+  std::string name;
+  std::string argument;
+  std::shared_ptr<OpPassManager> dynamicPM;
+  unsigned maxIters;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositePass(
+    std::string name, std::string argument,
+    llvm::function_ref<void(OpPassManager &)> populateFunc,
+    unsigned maxIterations) {
+
+  return std::make_unique<CompositePass>(std::move(name), std::move(argument),
+                                         populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..4bf83d3a79754a
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+  %0 = arith.constant 1.5 : f32
+  return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
   TestCommutativityUtils.cpp
+  TestCompositePass.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..64299685b3286e
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,30 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+  registerPass([]() -> std::unique_ptr<Pass> {
+    return createCompositePass("TestCompositePass", "test-composite-pass",
+                               [](OpPassManager &p) {
+                                 p.addPass(createCanonicalizerPass());
+                                 p.addPass(createCSEPass());
+                               });
+  });
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
 void registerVectorizerTestPass();
 
 namespace test {
+void registerTestCompositePass();
 void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerTestCompositePass();
   mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();

>From da797648fc6663affb3b9078a7520e2112fb387e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 01:43:03 +0200
Subject: [PATCH 2/9] use pipeline string

---
 mlir/include/mlir/Transforms/Passes.h         |  3 +-
 mlir/include/mlir/Transforms/Passes.td        | 16 ++++++
 mlir/lib/Transforms/CompositePass.cpp         | 52 +++++++++++++------
 mlir/test/Transforms/composite-pass.mlir      |  1 +
 .../test/lib/Transforms/TestCompositePass.cpp | 22 +++++---
 5 files changed, 69 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 0cf45d8d40a93d..a9a303c769d21b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -43,6 +43,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLDCE
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
+#define GEN_PASS_DECL_COMPOSITEPASS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
@@ -133,7 +134,7 @@ std::unique_ptr<Pass> createTopologicalSortPass();
 /// Create composite pass, which runs selected set of passes until fixed point
 /// or maximum number of iterations reached.
 std::unique_ptr<Pass>
-createCompositePass(std::string name, std::string argument,
+createCompositePass(std::string name,
                     llvm::function_ref<void(OpPassManager &)> populateFunc,
                     unsigned maxIterations = 10);
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 51b2a27da639d6..d7e9cf97516e54 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -552,4 +552,20 @@ def TopologicalSort : Pass<"topological-sort"> {
   let constructor = "mlir::createTopologicalSortPass()";
 }
 
+def CompositePass : Pass<"composite-pass"> {
+  let summary = "TBD";
+  let description = [{
+    TBD
+  }];
+
+  let options = [
+    Option<"name", "name", "std::string", /*default=*/"\"CompositePass\"",
+      "Composite pass display name">,
+    Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
+      "Composite pass inner pipeline">,
+    Option<"maxIter", "max-iterations", "unsigned", /*default=*/"10",
+      "Maximum number of iterations if inner pipeline">,
+  ];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 3b9700f1f05176..b7f2e73c847b20 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -15,19 +15,42 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_COMPOSITEPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
 using namespace mlir;
 
 namespace {
-struct CompositePass final
-    : public PassWrapper<CompositePass, OperationPass<void>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+struct CompositePass final : public impl::CompositePassBase<CompositePass> {
+  using CompositePassBase::CompositePassBase;
 
-  CompositePass(std::string name_, std::string argument_,
+  CompositePass(std::string name_,
                 llvm::function_ref<void(OpPassManager &)> populateFunc,
                 unsigned maxIterations)
-      : name(std::move(name_)), argument(std::move(argument_)),
-        dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+      : dynamicPM(std::make_shared<OpPassManager>()) {
+    name = std::move(name_);
+    maxIter = maxIterations;
     populateFunc(*dynamicPM);
+    std::string pipeline;
+    llvm::raw_string_ostream os(pipeline);
+    dynamicPM->printAsTextualPipeline(os);
+    os.flush();
+    pipelineStr = pipeline;
+  }
+
+  LogicalResult initializeOptions(StringRef options) override {
+    if (failed(CompositePassBase::initializeOptions(options)))
+      return failure();
+
+    dynamicPM = std::make_shared<OpPassManager>();
+    if (failed(parsePassPipeline(pipelineStr, *dynamicPM))) {
+      llvm::errs() << "Failed to parse composite pass pipeline\n";
+      return failure();
+    }
+
+    return success();
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -39,13 +62,14 @@ struct CompositePass final
     OperationFingerPrint fp(op);
 
     unsigned currentIter = 0;
+    unsigned maxIterVal = maxIter;
     while (true) {
       if (failed(runPipeline(*dynamicPM, op)))
         return signalPassFailure();
 
-      if (currentIter++ >= maxIters) {
+      if (currentIter++ >= maxIterVal) {
         op->emitWarning("Composite pass \"" + llvm::Twine(name) +
-                        "\"+ didn't converge in " + llvm::Twine(maxIters) +
+                        "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
                         " iterations");
         break;
       }
@@ -61,21 +85,15 @@ struct CompositePass final
 protected:
   llvm::StringRef getName() const override { return name; }
 
-  llvm::StringRef getArgument() const override { return argument; }
-
 private:
-  std::string name;
-  std::string argument;
   std::shared_ptr<OpPassManager> dynamicPM;
-  unsigned maxIters;
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::createCompositePass(
-    std::string name, std::string argument,
-    llvm::function_ref<void(OpPassManager &)> populateFunc,
+    std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
     unsigned maxIterations) {
 
-  return std::make_unique<CompositePass>(std::move(name), std::move(argument),
-                                         populateFunc, maxIterations);
+  return std::make_unique<CompositePass>(std::move(name), populateFunc,
+                                         maxIterations);
 }
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
index 4bf83d3a79754a..034d266424b9ea 100644
--- a/mlir/test/Transforms/composite-pass.mlir
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --composite-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
 
 // CHECK-LABEL: running `TestCompositePass`
 //       CHECK: running `Canonicalizer`
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
index 64299685b3286e..47d328c5c7b7af 100644
--- a/mlir/test/lib/Transforms/TestCompositePass.cpp
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -18,13 +18,21 @@
 namespace mlir {
 namespace test {
 void registerTestCompositePass() {
-  registerPass([]() -> std::unique_ptr<Pass> {
-    return createCompositePass("TestCompositePass", "test-composite-pass",
-                               [](OpPassManager &p) {
-                                 p.addPass(createCanonicalizerPass());
-                                 p.addPass(createCSEPass());
-                               });
-  });
+  registerPassPipeline(
+      "test-composite-pass", "Test composite pass",
+      [](OpPassManager &pm, StringRef optionsStr,
+         function_ref<LogicalResult(const Twine &)> errorHandler) {
+        if (!optionsStr.empty())
+          return failure();
+
+        pm.addPass(
+            createCompositePass("TestCompositePass", [](OpPassManager &p) {
+              p.addPass(createCanonicalizerPass());
+              p.addPass(createCSEPass());
+            }));
+        return success();
+      },
+      [](function_ref<void(const detail::PassOptions &)>) {});
 }
 } // namespace test
 } // namespace mlir

>From 011e2882b6d8b5976208eda85f0f6d2868d65e3f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 01:50:25 +0200
Subject: [PATCH 3/9] renamings

---
 mlir/include/mlir/Transforms/Passes.h         |  9 ++++----
 mlir/include/mlir/Transforms/Passes.td        |  2 +-
 mlir/lib/Transforms/CompositePass.cpp         | 21 ++++++++++---------
 mlir/test/Transforms/composite-pass.mlir      |  4 ++--
 .../test/lib/Transforms/TestCompositePass.cpp |  6 +++---
 5 files changed, 21 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index a9a303c769d21b..b89257d4accef5 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -43,7 +43,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLDCE
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
-#define GEN_PASS_DECL_COMPOSITEPASS
+#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
@@ -133,10 +133,9 @@ std::unique_ptr<Pass> createTopologicalSortPass();
 
 /// Create composite pass, which runs selected set of passes until fixed point
 /// or maximum number of iterations reached.
-std::unique_ptr<Pass>
-createCompositePass(std::string name,
-                    llvm::function_ref<void(OpPassManager &)> populateFunc,
-                    unsigned maxIterations = 10);
+std::unique_ptr<Pass> createCompositeFixedPointPass(
+    std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
+    unsigned maxIterations = 10);
 
 //===----------------------------------------------------------------------===//
 // Registration
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index d7e9cf97516e54..4fc5edd32ab52a 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -552,7 +552,7 @@ def TopologicalSort : Pass<"topological-sort"> {
   let constructor = "mlir::createTopologicalSortPass()";
 }
 
-def CompositePass : Pass<"composite-pass"> {
+def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   let summary = "TBD";
   let description = [{
     TBD
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index b7f2e73c847b20..5455a768383614 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -16,19 +16,20 @@
 #include "mlir/Pass/PassManager.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_COMPOSITEPASS
+#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
 using namespace mlir;
 
 namespace {
-struct CompositePass final : public impl::CompositePassBase<CompositePass> {
-  using CompositePassBase::CompositePassBase;
+struct CompositeFixedPointPass final
+    : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
+  using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
 
-  CompositePass(std::string name_,
-                llvm::function_ref<void(OpPassManager &)> populateFunc,
-                unsigned maxIterations)
+  CompositeFixedPointPass(
+      std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
+      unsigned maxIterations)
       : dynamicPM(std::make_shared<OpPassManager>()) {
     name = std::move(name_);
     maxIter = maxIterations;
@@ -41,7 +42,7 @@ struct CompositePass final : public impl::CompositePassBase<CompositePass> {
   }
 
   LogicalResult initializeOptions(StringRef options) override {
-    if (failed(CompositePassBase::initializeOptions(options)))
+    if (failed(CompositeFixedPointPassBase::initializeOptions(options)))
       return failure();
 
     dynamicPM = std::make_shared<OpPassManager>();
@@ -90,10 +91,10 @@ struct CompositePass final : public impl::CompositePassBase<CompositePass> {
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::createCompositePass(
+std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
     unsigned maxIterations) {
 
-  return std::make_unique<CompositePass>(std::move(name), populateFunc,
-                                         maxIterations);
+  return std::make_unique<CompositeFixedPointPass>(std::move(name),
+                                                   populateFunc, maxIterations);
 }
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
index 034d266424b9ea..829470c2c9aa64 100644
--- a/mlir/test/Transforms/composite-pass.mlir
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
-// RUN: mlir-opt %s --log-actions-to=- --composite-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
 
 // CHECK-LABEL: running `TestCompositePass`
 //       CHECK: running `Canonicalizer`
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
index 47d328c5c7b7af..5c0d93cc0d64ec 100644
--- a/mlir/test/lib/Transforms/TestCompositePass.cpp
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -19,14 +19,14 @@ namespace mlir {
 namespace test {
 void registerTestCompositePass() {
   registerPassPipeline(
-      "test-composite-pass", "Test composite pass",
+      "test-composite-fixed-point-pass", "Test composite pass",
       [](OpPassManager &pm, StringRef optionsStr,
          function_ref<LogicalResult(const Twine &)> errorHandler) {
         if (!optionsStr.empty())
           return failure();
 
-        pm.addPass(
-            createCompositePass("TestCompositePass", [](OpPassManager &p) {
+        pm.addPass(createCompositeFixedPointPass(
+            "TestCompositePass", [](OpPassManager &p) {
               p.addPass(createCanonicalizerPass());
               p.addPass(createCSEPass());
             }));

>From 8df25ea53b838e82a981d9354e070ff83605e8d5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 01:57:25 +0200
Subject: [PATCH 4/9] Docs

---
 mlir/include/mlir/Transforms/Passes.h  | 2 +-
 mlir/include/mlir/Transforms/Passes.td | 7 ++++---
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index b89257d4accef5..1c6e158fa4592f 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -131,7 +131,7 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
 /// their producers.
 std::unique_ptr<Pass> createTopologicalSortPass();
 
-/// Create composite pass, which runs selected set of passes until fixed point
+/// Create composite pass, which runs provided set of passes until fixed point
 /// or maximum number of iterations reached.
 std::unique_ptr<Pass> createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 4fc5edd32ab52a..9a38aaf6985c0e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -553,13 +553,14 @@ def TopologicalSort : Pass<"topological-sort"> {
 }
 
 def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
-  let summary = "TBD";
+  let summary = "Composite fixed point pass";
   let description = [{
-    TBD
+    Composite pass runs provided set of passes until fixed point or maximum
+    number of iterations reached.
   }];
 
   let options = [
-    Option<"name", "name", "std::string", /*default=*/"\"CompositePass\"",
+    Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"",
       "Composite pass display name">,
     Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
       "Composite pass inner pipeline">,

>From 332ec5ac8688c4eacc784a86787c0a7253b8f028 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:21:45 +0200
Subject: [PATCH 5/9] copy PM

---
 mlir/lib/Transforms/CompositePass.cpp | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 5455a768383614..5c01482684bf75 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -29,14 +29,13 @@ struct CompositeFixedPointPass final
 
   CompositeFixedPointPass(
       std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
-      unsigned maxIterations)
-      : dynamicPM(std::make_shared<OpPassManager>()) {
+      unsigned maxIterations) {
     name = std::move(name_);
     maxIter = maxIterations;
-    populateFunc(*dynamicPM);
+    populateFunc(dynamicPM);
     std::string pipeline;
     llvm::raw_string_ostream os(pipeline);
-    dynamicPM->printAsTextualPipeline(os);
+    dynamicPM.printAsTextualPipeline(os);
     os.flush();
     pipelineStr = pipeline;
   }
@@ -45,8 +44,7 @@ struct CompositeFixedPointPass final
     if (failed(CompositeFixedPointPassBase::initializeOptions(options)))
       return failure();
 
-    dynamicPM = std::make_shared<OpPassManager>();
-    if (failed(parsePassPipeline(pipelineStr, *dynamicPM))) {
+    if (failed(parsePassPipeline(pipelineStr, dynamicPM))) {
       llvm::errs() << "Failed to parse composite pass pipeline\n";
       return failure();
     }
@@ -55,7 +53,7 @@ struct CompositeFixedPointPass final
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    dynamicPM->getDependentDialects(registry);
+    dynamicPM.getDependentDialects(registry);
   }
 
   void runOnOperation() override {
@@ -65,7 +63,7 @@ struct CompositeFixedPointPass final
     unsigned currentIter = 0;
     unsigned maxIterVal = maxIter;
     while (true) {
-      if (failed(runPipeline(*dynamicPM, op)))
+      if (failed(runPipeline(dynamicPM, op)))
         return signalPassFailure();
 
       if (currentIter++ >= maxIterVal) {
@@ -87,7 +85,7 @@ struct CompositeFixedPointPass final
   llvm::StringRef getName() const override { return name; }
 
 private:
-  std::shared_ptr<OpPassManager> dynamicPM;
+  OpPassManager dynamicPM;
 };
 } // namespace
 

>From 93a57e22e91b0d8fdc0bb31f5f18787c9c8070be Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:26:42 +0200
Subject: [PATCH 6/9] remove temp string

---
 mlir/lib/Transforms/CompositePass.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 5c01482684bf75..9ce386782bd55f 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -33,11 +33,10 @@ struct CompositeFixedPointPass final
     name = std::move(name_);
     maxIter = maxIterations;
     populateFunc(dynamicPM);
-    std::string pipeline;
-    llvm::raw_string_ostream os(pipeline);
+
+    llvm::raw_string_ostream os(pipelineStr);
     dynamicPM.printAsTextualPipeline(os);
     os.flush();
-    pipelineStr = pipeline;
   }
 
   LogicalResult initializeOptions(StringRef options) override {

>From 53526aff12c650f6d7bc5e4e22b061d85fee6e06 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:31:22 +0200
Subject: [PATCH 7/9] use int

---
 mlir/include/mlir/Transforms/Passes.h  |  2 +-
 mlir/include/mlir/Transforms/Passes.td |  2 +-
 mlir/lib/Transforms/CompositePass.cpp  | 11 ++++++++++-
 3 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 1c6e158fa4592f..58bd61b2ae8b88 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -135,7 +135,7 @@ std::unique_ptr<Pass> createTopologicalSortPass();
 /// or maximum number of iterations reached.
 std::unique_ptr<Pass> createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
-    unsigned maxIterations = 10);
+    int maxIterations = 10);
 
 //===----------------------------------------------------------------------===//
 // Registration
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 9a38aaf6985c0e..1b40a87c63f27e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -564,7 +564,7 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
       "Composite pass display name">,
     Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
       "Composite pass inner pipeline">,
-    Option<"maxIter", "max-iterations", "unsigned", /*default=*/"10",
+    Option<"maxIter", "max-iterations", "int", /*default=*/"10",
       "Maximum number of iterations if inner pipeline">,
   ];
 }
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 9ce386782bd55f..abfe093f2e5e04 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -51,6 +51,15 @@ struct CompositeFixedPointPass final
     return success();
   }
 
+  LogicalResult initialize(MLIRContext * /*context*/) override {
+    if (maxIter <= 0) {
+      llvm::errs() << "Invalid maxIterations value: " << maxIter << "\n";
+      return failure();
+    }
+
+    return success();
+  }
+
   void getDependentDialects(DialectRegistry &registry) const override {
     dynamicPM.getDependentDialects(registry);
   }
@@ -90,7 +99,7 @@ struct CompositeFixedPointPass final
 
 std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
-    unsigned maxIterations) {
+    int maxIterations) {
 
   return std::make_unique<CompositeFixedPointPass>(std::move(name),
                                                    populateFunc, maxIterations);

>From 7ade139d40e12e714ee7b645259cea6fd4c9e53c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:39:05 +0200
Subject: [PATCH 8/9] remove flush

---
 mlir/lib/Transforms/CompositePass.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index abfe093f2e5e04..a40e5300f0d193 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -36,7 +36,6 @@ struct CompositeFixedPointPass final
 
     llvm::raw_string_ostream os(pipelineStr);
     dynamicPM.printAsTextualPipeline(os);
-    os.flush();
   }
 
   LogicalResult initializeOptions(StringRef options) override {

>From 6aee80900264d45d2e1a8f20e9cd26124514fd04 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:43:26 +0200
Subject: [PATCH 9/9] use emitError

---
 mlir/lib/Transforms/CompositePass.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index a40e5300f0d193..e961451e4d550e 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -50,11 +50,10 @@ struct CompositeFixedPointPass final
     return success();
   }
 
-  LogicalResult initialize(MLIRContext * /*context*/) override {
-    if (maxIter <= 0) {
-      llvm::errs() << "Invalid maxIterations value: " << maxIter << "\n";
-      return failure();
-    }
+  LogicalResult initialize(MLIRContext *context) override {
+    if (maxIter <= 0)
+      return emitError(UnknownLoc::get(context))
+             << "Invalid maxIterations value: " << maxIter << "\n";
 
     return success();
   }



More information about the Mlir-commits mailing list