[Mlir-commits] [mlir] [mlir] Add `SelectPass` (PR #130409)
Ivan Butygin
llvmlistbot at llvm.org
Sat Mar 8 12:14:57 PST 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/130409
>From 0303b3c511c4e85e012412142702c4466d3ab01d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 12:54:27 +0100
Subject: [PATCH 1/6] [mlir] Add `SelectPass`
`SelectPass` allows to dynamically select the pass pipeline based on attribute value attached to some top-level op.
---
mlir/include/mlir/Transforms/Passes.h | 8 ++
mlir/include/mlir/Transforms/Passes.td | 19 ++++
mlir/lib/Transforms/CMakeLists.txt | 1 +
mlir/lib/Transforms/SelectPass.cpp | 132 +++++++++++++++++++++++++
mlir/test/Transforms/select-pass.mlir | 24 +++++
5 files changed, 184 insertions(+)
create mode 100644 mlir/lib/Transforms/SelectPass.cpp
create mode 100644 mlir/test/Transforms/select-pass.mlir
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..e521705371b0b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_SELECTPASS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -139,6 +140,13 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);
+/// Creates select pass which allows to run multiple different set of passes
+/// based on attribute value on some top-level op.
+std::unique_ptr<Pass> createSelectPass(
+ std::string name, std::string selectCondName,
+ ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+ populateFuncs);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..846dffb89e8f7 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
+def SelectPass : Pass<"select-pass"> {
+ let summary = "Select pass";
+ let description = [{
+ Select pass allows to run multiple different set of passes based on
+ attribute value on some top-level op.
+ }];
+
+ let options = [
+ Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"",
+ "Select pass display name">,
+ Option<"selectCondName", "select-cond-name", "std::string", "\"select\"",
+ "Attribute name used for condition">,
+ ListOption<"selectValues", "select-values", "std::string",
+ "Values used to check select condition">,
+ ListOption<"selectPipelines", "select-pipelines", "std::string",
+ "Pipelines, assotiated with corresponding select values">,
+ ];
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3a8088bccf299..b94e390b627d6 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
+ SelectPass.cpp
SROA.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
new file mode 100644
index 0000000000000..750a6617fe9b9
--- /dev/null
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -0,0 +1,132 @@
+//===- SelectPass.cpp - Select pass code ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// SelectPass allows to run multiple different set of passes based on attribute
+// value on some top-level op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SELECTPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct SelectPass final : public impl::SelectPassBase<SelectPass> {
+ using SelectPassBase::SelectPassBase;
+
+ SelectPass(
+ std::string name_, std::string selectCondName_,
+ ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+ populateFuncs) {
+ name = std::move(name_);
+ selectCondName = std::move(selectCondName_);
+
+ SmallVector<std::string> selectVals;
+ SmallVector<std::string> selectPpls;
+ selectVals.reserve(populateFuncs.size());
+ selectPpls.reserve(populateFuncs.size());
+ selectPassManagers.reserve(populateFuncs.size());
+ for (auto &&[name, populate] : populateFuncs) {
+ selectVals.emplace_back(name);
+
+ auto &pm = selectPassManagers.emplace_back();
+ populate(pm);
+
+ llvm::raw_string_ostream os(selectPpls.emplace_back());
+ pm.printAsTextualPipeline(os);
+ }
+
+ selectValues = selectVals;
+ selectPipelines = selectPpls;
+ }
+
+ LogicalResult initializeOptions(
+ StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler) override {
+ if (failed(SelectPassBase::initializeOptions(options, errorHandler)))
+ return failure();
+
+ if (selectCondName.empty())
+ return errorHandler("Invalid select-cond-name");
+
+ if (selectValues.size() != selectPipelines.size())
+ return errorHandler("Values and pipelines size mismatch");
+
+ selectPassManagers.resize(selectPipelines.size());
+
+ for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
+ if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
+ return errorHandler("Failed to parse pipeline");
+ }
+
+ return success();
+ }
+
+ LogicalResult initialize(MLIRContext *context) override {
+ condAttrName = StringAttr::get(context, selectCondName);
+
+ selectAttrs.reserve(selectAttrs.size());
+ for (StringRef value : selectValues)
+ selectAttrs.emplace_back(StringAttr::get(context, value));
+
+ return success();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ for (const OpPassManager &pipeline : selectPassManagers)
+ pipeline.getDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ Attribute condAttrValue = op->getAttr(condAttrName);
+ if (!condAttrValue) {
+ op->emitError("Condition attribute not present: ") << condAttrName;
+ return signalPassFailure();
+ }
+
+ for (auto &&[value, pm] :
+ llvm::zip_equal(selectAttrs, selectPassManagers)) {
+ if (value != condAttrValue)
+ continue;
+
+ if (failed(runPipeline(pm, op)))
+ return signalPassFailure();
+
+ return;
+ }
+
+ op->emitError("Unhandled condition value: ") << condAttrValue;
+ return signalPassFailure();
+ }
+
+protected:
+ StringRef getName() const override { return name; }
+
+private:
+ StringAttr condAttrName;
+ SmallVector<Attribute> selectAttrs;
+ SmallVector<OpPassManager> selectPassManagers;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createSelectPass(
+ std::string name, std::string selectCondName,
+ ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+ populateFuncs) {
+ return std::make_unique<SelectPass>(std::move(name),
+ std::move(selectCondName), populateFuncs);
+}
diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir
new file mode 100644
index 0000000000000..fb93486b94ed7
--- /dev/null
+++ b/mlir/test/Transforms/select-pass.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \
+// RUN: name=TestSelectPass \
+// RUN: select-cond-name=test.attr \
+// RUN: select-values=rocdl,nvvm \
+// RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
+// RUN: }))' -split-input-file | FileCheck %s
+
+gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
+// CHECK-LABEL: func @foo()
+// CHECK: rocdl.workitem.id.x
+ func.func @foo() -> index {
+ %0 = gpu.thread_id x
+ return %0 : index
+ }
+}
+
+gpu.module @nvvm_module attributes {test.attr = "nvvm"} {
+// CHECK-LABEL: func @bar()
+// CHECK: nvvm.read.ptx.sreg.tid.x
+ func.func @bar() -> index {
+ %0 = gpu.thread_id x
+ return %0 : index
+ }
+}
>From 6b62cad4fb4929190a9be0df5eba8fae3ae66f43 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 19:46:32 +0100
Subject: [PATCH 2/6] clarify doc
---
mlir/include/mlir/Transforms/Passes.h | 4 ++--
mlir/include/mlir/Transforms/Passes.td | 4 ++--
mlir/lib/Transforms/SelectPass.cpp | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index e521705371b0b..c808e1bedc8de 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -140,8 +140,8 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);
-/// Creates select pass which allows to run multiple different set of passes
-/// based on attribute value on some top-level op.
+/// Creates select pass, which dynamically selects pass pipeline to run based on
+/// root op attribute.
std::unique_ptr<Pass> createSelectPass(
std::string name, std::string selectCondName,
ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 846dffb89e8f7..b707e1edf3d6f 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -589,8 +589,8 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
def SelectPass : Pass<"select-pass"> {
let summary = "Select pass";
let description = [{
- Select pass allows to run multiple different set of passes based on
- attribute value on some top-level op.
+ Select pass dynamically selects pass pipeline to run based on root op
+ attribute.
}];
let options = [
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index 750a6617fe9b9..e538a944c97c9 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// SelectPass allows to run multiple different set of passes based on attribute
-// value on some top-level op.
+// SelectPass dynamically selects pass pipeline to run based on root op
+// attribute.
//
//===----------------------------------------------------------------------===//
>From c81b8f2b5f7c142f1e37113cdad3f59f55d13dde Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 19:47:36 +0100
Subject: [PATCH 3/6] remove -split-input-file
---
mlir/test/Transforms/select-pass.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir
index fb93486b94ed7..42340f957b4fc 100644
--- a/mlir/test/Transforms/select-pass.mlir
+++ b/mlir/test/Transforms/select-pass.mlir
@@ -3,7 +3,7 @@
// RUN: select-cond-name=test.attr \
// RUN: select-values=rocdl,nvvm \
// RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
-// RUN: }))' -split-input-file | FileCheck %s
+// RUN: }))' | FileCheck %s
gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
// CHECK-LABEL: func @foo()
>From 85586b135f0861e8dab826c0032da7837d18a437 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 19:53:13 +0100
Subject: [PATCH 4/6] update err messages
---
mlir/lib/Transforms/SelectPass.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index e538a944c97c9..c895fb9d0cdd5 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -60,16 +60,16 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
return failure();
if (selectCondName.empty())
- return errorHandler("Invalid select-cond-name");
+ return errorHandler("invalid select-cond-name");
if (selectValues.size() != selectPipelines.size())
- return errorHandler("Values and pipelines size mismatch");
+ return errorHandler("values and pipelines size mismatch");
selectPassManagers.resize(selectPipelines.size());
for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
- return errorHandler("Failed to parse pipeline");
+ return errorHandler("failed to parse pipeline");
}
return success();
@@ -94,7 +94,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
Operation *op = getOperation();
Attribute condAttrValue = op->getAttr(condAttrName);
if (!condAttrValue) {
- op->emitError("Condition attribute not present: ") << condAttrName;
+ op->emitError("condition attribute not present: ") << condAttrName;
return signalPassFailure();
}
@@ -109,7 +109,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
return;
}
- op->emitError("Unhandled condition value: ") << condAttrValue;
+ op->emitError("unhandled condition value: ") << condAttrValue;
return signalPassFailure();
}
>From 0d531d591d559c1b3653036cd6fe030586fefd20 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 20:38:31 +0100
Subject: [PATCH 5/6] TODO
---
mlir/lib/Transforms/SelectPass.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index c895fb9d0cdd5..a025fd8847877 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -109,6 +109,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
return;
}
+ // TODO: add a default pipeline option.
op->emitError("unhandled condition value: ") << condAttrValue;
return signalPassFailure();
}
>From 1af8b3eea109cf4f6004944050b7946a3f8a7eff Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 21:14:37 +0100
Subject: [PATCH 6/6] improve err msg
---
mlir/lib/Transforms/SelectPass.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index a025fd8847877..60b956422a9d0 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -60,7 +60,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
return failure();
if (selectCondName.empty())
- return errorHandler("invalid select-cond-name");
+ return errorHandler("select-cond-name is empty");
if (selectValues.size() != selectPipelines.size())
return errorHandler("values and pipelines size mismatch");
More information about the Mlir-commits
mailing list