[Mlir-commits] [mlir] 5b66b6a - [mlir][pass] Add composite pass utility (#87166)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 2 03:30:48 PDT 2024
Author: Ivan Butygin
Date: 2024-04-02T13:30:45+03:00
New Revision: 5b66b6a32ad89562732ad6a81c84783486b6187a
URL: https://github.com/llvm/llvm-project/commit/5b66b6a32ad89562732ad6a81c84783486b6187a
DIFF: https://github.com/llvm/llvm-project/commit/5b66b6a32ad89562732ad6a81c84783486b6187a.diff
LOG: [mlir][pass] Add composite pass utility (#87166)
Composite pass allows to run sequence of passes in the loop until fixed
point or maximum number of iterations is reached. The usual candidates
are canonicalize+CSE as canonicalize can open more opportunities for CSE
and vice-versa.
Added:
mlir/lib/Transforms/CompositePass.cpp
mlir/test/Transforms/composite-pass.mlir
mlir/test/lib/Transforms/TestCompositePass.cpp
Modified:
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..58bd61b2ae8b88 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -43,6 +43,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLDCE
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
+#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -130,6 +131,12 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();
+/// Create composite pass, which runs provided set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass> createCompositeFixedPointPass(
+ std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
+ int maxIterations = 10);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 51b2a27da639d6..1b40a87c63f27e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -552,4 +552,21 @@ def TopologicalSort : Pass<"topological-sort"> {
let constructor = "mlir::createTopologicalSortPass()";
}
+def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
+ let summary = "Composite fixed point pass";
+ let description = [{
+ Composite pass runs provided set of passes until fixed point or maximum
+ number of iterations reached.
+ }];
+
+ let options = [
+ Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"",
+ "Composite pass display name">,
+ Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
+ "Composite pass inner pipeline">,
+ Option<"maxIter", "max-iterations", "int", /*default=*/"10",
+ "Maximum number of iterations if inner pipeline">,
+ ];
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
add_mlir_library(MLIRTransforms
Canonicalizer.cpp
+ CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..b388a28da6424f
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,105 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct CompositeFixedPointPass final
+ : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
+ using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
+
+ CompositeFixedPointPass(
+ std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
+ int maxIterations) {
+ name = std::move(name_);
+ maxIter = maxIterations;
+ populateFunc(dynamicPM);
+
+ llvm::raw_string_ostream os(pipelineStr);
+ dynamicPM.printAsTextualPipeline(os);
+ }
+
+ LogicalResult initializeOptions(
+ StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler) override {
+ if (failed(CompositeFixedPointPassBase::initializeOptions(options,
+ errorHandler)))
+ return failure();
+
+ if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
+ return errorHandler("Failed to parse composite pass pipeline");
+
+ return success();
+ }
+
+ LogicalResult initialize(MLIRContext *context) override {
+ if (maxIter <= 0)
+ return emitError(UnknownLoc::get(context))
+ << "Invalid maxIterations value: " << maxIter << "\n";
+
+ return success();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ dynamicPM.getDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ auto op = getOperation();
+ OperationFingerPrint fp(op);
+
+ int currentIter = 0;
+ int maxIterVal = maxIter;
+ while (true) {
+ if (failed(runPipeline(dynamicPM, op)))
+ return signalPassFailure();
+
+ if (currentIter++ >= maxIterVal) {
+ op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+ "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
+ " iterations");
+ break;
+ }
+
+ OperationFingerPrint newFp(op);
+ if (newFp == fp)
+ break;
+
+ fp = newFp;
+ }
+ }
+
+protected:
+ llvm::StringRef getName() const override { return name; }
+
+private:
+ OpPassManager dynamicPM;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
+ std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
+ int maxIterations) {
+
+ return std::make_unique<CompositeFixedPointPass>(std::move(name),
+ populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..829470c2c9aa64
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+ %0 = arith.constant 1.5 : f32
+ return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
+ TestCompositePass.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..5c0d93cc0d64ec
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,38 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+ registerPassPipeline(
+ "test-composite-fixed-point-pass", "Test composite pass",
+ [](OpPassManager &pm, StringRef optionsStr,
+ function_ref<LogicalResult(const Twine &)> errorHandler) {
+ if (!optionsStr.empty())
+ return failure();
+
+ pm.addPass(createCompositeFixedPointPass(
+ "TestCompositePass", [](OpPassManager &p) {
+ p.addPass(createCanonicalizerPass());
+ p.addPass(createCSEPass());
+ }));
+ return success();
+ },
+ [](function_ref<void(const detail::PassOptions &)>) {});
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
void registerVectorizerTestPass();
namespace test {
+void registerTestCompositePass();
void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
registerVectorizerTestPass();
registerTosaTestQuantUtilAPIPass();
+ mlir::test::registerTestCompositePass();
mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerInliner();
More information about the Mlir-commits
mailing list