[Mlir-commits] [mlir] [mlir] Add `SelectPass` (PR #130409)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 8 12:14:57 PST 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/130409

>From 0303b3c511c4e85e012412142702c4466d3ab01d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 12:54:27 +0100
Subject: [PATCH 1/6] [mlir] Add `SelectPass`

`SelectPass` allows to dynamically select the pass pipeline based on attribute value attached to some top-level op.
---
 mlir/include/mlir/Transforms/Passes.h  |   8 ++
 mlir/include/mlir/Transforms/Passes.td |  19 ++++
 mlir/lib/Transforms/CMakeLists.txt     |   1 +
 mlir/lib/Transforms/SelectPass.cpp     | 132 +++++++++++++++++++++++++
 mlir/test/Transforms/select-pass.mlir  |  24 +++++
 5 files changed, 184 insertions(+)
 create mode 100644 mlir/lib/Transforms/SelectPass.cpp
 create mode 100644 mlir/test/Transforms/select-pass.mlir

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..e521705371b0b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_SELECTPASS
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
@@ -139,6 +140,13 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
     int maxIterations = 10);
 
+/// Creates select pass which allows to run multiple different set of passes
+/// based on attribute value on some top-level op.
+std::unique_ptr<Pass> createSelectPass(
+    std::string name, std::string selectCondName,
+    ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+        populateFuncs);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..846dffb89e8f7 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
   ];
 }
 
+def SelectPass : Pass<"select-pass"> {
+  let summary = "Select pass";
+  let description = [{
+    Select pass allows to run multiple different set of passes based on
+    attribute value on some top-level op.
+  }];
+
+  let options = [
+    Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"",
+           "Select pass display name">,
+    Option<"selectCondName", "select-cond-name", "std::string", "\"select\"",
+           "Attribute name used for condition">,
+    ListOption<"selectValues", "select-values", "std::string",
+               "Values used to check select condition">,
+    ListOption<"selectPipelines", "select-pipelines", "std::string",
+               "Pipelines, assotiated with corresponding select values">,
+  ];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3a8088bccf299..b94e390b627d6 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
+  SelectPass.cpp
   SROA.cpp
   StripDebugInfo.cpp
   SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
new file mode 100644
index 0000000000000..750a6617fe9b9
--- /dev/null
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -0,0 +1,132 @@
+//===- SelectPass.cpp - Select pass code ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// SelectPass allows to run multiple different set of passes based on attribute
+// value on some top-level op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SELECTPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct SelectPass final : public impl::SelectPassBase<SelectPass> {
+  using SelectPassBase::SelectPassBase;
+
+  SelectPass(
+      std::string name_, std::string selectCondName_,
+      ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+          populateFuncs) {
+    name = std::move(name_);
+    selectCondName = std::move(selectCondName_);
+
+    SmallVector<std::string> selectVals;
+    SmallVector<std::string> selectPpls;
+    selectVals.reserve(populateFuncs.size());
+    selectPpls.reserve(populateFuncs.size());
+    selectPassManagers.reserve(populateFuncs.size());
+    for (auto &&[name, populate] : populateFuncs) {
+      selectVals.emplace_back(name);
+
+      auto &pm = selectPassManagers.emplace_back();
+      populate(pm);
+
+      llvm::raw_string_ostream os(selectPpls.emplace_back());
+      pm.printAsTextualPipeline(os);
+    }
+
+    selectValues = selectVals;
+    selectPipelines = selectPpls;
+  }
+
+  LogicalResult initializeOptions(
+      StringRef options,
+      function_ref<LogicalResult(const Twine &)> errorHandler) override {
+    if (failed(SelectPassBase::initializeOptions(options, errorHandler)))
+      return failure();
+
+    if (selectCondName.empty())
+      return errorHandler("Invalid select-cond-name");
+
+    if (selectValues.size() != selectPipelines.size())
+      return errorHandler("Values and pipelines size mismatch");
+
+    selectPassManagers.resize(selectPipelines.size());
+
+    for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
+      if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
+        return errorHandler("Failed to parse pipeline");
+    }
+
+    return success();
+  }
+
+  LogicalResult initialize(MLIRContext *context) override {
+    condAttrName = StringAttr::get(context, selectCondName);
+
+    selectAttrs.reserve(selectAttrs.size());
+    for (StringRef value : selectValues)
+      selectAttrs.emplace_back(StringAttr::get(context, value));
+
+    return success();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    for (const OpPassManager &pipeline : selectPassManagers)
+      pipeline.getDependentDialects(registry);
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    Attribute condAttrValue = op->getAttr(condAttrName);
+    if (!condAttrValue) {
+      op->emitError("Condition attribute not present: ") << condAttrName;
+      return signalPassFailure();
+    }
+
+    for (auto &&[value, pm] :
+         llvm::zip_equal(selectAttrs, selectPassManagers)) {
+      if (value != condAttrValue)
+        continue;
+
+      if (failed(runPipeline(pm, op)))
+        return signalPassFailure();
+
+      return;
+    }
+
+    op->emitError("Unhandled condition value: ") << condAttrValue;
+    return signalPassFailure();
+  }
+
+protected:
+  StringRef getName() const override { return name; }
+
+private:
+  StringAttr condAttrName;
+  SmallVector<Attribute> selectAttrs;
+  SmallVector<OpPassManager> selectPassManagers;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createSelectPass(
+    std::string name, std::string selectCondName,
+    ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
+        populateFuncs) {
+  return std::make_unique<SelectPass>(std::move(name),
+                                      std::move(selectCondName), populateFuncs);
+}
diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir
new file mode 100644
index 0000000000000..fb93486b94ed7
--- /dev/null
+++ b/mlir/test/Transforms/select-pass.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \
+// RUN:     name=TestSelectPass \
+// RUN:     select-cond-name=test.attr \
+// RUN:     select-values=rocdl,nvvm \
+// RUN:     select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
+// RUN:     }))' -split-input-file | FileCheck %s
+
+gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
+// CHECK-LABEL: func @foo()
+// CHECK: rocdl.workitem.id.x
+  func.func @foo() -> index {
+    %0 = gpu.thread_id x
+    return %0 : index
+  }
+}
+
+gpu.module @nvvm_module attributes {test.attr = "nvvm"} {
+// CHECK-LABEL: func @bar()
+// CHECK: nvvm.read.ptx.sreg.tid.x
+  func.func @bar() -> index {
+    %0 = gpu.thread_id x
+    return %0 : index
+  }
+}

>From 6b62cad4fb4929190a9be0df5eba8fae3ae66f43 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 19:46:32 +0100
Subject: [PATCH 2/6] clarify doc

---
 mlir/include/mlir/Transforms/Passes.h  | 4 ++--
 mlir/include/mlir/Transforms/Passes.td | 4 ++--
 mlir/lib/Transforms/SelectPass.cpp     | 4 ++--
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index e521705371b0b..c808e1bedc8de 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -140,8 +140,8 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
     std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
     int maxIterations = 10);
 
-/// Creates select pass which allows to run multiple different set of passes
-/// based on attribute value on some top-level op.
+/// Creates select pass, which dynamically selects pass pipeline to run based on
+/// root op attribute.
 std::unique_ptr<Pass> createSelectPass(
     std::string name, std::string selectCondName,
     ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 846dffb89e8f7..b707e1edf3d6f 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -589,8 +589,8 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
 def SelectPass : Pass<"select-pass"> {
   let summary = "Select pass";
   let description = [{
-    Select pass allows to run multiple different set of passes based on
-    attribute value on some top-level op.
+    Select pass dynamically selects pass pipeline to run based on root op
+    attribute.
   }];
 
   let options = [
diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index 750a6617fe9b9..e538a944c97c9 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// SelectPass allows to run multiple different set of passes based on attribute
-// value on some top-level op.
+// SelectPass dynamically selects pass pipeline to run based on root op
+// attribute.
 //
 //===----------------------------------------------------------------------===//
 

>From c81b8f2b5f7c142f1e37113cdad3f59f55d13dde Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 19:47:36 +0100
Subject: [PATCH 3/6] remove -split-input-file

---
 mlir/test/Transforms/select-pass.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir
index fb93486b94ed7..42340f957b4fc 100644
--- a/mlir/test/Transforms/select-pass.mlir
+++ b/mlir/test/Transforms/select-pass.mlir
@@ -3,7 +3,7 @@
 // RUN:     select-cond-name=test.attr \
 // RUN:     select-values=rocdl,nvvm \
 // RUN:     select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
-// RUN:     }))' -split-input-file | FileCheck %s
+// RUN:     }))' | FileCheck %s
 
 gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
 // CHECK-LABEL: func @foo()

>From 85586b135f0861e8dab826c0032da7837d18a437 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 19:53:13 +0100
Subject: [PATCH 4/6] update err messages

---
 mlir/lib/Transforms/SelectPass.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index e538a944c97c9..c895fb9d0cdd5 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -60,16 +60,16 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
       return failure();
 
     if (selectCondName.empty())
-      return errorHandler("Invalid select-cond-name");
+      return errorHandler("invalid select-cond-name");
 
     if (selectValues.size() != selectPipelines.size())
-      return errorHandler("Values and pipelines size mismatch");
+      return errorHandler("values and pipelines size mismatch");
 
     selectPassManagers.resize(selectPipelines.size());
 
     for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
       if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
-        return errorHandler("Failed to parse pipeline");
+        return errorHandler("failed to parse pipeline");
     }
 
     return success();
@@ -94,7 +94,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
     Operation *op = getOperation();
     Attribute condAttrValue = op->getAttr(condAttrName);
     if (!condAttrValue) {
-      op->emitError("Condition attribute not present: ") << condAttrName;
+      op->emitError("condition attribute not present: ") << condAttrName;
       return signalPassFailure();
     }
 
@@ -109,7 +109,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
       return;
     }
 
-    op->emitError("Unhandled condition value: ") << condAttrValue;
+    op->emitError("unhandled condition value: ") << condAttrValue;
     return signalPassFailure();
   }
 

>From 0d531d591d559c1b3653036cd6fe030586fefd20 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 20:38:31 +0100
Subject: [PATCH 5/6] TODO

---
 mlir/lib/Transforms/SelectPass.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index c895fb9d0cdd5..a025fd8847877 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -109,6 +109,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
       return;
     }
 
+    // TODO: add a default pipeline option.
     op->emitError("unhandled condition value: ") << condAttrValue;
     return signalPassFailure();
   }

>From 1af8b3eea109cf4f6004944050b7946a3f8a7eff Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 8 Mar 2025 21:14:37 +0100
Subject: [PATCH 6/6] improve err msg

---
 mlir/lib/Transforms/SelectPass.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp
index a025fd8847877..60b956422a9d0 100644
--- a/mlir/lib/Transforms/SelectPass.cpp
+++ b/mlir/lib/Transforms/SelectPass.cpp
@@ -60,7 +60,7 @@ struct SelectPass final : public impl::SelectPassBase<SelectPass> {
       return failure();
 
     if (selectCondName.empty())
-      return errorHandler("invalid select-cond-name");
+      return errorHandler("select-cond-name is empty");
 
     if (selectValues.size() != selectPipelines.size())
       return errorHandler("values and pipelines size mismatch");



More information about the Mlir-commits mailing list