[Mlir-commits] [mlir] [mlir][pass] Add composite pass utility (PR #87166)
Ivan Butygin
llvmlistbot at llvm.org
Sun Mar 31 17:44:52 PDT 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/87166
>From 5acbe822f8e21c68cae9d3df952be6be024f0711 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 30 Mar 2024 19:03:37 +0100
Subject: [PATCH 1/9] [mlir][pass] Add composite pass utility
Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached.
The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.
---
mlir/include/mlir/Transforms/Passes.h | 7 ++
mlir/lib/Transforms/CMakeLists.txt | 1 +
mlir/lib/Transforms/CompositePass.cpp | 81 +++++++++++++++++++
mlir/test/Transforms/composite-pass.mlir | 25 ++++++
mlir/test/lib/Transforms/CMakeLists.txt | 1 +
.../test/lib/Transforms/TestCompositePass.cpp | 30 +++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
7 files changed, 147 insertions(+)
create mode 100644 mlir/lib/Transforms/CompositePass.cpp
create mode 100644 mlir/test/Transforms/composite-pass.mlir
create mode 100644 mlir/test/lib/Transforms/TestCompositePass.cpp
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..0cf45d8d40a93d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();
+/// Create composite pass, which runs selected set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass>
+createCompositePass(std::string name, std::string argument,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations = 10);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
add_mlir_library(MLIRTransforms
Canonicalizer.cpp
+ CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..3b9700f1f05176
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,81 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+namespace {
+struct CompositePass final
+ : public PassWrapper<CompositePass, OperationPass<void>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+
+ CompositePass(std::string name_, std::string argument_,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations)
+ : name(std::move(name_)), argument(std::move(argument_)),
+ dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+ populateFunc(*dynamicPM);
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ dynamicPM->getDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ auto op = getOperation();
+ OperationFingerPrint fp(op);
+
+ unsigned currentIter = 0;
+ while (true) {
+ if (failed(runPipeline(*dynamicPM, op)))
+ return signalPassFailure();
+
+ if (currentIter++ >= maxIters) {
+ op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+ "\"+ didn't converge in " + llvm::Twine(maxIters) +
+ " iterations");
+ break;
+ }
+
+ OperationFingerPrint newFp(op);
+ if (newFp == fp)
+ break;
+
+ fp = newFp;
+ }
+ }
+
+protected:
+ llvm::StringRef getName() const override { return name; }
+
+ llvm::StringRef getArgument() const override { return argument; }
+
+private:
+ std::string name;
+ std::string argument;
+ std::shared_ptr<OpPassManager> dynamicPM;
+ unsigned maxIters;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositePass(
+ std::string name, std::string argument,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations) {
+
+ return std::make_unique<CompositePass>(std::move(name), std::move(argument),
+ populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..4bf83d3a79754a
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+ %0 = arith.constant 1.5 : f32
+ return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
+ TestCompositePass.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..64299685b3286e
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,30 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+ registerPass([]() -> std::unique_ptr<Pass> {
+ return createCompositePass("TestCompositePass", "test-composite-pass",
+ [](OpPassManager &p) {
+ p.addPass(createCanonicalizerPass());
+ p.addPass(createCSEPass());
+ });
+ });
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
void registerVectorizerTestPass();
namespace test {
+void registerTestCompositePass();
void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
registerVectorizerTestPass();
registerTosaTestQuantUtilAPIPass();
+ mlir::test::registerTestCompositePass();
mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerInliner();
>From da797648fc6663affb3b9078a7520e2112fb387e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 01:43:03 +0200
Subject: [PATCH 2/9] use pipeline string
---
mlir/include/mlir/Transforms/Passes.h | 3 +-
mlir/include/mlir/Transforms/Passes.td | 16 ++++++
mlir/lib/Transforms/CompositePass.cpp | 52 +++++++++++++------
mlir/test/Transforms/composite-pass.mlir | 1 +
.../test/lib/Transforms/TestCompositePass.cpp | 22 +++++---
5 files changed, 69 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 0cf45d8d40a93d..a9a303c769d21b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -43,6 +43,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLDCE
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
+#define GEN_PASS_DECL_COMPOSITEPASS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -133,7 +134,7 @@ std::unique_ptr<Pass> createTopologicalSortPass();
/// Create composite pass, which runs selected set of passes until fixed point
/// or maximum number of iterations reached.
std::unique_ptr<Pass>
-createCompositePass(std::string name, std::string argument,
+createCompositePass(std::string name,
llvm::function_ref<void(OpPassManager &)> populateFunc,
unsigned maxIterations = 10);
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 51b2a27da639d6..d7e9cf97516e54 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -552,4 +552,20 @@ def TopologicalSort : Pass<"topological-sort"> {
let constructor = "mlir::createTopologicalSortPass()";
}
+def CompositePass : Pass<"composite-pass"> {
+ let summary = "TBD";
+ let description = [{
+ TBD
+ }];
+
+ let options = [
+ Option<"name", "name", "std::string", /*default=*/"\"CompositePass\"",
+ "Composite pass display name">,
+ Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
+ "Composite pass inner pipeline">,
+ Option<"maxIter", "max-iterations", "unsigned", /*default=*/"10",
+ "Maximum number of iterations if inner pipeline">,
+ ];
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 3b9700f1f05176..b7f2e73c847b20 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -15,19 +15,42 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+namespace mlir {
+#define GEN_PASS_DEF_COMPOSITEPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
using namespace mlir;
namespace {
-struct CompositePass final
- : public PassWrapper<CompositePass, OperationPass<void>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+struct CompositePass final : public impl::CompositePassBase<CompositePass> {
+ using CompositePassBase::CompositePassBase;
- CompositePass(std::string name_, std::string argument_,
+ CompositePass(std::string name_,
llvm::function_ref<void(OpPassManager &)> populateFunc,
unsigned maxIterations)
- : name(std::move(name_)), argument(std::move(argument_)),
- dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+ : dynamicPM(std::make_shared<OpPassManager>()) {
+ name = std::move(name_);
+ maxIter = maxIterations;
populateFunc(*dynamicPM);
+ std::string pipeline;
+ llvm::raw_string_ostream os(pipeline);
+ dynamicPM->printAsTextualPipeline(os);
+ os.flush();
+ pipelineStr = pipeline;
+ }
+
+ LogicalResult initializeOptions(StringRef options) override {
+ if (failed(CompositePassBase::initializeOptions(options)))
+ return failure();
+
+ dynamicPM = std::make_shared<OpPassManager>();
+ if (failed(parsePassPipeline(pipelineStr, *dynamicPM))) {
+ llvm::errs() << "Failed to parse composite pass pipeline\n";
+ return failure();
+ }
+
+ return success();
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -39,13 +62,14 @@ struct CompositePass final
OperationFingerPrint fp(op);
unsigned currentIter = 0;
+ unsigned maxIterVal = maxIter;
while (true) {
if (failed(runPipeline(*dynamicPM, op)))
return signalPassFailure();
- if (currentIter++ >= maxIters) {
+ if (currentIter++ >= maxIterVal) {
op->emitWarning("Composite pass \"" + llvm::Twine(name) +
- "\"+ didn't converge in " + llvm::Twine(maxIters) +
+ "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
" iterations");
break;
}
@@ -61,21 +85,15 @@ struct CompositePass final
protected:
llvm::StringRef getName() const override { return name; }
- llvm::StringRef getArgument() const override { return argument; }
-
private:
- std::string name;
- std::string argument;
std::shared_ptr<OpPassManager> dynamicPM;
- unsigned maxIters;
};
} // namespace
std::unique_ptr<Pass> mlir::createCompositePass(
- std::string name, std::string argument,
- llvm::function_ref<void(OpPassManager &)> populateFunc,
+ std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
unsigned maxIterations) {
- return std::make_unique<CompositePass>(std::move(name), std::move(argument),
- populateFunc, maxIterations);
+ return std::make_unique<CompositePass>(std::move(name), populateFunc,
+ maxIterations);
}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
index 4bf83d3a79754a..034d266424b9ea 100644
--- a/mlir/test/Transforms/composite-pass.mlir
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --composite-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
// CHECK-LABEL: running `TestCompositePass`
// CHECK: running `Canonicalizer`
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
index 64299685b3286e..47d328c5c7b7af 100644
--- a/mlir/test/lib/Transforms/TestCompositePass.cpp
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -18,13 +18,21 @@
namespace mlir {
namespace test {
void registerTestCompositePass() {
- registerPass([]() -> std::unique_ptr<Pass> {
- return createCompositePass("TestCompositePass", "test-composite-pass",
- [](OpPassManager &p) {
- p.addPass(createCanonicalizerPass());
- p.addPass(createCSEPass());
- });
- });
+ registerPassPipeline(
+ "test-composite-pass", "Test composite pass",
+ [](OpPassManager &pm, StringRef optionsStr,
+ function_ref<LogicalResult(const Twine &)> errorHandler) {
+ if (!optionsStr.empty())
+ return failure();
+
+ pm.addPass(
+ createCompositePass("TestCompositePass", [](OpPassManager &p) {
+ p.addPass(createCanonicalizerPass());
+ p.addPass(createCSEPass());
+ }));
+ return success();
+ },
+ [](function_ref<void(const detail::PassOptions &)>) {});
}
} // namespace test
} // namespace mlir
>From 011e2882b6d8b5976208eda85f0f6d2868d65e3f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 01:50:25 +0200
Subject: [PATCH 3/9] renamings
---
mlir/include/mlir/Transforms/Passes.h | 9 ++++----
mlir/include/mlir/Transforms/Passes.td | 2 +-
mlir/lib/Transforms/CompositePass.cpp | 21 ++++++++++---------
mlir/test/Transforms/composite-pass.mlir | 4 ++--
.../test/lib/Transforms/TestCompositePass.cpp | 6 +++---
5 files changed, 21 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index a9a303c769d21b..b89257d4accef5 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -43,7 +43,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLDCE
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
-#define GEN_PASS_DECL_COMPOSITEPASS
+#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -133,10 +133,9 @@ std::unique_ptr<Pass> createTopologicalSortPass();
/// Create composite pass, which runs selected set of passes until fixed point
/// or maximum number of iterations reached.
-std::unique_ptr<Pass>
-createCompositePass(std::string name,
- llvm::function_ref<void(OpPassManager &)> populateFunc,
- unsigned maxIterations = 10);
+std::unique_ptr<Pass> createCompositeFixedPointPass(
+ std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations = 10);
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index d7e9cf97516e54..4fc5edd32ab52a 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -552,7 +552,7 @@ def TopologicalSort : Pass<"topological-sort"> {
let constructor = "mlir::createTopologicalSortPass()";
}
-def CompositePass : Pass<"composite-pass"> {
+def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
let summary = "TBD";
let description = [{
TBD
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index b7f2e73c847b20..5455a768383614 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -16,19 +16,20 @@
#include "mlir/Pass/PassManager.h"
namespace mlir {
-#define GEN_PASS_DEF_COMPOSITEPASS
+#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
-struct CompositePass final : public impl::CompositePassBase<CompositePass> {
- using CompositePassBase::CompositePassBase;
+struct CompositeFixedPointPass final
+ : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
+ using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
- CompositePass(std::string name_,
- llvm::function_ref<void(OpPassManager &)> populateFunc,
- unsigned maxIterations)
+ CompositeFixedPointPass(
+ std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations)
: dynamicPM(std::make_shared<OpPassManager>()) {
name = std::move(name_);
maxIter = maxIterations;
@@ -41,7 +42,7 @@ struct CompositePass final : public impl::CompositePassBase<CompositePass> {
}
LogicalResult initializeOptions(StringRef options) override {
- if (failed(CompositePassBase::initializeOptions(options)))
+ if (failed(CompositeFixedPointPassBase::initializeOptions(options)))
return failure();
dynamicPM = std::make_shared<OpPassManager>();
@@ -90,10 +91,10 @@ struct CompositePass final : public impl::CompositePassBase<CompositePass> {
};
} // namespace
-std::unique_ptr<Pass> mlir::createCompositePass(
+std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
unsigned maxIterations) {
- return std::make_unique<CompositePass>(std::move(name), populateFunc,
- maxIterations);
+ return std::make_unique<CompositeFixedPointPass>(std::move(name),
+ populateFunc, maxIterations);
}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
index 034d266424b9ea..829470c2c9aa64 100644
--- a/mlir/test/Transforms/composite-pass.mlir
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
-// RUN: mlir-opt %s --log-actions-to=- --composite-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
// CHECK-LABEL: running `TestCompositePass`
// CHECK: running `Canonicalizer`
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
index 47d328c5c7b7af..5c0d93cc0d64ec 100644
--- a/mlir/test/lib/Transforms/TestCompositePass.cpp
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -19,14 +19,14 @@ namespace mlir {
namespace test {
void registerTestCompositePass() {
registerPassPipeline(
- "test-composite-pass", "Test composite pass",
+ "test-composite-fixed-point-pass", "Test composite pass",
[](OpPassManager &pm, StringRef optionsStr,
function_ref<LogicalResult(const Twine &)> errorHandler) {
if (!optionsStr.empty())
return failure();
- pm.addPass(
- createCompositePass("TestCompositePass", [](OpPassManager &p) {
+ pm.addPass(createCompositeFixedPointPass(
+ "TestCompositePass", [](OpPassManager &p) {
p.addPass(createCanonicalizerPass());
p.addPass(createCSEPass());
}));
>From 8df25ea53b838e82a981d9354e070ff83605e8d5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 01:57:25 +0200
Subject: [PATCH 4/9] Docs
---
mlir/include/mlir/Transforms/Passes.h | 2 +-
mlir/include/mlir/Transforms/Passes.td | 7 ++++---
2 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index b89257d4accef5..1c6e158fa4592f 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -131,7 +131,7 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();
-/// Create composite pass, which runs selected set of passes until fixed point
+/// Create composite pass, which runs provided set of passes until fixed point
/// or maximum number of iterations reached.
std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 4fc5edd32ab52a..9a38aaf6985c0e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -553,13 +553,14 @@ def TopologicalSort : Pass<"topological-sort"> {
}
def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
- let summary = "TBD";
+ let summary = "Composite fixed point pass";
let description = [{
- TBD
+ Composite pass runs provided set of passes until fixed point or maximum
+ number of iterations reached.
}];
let options = [
- Option<"name", "name", "std::string", /*default=*/"\"CompositePass\"",
+ Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"",
"Composite pass display name">,
Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
"Composite pass inner pipeline">,
>From 332ec5ac8688c4eacc784a86787c0a7253b8f028 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:21:45 +0200
Subject: [PATCH 5/9] copy PM
---
mlir/lib/Transforms/CompositePass.cpp | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 5455a768383614..5c01482684bf75 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -29,14 +29,13 @@ struct CompositeFixedPointPass final
CompositeFixedPointPass(
std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
- unsigned maxIterations)
- : dynamicPM(std::make_shared<OpPassManager>()) {
+ unsigned maxIterations) {
name = std::move(name_);
maxIter = maxIterations;
- populateFunc(*dynamicPM);
+ populateFunc(dynamicPM);
std::string pipeline;
llvm::raw_string_ostream os(pipeline);
- dynamicPM->printAsTextualPipeline(os);
+ dynamicPM.printAsTextualPipeline(os);
os.flush();
pipelineStr = pipeline;
}
@@ -45,8 +44,7 @@ struct CompositeFixedPointPass final
if (failed(CompositeFixedPointPassBase::initializeOptions(options)))
return failure();
- dynamicPM = std::make_shared<OpPassManager>();
- if (failed(parsePassPipeline(pipelineStr, *dynamicPM))) {
+ if (failed(parsePassPipeline(pipelineStr, dynamicPM))) {
llvm::errs() << "Failed to parse composite pass pipeline\n";
return failure();
}
@@ -55,7 +53,7 @@ struct CompositeFixedPointPass final
}
void getDependentDialects(DialectRegistry ®istry) const override {
- dynamicPM->getDependentDialects(registry);
+ dynamicPM.getDependentDialects(registry);
}
void runOnOperation() override {
@@ -65,7 +63,7 @@ struct CompositeFixedPointPass final
unsigned currentIter = 0;
unsigned maxIterVal = maxIter;
while (true) {
- if (failed(runPipeline(*dynamicPM, op)))
+ if (failed(runPipeline(dynamicPM, op)))
return signalPassFailure();
if (currentIter++ >= maxIterVal) {
@@ -87,7 +85,7 @@ struct CompositeFixedPointPass final
llvm::StringRef getName() const override { return name; }
private:
- std::shared_ptr<OpPassManager> dynamicPM;
+ OpPassManager dynamicPM;
};
} // namespace
>From 93a57e22e91b0d8fdc0bb31f5f18787c9c8070be Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:26:42 +0200
Subject: [PATCH 6/9] remove temp string
---
mlir/lib/Transforms/CompositePass.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 5c01482684bf75..9ce386782bd55f 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -33,11 +33,10 @@ struct CompositeFixedPointPass final
name = std::move(name_);
maxIter = maxIterations;
populateFunc(dynamicPM);
- std::string pipeline;
- llvm::raw_string_ostream os(pipeline);
+
+ llvm::raw_string_ostream os(pipelineStr);
dynamicPM.printAsTextualPipeline(os);
os.flush();
- pipelineStr = pipeline;
}
LogicalResult initializeOptions(StringRef options) override {
>From 53526aff12c650f6d7bc5e4e22b061d85fee6e06 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:31:22 +0200
Subject: [PATCH 7/9] use int
---
mlir/include/mlir/Transforms/Passes.h | 2 +-
mlir/include/mlir/Transforms/Passes.td | 2 +-
mlir/lib/Transforms/CompositePass.cpp | 11 ++++++++++-
3 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 1c6e158fa4592f..58bd61b2ae8b88 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -135,7 +135,7 @@ std::unique_ptr<Pass> createTopologicalSortPass();
/// or maximum number of iterations reached.
std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
- unsigned maxIterations = 10);
+ int maxIterations = 10);
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 9a38aaf6985c0e..1b40a87c63f27e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -564,7 +564,7 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
"Composite pass display name">,
Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
"Composite pass inner pipeline">,
- Option<"maxIter", "max-iterations", "unsigned", /*default=*/"10",
+ Option<"maxIter", "max-iterations", "int", /*default=*/"10",
"Maximum number of iterations if inner pipeline">,
];
}
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index 9ce386782bd55f..abfe093f2e5e04 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -51,6 +51,15 @@ struct CompositeFixedPointPass final
return success();
}
+ LogicalResult initialize(MLIRContext * /*context*/) override {
+ if (maxIter <= 0) {
+ llvm::errs() << "Invalid maxIterations value: " << maxIter << "\n";
+ return failure();
+ }
+
+ return success();
+ }
+
void getDependentDialects(DialectRegistry ®istry) const override {
dynamicPM.getDependentDialects(registry);
}
@@ -90,7 +99,7 @@ struct CompositeFixedPointPass final
std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
- unsigned maxIterations) {
+ int maxIterations) {
return std::make_unique<CompositeFixedPointPass>(std::move(name),
populateFunc, maxIterations);
>From 7ade139d40e12e714ee7b645259cea6fd4c9e53c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:39:05 +0200
Subject: [PATCH 8/9] remove flush
---
mlir/lib/Transforms/CompositePass.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index abfe093f2e5e04..a40e5300f0d193 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -36,7 +36,6 @@ struct CompositeFixedPointPass final
llvm::raw_string_ostream os(pipelineStr);
dynamicPM.printAsTextualPipeline(os);
- os.flush();
}
LogicalResult initializeOptions(StringRef options) override {
>From 6aee80900264d45d2e1a8f20e9cd26124514fd04 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 1 Apr 2024 02:43:26 +0200
Subject: [PATCH 9/9] use emitError
---
mlir/lib/Transforms/CompositePass.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
index a40e5300f0d193..e961451e4d550e 100644
--- a/mlir/lib/Transforms/CompositePass.cpp
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -50,11 +50,10 @@ struct CompositeFixedPointPass final
return success();
}
- LogicalResult initialize(MLIRContext * /*context*/) override {
- if (maxIter <= 0) {
- llvm::errs() << "Invalid maxIterations value: " << maxIter << "\n";
- return failure();
- }
+ LogicalResult initialize(MLIRContext *context) override {
+ if (maxIter <= 0)
+ return emitError(UnknownLoc::get(context))
+ << "Invalid maxIterations value: " << maxIter << "\n";
return success();
}
More information about the Mlir-commits
mailing list