[Mlir-commits] [mlir] [mlir] Add `SelectPass` (PR #130409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 8 04:00:10 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

`SelectPass` allows to dynamically select the pass pipeline based on attribute value attached to some top-level op.

One of usage examples is IR to have multiple `gpu.module`s for different vendors, each requiring different lowering pipeline (see test).

---
Full diff: https://github.com/llvm/llvm-project/pull/130409.diff


5 Files Affected:

- (modified) mlir/include/mlir/Transforms/Passes.h (+8) 
- (modified) mlir/include/mlir/Transforms/Passes.td (+19) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Transforms/SelectPass.cpp (+132) 
- (added) mlir/test/Transforms/select-pass.mlir (+24) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..e521705371b0b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_SELECTPASS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
@@ -139,6 +140,13 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
     int maxIterations = 10);
 
+/// Creates select pass which allows to run multiple different set of passes
+/// based on attribute value on some top-level op.
+std::unique_ptr<Pass> createSelectPass(
+    std::string name, std::string selectCondName,
+    ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+        populateFuncs);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..846dffb89e8f7 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
+def SelectPass : Pass<"select-pass"> {
+  let summary = "Select pass";
+  let description = [{
+    Select pass allows to run multiple different set of passes based on
+    attribute value on some top-level op.
+  }];
+
+  let options = [
+    Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"",
+           "Select pass display name">,
+    Option<"selectCondName", "select-cond-name", "std::string", "\"select\"",
+           "Attribute name used for condition">,
+    ListOption<"selectValues", "select-values", "std::string",
+               "Values used to check select condition">,
+    ListOption<"selectPipelines", "select-pipelines", "std::string",
+               "Pipelines, assotiated with corresponding select values">,
+  ];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3a8088bccf299..b94e390b627d6 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
+  SelectPass.cpp
   SROA.cpp
   StripDebugInfo.cpp
   SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
new file mode 100644
index 0000000000000..750a6617fe9b9
--- /dev/null
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -0,0 +1,132 @@
+//===- SelectPass.cpp - Select pass code ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// SelectPass allows to run multiple different set of passes based on attribute
+// value on some top-level op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SELECTPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct SelectPass final : public impl::SelectPassBase<SelectPass> {
+  using SelectPassBase::SelectPassBase;
+
+  SelectPass(
+      std::string name_, std::string selectCondName_,
+      ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+          populateFuncs) {
+    name = std::move(name_);
+    selectCondName = std::move(selectCondName_);
+
+    SmallVector<std::string> selectVals;
+    SmallVector<std::string> selectPpls;
+    selectVals.reserve(populateFuncs.size());
+    selectPpls.reserve(populateFuncs.size());
+    selectPassManagers.reserve(populateFuncs.size());
+    for (auto &&[name, populate] : populateFuncs) {
+      selectVals.emplace_back(name);
+
+      auto &pm = selectPassManagers.emplace_back();
+      populate(pm);
+
+      llvm::raw_string_ostream os(selectPpls.emplace_back());
+      pm.printAsTextualPipeline(os);
+    }
+
+    selectValues = selectVals;
+    selectPipelines = selectPpls;
+  }
+
+  LogicalResult initializeOptions(
+      StringRef options,
+      function_ref<LogicalResult(const Twine &)> errorHandler) override {
+    if (failed(SelectPassBase::initializeOptions(options, errorHandler)))
+      return failure();
+
+    if (selectCondName.empty())
+      return errorHandler("Invalid select-cond-name");
+
+    if (selectValues.size() != selectPipelines.size())
+      return errorHandler("Values and pipelines size mismatch");
+
+    selectPassManagers.resize(selectPipelines.size());
+
+    for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
+      if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
+        return errorHandler("Failed to parse pipeline");
+    }
+
+    return success();
+  }
+
+  LogicalResult initialize(MLIRContext *context) override {
+    condAttrName = StringAttr::get(context, selectCondName);
+
+    selectAttrs.reserve(selectAttrs.size());
+    for (StringRef value : selectValues)
+      selectAttrs.emplace_back(StringAttr::get(context, value));
+
+    return success();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    for (const OpPassManager &pipeline : selectPassManagers)
+      pipeline.getDependentDialects(registry);
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    Attribute condAttrValue = op->getAttr(condAttrName);
+    if (!condAttrValue) {
+      op->emitError("Condition attribute not present: ") << condAttrName;
+      return signalPassFailure();
+    }
+
+    for (auto &&[value, pm] :
+         llvm::zip_equal(selectAttrs, selectPassManagers)) {
+      if (value != condAttrValue)
+        continue;
+
+      if (failed(runPipeline(pm, op)))
+        return signalPassFailure();
+
+      return;
+    }
+
+    op->emitError("Unhandled condition value: ") << condAttrValue;
+    return signalPassFailure();
+  }
+
+protected:
+  StringRef getName() const override { return name; }
+
+private:
+  StringAttr condAttrName;
+  SmallVector<Attribute> selectAttrs;
+  SmallVector<OpPassManager> selectPassManagers;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createSelectPass(
+    std::string name, std::string selectCondName,
+    ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+        populateFuncs) {
+  return std::make_unique<SelectPass>(std::move(name),
+                                      std::move(selectCondName), populateFuncs);
+}
diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir
new file mode 100644
index 0000000000000..fb93486b94ed7
--- /dev/null
+++ b/mlir/test/Transforms/select-pass.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \
+// RUN:     name=TestSelectPass \
+// RUN:     select-cond-name=test.attr \
+// RUN:     select-values=rocdl,nvvm \
+// RUN:     select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
+// RUN:     }))' -split-input-file | FileCheck %s
+
+gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
+// CHECK-LABEL: func @foo()
+// CHECK: rocdl.workitem.id.x
+  func.func @foo() -> index {
+    %0 = gpu.thread_id x
+    return %0 : index
+  }
+}
+
+gpu.module @nvvm_module attributes {test.attr = "nvvm"} {
+// CHECK-LABEL: func @bar()
+// CHECK: nvvm.read.ptx.sreg.tid.x
+  func.func @bar() -> index {
+    %0 = gpu.thread_id x
+    return %0 : index
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/130409


More information about the Mlir-commits mailing list