[Mlir-commits] [mlir] 5b66b6a - [mlir][pass] Add composite pass utility (#87166)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 2 03:30:48 PDT 2024


Author: Ivan Butygin
Date: 2024-04-02T13:30:45+03:00
New Revision: 5b66b6a32ad89562732ad6a81c84783486b6187a

URL: https://github.com/llvm/llvm-project/commit/5b66b6a32ad89562732ad6a81c84783486b6187a
DIFF: https://github.com/llvm/llvm-project/commit/5b66b6a32ad89562732ad6a81c84783486b6187a.diff

LOG: [mlir][pass] Add composite pass utility (#87166)

Composite pass allows to run sequence of passes in the loop until fixed
point or maximum number of iterations is reached. The usual candidates
are canonicalize+CSE as canonicalize can open more opportunities for CSE
and vice-versa.

Added: 
    mlir/lib/Transforms/CompositePass.cpp
    mlir/test/Transforms/composite-pass.mlir
    mlir/test/lib/Transforms/TestCompositePass.cpp

Modified: 
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/CMakeLists.txt
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..58bd61b2ae8b88 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -43,6 +43,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLDCE
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
+#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
@@ -130,6 +131,12 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
 /// their producers.
 std::unique_ptr<Pass> createTopologicalSortPass();
 
+/// Create composite pass, which runs provided set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass> createCompositeFixedPointPass(
+    std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
+    int maxIterations = 10);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 51b2a27da639d6..1b40a87c63f27e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -552,4 +552,21 @@ def TopologicalSort : Pass<"topological-sort"> {
   let constructor = "mlir::createTopologicalSortPass()";
 }
 
+def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
+  let summary = "Composite fixed point pass";
+  let description = [{
+    Composite pass runs provided set of passes until fixed point or maximum
+    number of iterations reached.
+  }];
+
+  let options = [
+    Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"",
+      "Composite pass display name">,
+    Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
+      "Composite pass inner pipeline">,
+    Option<"maxIter", "max-iterations", "int", /*default=*/"10",
+      "Maximum number of iterations if inner pipeline">,
+  ];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
 
 add_mlir_library(MLIRTransforms
   Canonicalizer.cpp
+  CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp

diff  --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..b388a28da6424f
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,105 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct CompositeFixedPointPass final
+    : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
+  using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
+
+  CompositeFixedPointPass(
+      std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
+      int maxIterations) {
+    name = std::move(name_);
+    maxIter = maxIterations;
+    populateFunc(dynamicPM);
+
+    llvm::raw_string_ostream os(pipelineStr);
+    dynamicPM.printAsTextualPipeline(os);
+  }
+
+  LogicalResult initializeOptions(
+      StringRef options,
+      function_ref<LogicalResult(const Twine &)> errorHandler) override {
+    if (failed(CompositeFixedPointPassBase::initializeOptions(options,
+                                                              errorHandler)))
+      return failure();
+
+    if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
+      return errorHandler("Failed to parse composite pass pipeline");
+
+    return success();
+  }
+
+  LogicalResult initialize(MLIRContext *context) override {
+    if (maxIter <= 0)
+      return emitError(UnknownLoc::get(context))
+             << "Invalid maxIterations value: " << maxIter << "\n";
+
+    return success();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    dynamicPM.getDependentDialects(registry);
+  }
+
+  void runOnOperation() override {
+    auto op = getOperation();
+    OperationFingerPrint fp(op);
+
+    int currentIter = 0;
+    int maxIterVal = maxIter;
+    while (true) {
+      if (failed(runPipeline(dynamicPM, op)))
+        return signalPassFailure();
+
+      if (currentIter++ >= maxIterVal) {
+        op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+                        "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
+                        " iterations");
+        break;
+      }
+
+      OperationFingerPrint newFp(op);
+      if (newFp == fp)
+        break;
+
+      fp = newFp;
+    }
+  }
+
+protected:
+  llvm::StringRef getName() const override { return name; }
+
+private:
+  OpPassManager dynamicPM;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
+    std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
+    int maxIterations) {
+
+  return std::make_unique<CompositeFixedPointPass>(std::move(name),
+                                                   populateFunc, maxIterations);
+}

diff  --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..829470c2c9aa64
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+  %0 = arith.constant 1.5 : f32
+  return
+}

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
   TestCommutativityUtils.cpp
+  TestCompositePass.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp

diff  --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..5c0d93cc0d64ec
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,38 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+  registerPassPipeline(
+      "test-composite-fixed-point-pass", "Test composite pass",
+      [](OpPassManager &pm, StringRef optionsStr,
+         function_ref<LogicalResult(const Twine &)> errorHandler) {
+        if (!optionsStr.empty())
+          return failure();
+
+        pm.addPass(createCompositeFixedPointPass(
+            "TestCompositePass", [](OpPassManager &p) {
+              p.addPass(createCanonicalizerPass());
+              p.addPass(createCSEPass());
+            }));
+        return success();
+      },
+      [](function_ref<void(const detail::PassOptions &)>) {});
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
 void registerVectorizerTestPass();
 
 namespace test {
+void registerTestCompositePass();
 void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerTestCompositePass();
   mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();


        


More information about the Mlir-commits mailing list