[Mlir-commits] [mlir] [mlir][pass] Add composite pass utility (PR #87166)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 30 11:15:17 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.
---
Full diff: https://github.com/llvm/llvm-project/pull/87166.diff
7 Files Affected:
- (modified) mlir/include/mlir/Transforms/Passes.h (+7)
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Transforms/CompositePass.cpp (+81)
- (added) mlir/test/Transforms/composite-pass.mlir (+25)
- (modified) mlir/test/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/test/lib/Transforms/TestCompositePass.cpp (+30)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..0cf45d8d40a93d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();
+/// Create composite pass, which runs selected set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass>
+createCompositePass(std::string name, std::string argument,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations = 10);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
add_mlir_library(MLIRTransforms
Canonicalizer.cpp
+ CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..3b9700f1f05176
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,81 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+namespace {
+struct CompositePass final
+ : public PassWrapper<CompositePass, OperationPass<void>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+
+ CompositePass(std::string name_, std::string argument_,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations)
+ : name(std::move(name_)), argument(std::move(argument_)),
+ dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+ populateFunc(*dynamicPM);
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ dynamicPM->getDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ auto op = getOperation();
+ OperationFingerPrint fp(op);
+
+ unsigned currentIter = 0;
+ while (true) {
+ if (failed(runPipeline(*dynamicPM, op)))
+ return signalPassFailure();
+
+ if (currentIter++ >= maxIters) {
+ op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+ "\"+ didn't converge in " + llvm::Twine(maxIters) +
+ " iterations");
+ break;
+ }
+
+ OperationFingerPrint newFp(op);
+ if (newFp == fp)
+ break;
+
+ fp = newFp;
+ }
+ }
+
+protected:
+ llvm::StringRef getName() const override { return name; }
+
+ llvm::StringRef getArgument() const override { return argument; }
+
+private:
+ std::string name;
+ std::string argument;
+ std::shared_ptr<OpPassManager> dynamicPM;
+ unsigned maxIters;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositePass(
+ std::string name, std::string argument,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations) {
+
+ return std::make_unique<CompositePass>(std::move(name), std::move(argument),
+ populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..4bf83d3a79754a
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+ %0 = arith.constant 1.5 : f32
+ return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
+ TestCompositePass.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..64299685b3286e
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,30 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+ registerPass([]() -> std::unique_ptr<Pass> {
+ return createCompositePass("TestCompositePass", "test-composite-pass",
+ [](OpPassManager &p) {
+ p.addPass(createCanonicalizerPass());
+ p.addPass(createCSEPass());
+ });
+ });
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
void registerVectorizerTestPass();
namespace test {
+void registerTestCompositePass();
void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
registerVectorizerTestPass();
registerTosaTestQuantUtilAPIPass();
+ mlir::test::registerTestCompositePass();
mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerInliner();
``````````
</details>
https://github.com/llvm/llvm-project/pull/87166
More information about the Mlir-commits
mailing list