[Mlir-commits] [mlir] [mlir][pass] Add composite pass utility (PR #87166)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 30 11:15:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.

---
Full diff: https://github.com/llvm/llvm-project/pull/87166.diff


7 Files Affected:

- (modified) mlir/include/mlir/Transforms/Passes.h (+7) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Transforms/CompositePass.cpp (+81) 
- (added) mlir/test/Transforms/composite-pass.mlir (+25) 
- (modified) mlir/test/lib/Transforms/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Transforms/TestCompositePass.cpp (+30) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..0cf45d8d40a93d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
 /// their producers.
 std::unique_ptr<Pass> createTopologicalSortPass();
 
+/// Create composite pass, which runs selected set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass>
+createCompositePass(std::string name, std::string argument,
+                    llvm::function_ref<void(OpPassManager &)> populateFunc,
+                    unsigned maxIterations = 10);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
 
 add_mlir_library(MLIRTransforms
   Canonicalizer.cpp
+  CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..3b9700f1f05176
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,81 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+namespace {
+struct CompositePass final
+    : public PassWrapper<CompositePass, OperationPass<void>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+
+  CompositePass(std::string name_, std::string argument_,
+                llvm::function_ref<void(OpPassManager &)> populateFunc,
+                unsigned maxIterations)
+      : name(std::move(name_)), argument(std::move(argument_)),
+        dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+    populateFunc(*dynamicPM);
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    dynamicPM->getDependentDialects(registry);
+  }
+
+  void runOnOperation() override {
+    auto op = getOperation();
+    OperationFingerPrint fp(op);
+
+    unsigned currentIter = 0;
+    while (true) {
+      if (failed(runPipeline(*dynamicPM, op)))
+        return signalPassFailure();
+
+      if (currentIter++ >= maxIters) {
+        op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+                        "\"+ didn't converge in " + llvm::Twine(maxIters) +
+                        " iterations");
+        break;
+      }
+
+      OperationFingerPrint newFp(op);
+      if (newFp == fp)
+        break;
+
+      fp = newFp;
+    }
+  }
+
+protected:
+  llvm::StringRef getName() const override { return name; }
+
+  llvm::StringRef getArgument() const override { return argument; }
+
+private:
+  std::string name;
+  std::string argument;
+  std::shared_ptr<OpPassManager> dynamicPM;
+  unsigned maxIters;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositePass(
+    std::string name, std::string argument,
+    llvm::function_ref<void(OpPassManager &)> populateFunc,
+    unsigned maxIterations) {
+
+  return std::make_unique<CompositePass>(std::move(name), std::move(argument),
+                                         populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..4bf83d3a79754a
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+  %0 = arith.constant 1.5 : f32
+  return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
   TestCommutativityUtils.cpp
+  TestCompositePass.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..64299685b3286e
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,30 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+  registerPass([]() -> std::unique_ptr<Pass> {
+    return createCompositePass("TestCompositePass", "test-composite-pass",
+                               [](OpPassManager &p) {
+                                 p.addPass(createCanonicalizerPass());
+                                 p.addPass(createCSEPass());
+                               });
+  });
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
 void registerVectorizerTestPass();
 
 namespace test {
+void registerTestCompositePass();
 void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerTestCompositePass();
   mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();

``````````

</details>


https://github.com/llvm/llvm-project/pull/87166


More information about the Mlir-commits mailing list