[Mlir-commits] [mlir] [mlir] Add `SelectPass` (PR #130409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 8 04:00:10 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
`SelectPass` allows to dynamically select the pass pipeline based on attribute value attached to some top-level op.
One of usage examples is IR to have multiple `gpu.module`s for different vendors, each requiring different lowering pipeline (see test).
---
Full diff: https://github.com/llvm/llvm-project/pull/130409.diff
5 Files Affected:
- (modified) mlir/include/mlir/Transforms/Passes.h (+8)
- (modified) mlir/include/mlir/Transforms/Passes.td (+19)
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Transforms/SelectPass.cpp (+132)
- (added) mlir/test/Transforms/select-pass.mlir (+24)
``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..e521705371b0b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_SELECTPASS
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -139,6 +140,13 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);
+/// Creates select pass which allows to run multiple different set of passes
+/// based on attribute value on some top-level op.
+std::unique_ptr<Pass> createSelectPass(
+ std::string name, std::string selectCondName,
+ ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+ populateFuncs);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..846dffb89e8f7 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
];
}
+def SelectPass : Pass<"select-pass"> {
+ let summary = "Select pass";
+ let description = [{
+ Select pass allows to run multiple different set of passes based on
+ attribute value on some top-level op.
+ }];
+
+ let options = [
+ Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"",
+ "Select pass display name">,
+ Option<"selectCondName", "select-cond-name", "std::string", "\"select\"",
+ "Attribute name used for condition">,
+ ListOption<"selectValues", "select-values", "std::string",
+ "Values used to check select condition">,
+ ListOption<"selectPipelines", "select-pipelines", "std::string",
+ "Pipelines, assotiated with corresponding select values">,
+ ];
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3a8088bccf299..b94e390b627d6 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
+ SelectPass.cpp
SROA.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
new file mode 100644
index 0000000000000..750a6617fe9b9
--- /dev/null
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -0,0 +1,132 @@
+//===- SelectPass.cpp - Select pass code ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// SelectPass allows to run multiple different set of passes based on attribute
+// value on some top-level op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SELECTPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct SelectPass final : public impl::SelectPassBase<SelectPass> {
+ using SelectPassBase::SelectPassBase;
+
+ SelectPass(
+ std::string name_, std::string selectCondName_,
+ ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+ populateFuncs) {
+ name = std::move(name_);
+ selectCondName = std::move(selectCondName_);
+
+ SmallVector<std::string> selectVals;
+ SmallVector<std::string> selectPpls;
+ selectVals.reserve(populateFuncs.size());
+ selectPpls.reserve(populateFuncs.size());
+ selectPassManagers.reserve(populateFuncs.size());
+ for (auto &&[name, populate] : populateFuncs) {
+ selectVals.emplace_back(name);
+
+ auto &pm = selectPassManagers.emplace_back();
+ populate(pm);
+
+ llvm::raw_string_ostream os(selectPpls.emplace_back());
+ pm.printAsTextualPipeline(os);
+ }
+
+ selectValues = selectVals;
+ selectPipelines = selectPpls;
+ }
+
+ LogicalResult initializeOptions(
+ StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler) override {
+ if (failed(SelectPassBase::initializeOptions(options, errorHandler)))
+ return failure();
+
+ if (selectCondName.empty())
+ return errorHandler("Invalid select-cond-name");
+
+ if (selectValues.size() != selectPipelines.size())
+ return errorHandler("Values and pipelines size mismatch");
+
+ selectPassManagers.resize(selectPipelines.size());
+
+ for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
+ if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
+ return errorHandler("Failed to parse pipeline");
+ }
+
+ return success();
+ }
+
+ LogicalResult initialize(MLIRContext *context) override {
+ condAttrName = StringAttr::get(context, selectCondName);
+
+ selectAttrs.reserve(selectAttrs.size());
+ for (StringRef value : selectValues)
+ selectAttrs.emplace_back(StringAttr::get(context, value));
+
+ return success();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ for (const OpPassManager &pipeline : selectPassManagers)
+ pipeline.getDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ Attribute condAttrValue = op->getAttr(condAttrName);
+ if (!condAttrValue) {
+ op->emitError("Condition attribute not present: ") << condAttrName;
+ return signalPassFailure();
+ }
+
+ for (auto &&[value, pm] :
+ llvm::zip_equal(selectAttrs, selectPassManagers)) {
+ if (value != condAttrValue)
+ continue;
+
+ if (failed(runPipeline(pm, op)))
+ return signalPassFailure();
+
+ return;
+ }
+
+ op->emitError("Unhandled condition value: ") << condAttrValue;
+ return signalPassFailure();
+ }
+
+protected:
+ StringRef getName() const override { return name; }
+
+private:
+ StringAttr condAttrName;
+ SmallVector<Attribute> selectAttrs;
+ SmallVector<OpPassManager> selectPassManagers;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createSelectPass(
+ std::string name, std::string selectCondName,
+ ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+ populateFuncs) {
+ return std::make_unique<SelectPass>(std::move(name),
+ std::move(selectCondName), populateFuncs);
+}
diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir
new file mode 100644
index 0000000000000..fb93486b94ed7
--- /dev/null
+++ b/mlir/test/Transforms/select-pass.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \
+// RUN: name=TestSelectPass \
+// RUN: select-cond-name=test.attr \
+// RUN: select-values=rocdl,nvvm \
+// RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
+// RUN: }))' -split-input-file | FileCheck %s
+
+gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
+// CHECK-LABEL: func @foo()
+// CHECK: rocdl.workitem.id.x
+ func.func @foo() -> index {
+ %0 = gpu.thread_id x
+ return %0 : index
+ }
+}
+
+gpu.module @nvvm_module attributes {test.attr = "nvvm"} {
+// CHECK-LABEL: func @bar()
+// CHECK: nvvm.read.ptx.sreg.tid.x
+ func.func @bar() -> index {
+ %0 = gpu.thread_id x
+ return %0 : index
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/130409
More information about the Mlir-commits
mailing list