[Mlir-commits] [mlir] [mlir][canonicalize] Add filter-dialects option (PR #193041)

Mehdi Amini llvmlistbot at llvm.org
Wed Apr 22 08:18:01 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/193041

>From ca8cdaf6dd4158958ce03c5ea6942382ac24db81 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 20 Apr 2026 10:04:05 -0700
Subject: [PATCH] [mlir][canonicalize] Add filter-dialects option
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Add a new `filter-dialects` list option to the canonicalize pass. When
provided, only canonicalization patterns from the listed dialects are
collected, and the named dialects are force-loaded via
getDependentDialects.

Loading flow: the Canonicalizer's getDependentDialects override calls
`registry.addDialectToPreload(name)` for each filter-dialect name, which
records the name in a new `dialectsToPreload` list on DialectRegistry.
The PassManager's pipeline-init then calls
`dependentDialects.preloadSelectDialects(ctx, emitError)`, which loads
each preload entry via `context->getOrLoadDialect(name)` — the real
allocator is resolved from the context's own registry (registered by
the tool) and the dialect is loaded before multi-threaded execution
begins. If a requested dialect has no registration in the context, a
diagnostic `"can't load dialect '<name>': missing registration?"` is
emitted.

DialectRegistry changes:
- New `addDialectToPreload(StringRef)` method: records a dialect name
  that should be loaded into the MLIRContext but whose allocator lives
  in the context's own registry. The registry itself does not load the
  dialect — it just carries the request.
- New `preloadSelectDialects(MLIRContext *, function_ref<InFlightDiagnostic()> = {})`
  method: loads every preload-registered dialect into the context,
  returning failure (and optionally emitting a diagnostic) for the
  first name that cannot be resolved.
- `getDialectNames()` is split into two accessors:
  * `getRegisteredDialectNames()` — names of allocator-backed entries
    from the registry map.
  * `getDialectsToPreload()` — preload-only entries added via
    `addDialectToPreload(StringRef)`.
- `appendTo`/`isSubsetOf` updated to carry preload entries through.
- Class-level doc updated to describe the preload feature and make
  clear the registry only carries the request — callers must invoke
  `preloadSelectDialects` to actually load.

Assisted-by: Claude Code
---
 mlir/include/mlir/IR/DialectRegistry.h        | 55 +++++++++++++++++--
 mlir/include/mlir/Transforms/Passes.td        |  6 +-
 .../Transforms/ShardingInterfaceImpl.cpp      |  2 +-
 mlir/lib/IR/Dialect.cpp                       | 35 +++++++++++-
 mlir/lib/IR/MLIRContext.cpp                   |  2 +-
 mlir/lib/Pass/Pass.cpp                        |  5 +-
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |  4 +-
 mlir/lib/Transforms/Canonicalizer.cpp         | 30 +++++++++-
 .../canonicalize-filter-dialects.mlir         | 29 ++++++++++
 9 files changed, 153 insertions(+), 15 deletions(-)
 create mode 100644 mlir/test/Transforms/canonicalize-filter-dialects.mlir

diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index b7d3e5d67e6d7..ea957b808260e 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -17,8 +17,11 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 
 #include <map>
+#include <string>
 #include <tuple>
 
 namespace mlir {
@@ -136,6 +139,14 @@ bool hasPromisedInterface(Dialect &dialect) {
 /// "available" from the dialects loaded in the Context. The parser in
 /// particular will lazily load dialects in the Context as operations are
 /// encountered.
+///
+/// In addition to allocator-backed registrations, the registry can also carry
+/// a set of dialect *names* that some caller has asked to be preloaded into
+/// the context (see `addDialectToPreload(StringRef)`). The registry itself
+/// does not load those dialects — it merely records the request; the
+/// allocator is expected to live in the MLIRContext's own registry, and
+/// actually loading them is the caller's responsibility via
+/// `preloadSelectDialects(MLIRContext *)`.
 class DialectRegistry {
   using MapTy =
       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
@@ -172,6 +183,13 @@ class DialectRegistry {
   void insert(TypeID typeID, StringRef name,
               const DialectAllocatorFunction &ctor);
 
+  /// Request that the dialect with the given name be preloaded into the
+  /// MLIRContext, without providing an allocator. Useful when a caller knows a
+  /// dialect is required but expects its allocator to be available in the
+  /// MLIRContext's own registry at load time (e.g. a pass learning dialect
+  /// names from string-valued options).
+  void addDialectToPreload(StringRef name);
+
   /// Add a new dynamic dialect constructor in the registry. The constructor
   /// provides as argument the created dynamic dialect, and is expected to
   /// register the dialect types, attributes, and ops, using the
@@ -190,19 +208,41 @@ class DialectRegistry {
       destination.insert(nameAndRegistrationIt.second.first,
                          nameAndRegistrationIt.first,
                          nameAndRegistrationIt.second.second);
+    for (const std::string &name : dialectsToPreload)
+      destination.addDialectToPreload(StringRef(name));
     // Merge the extensions.
     for (const auto &extension : extensions)
       destination.extensions.try_emplace(extension.first,
                                          extension.second->clone());
   }
 
-  /// Return the names of dialects known to this registry.
-  auto getDialectNames() const {
-    return llvm::map_range(
-        registry,
-        [](const MapTy::value_type &item) -> StringRef { return item.first; });
+  /// Return the names of dialects registered in this registry with an
+  /// allocator function. Does not include preload-only entries added via
+  /// `addDialectToPreload(StringRef)` — use `getDialectsToPreload()` for those.
+  SmallVector<StringRef> getRegisteredDialectNames() const {
+    SmallVector<StringRef> names;
+    names.reserve(registry.size());
+    for (const auto &item : registry)
+      names.push_back(item.first);
+    return names;
+  }
+
+  /// Return the names of dialects that should be preloaded into the context
+  /// but whose allocator is expected to be resolved from the context's own
+  /// registry (added via `addDialectToPreload(StringRef)`).
+  ArrayRef<std::string> getDialectsToPreload() const {
+    return dialectsToPreload;
   }
 
+  /// Load into `ctx` every dialect previously added via
+  /// `addDialectToPreload(StringRef)`. The allocator is resolved from the
+  /// context's own registry. On failure, if `emitError` is provided, it is
+  /// invoked to produce a diagnostic naming the offending dialect; otherwise
+  /// the failure is silent.
+  LogicalResult preloadSelectDialects(
+      MLIRContext *ctx,
+      function_ref<InFlightDiagnostic()> emitError = {}) const;
+
   /// Apply any held extensions that require the given dialect. Users are not
   /// expected to call this directly.
   void applyExtensions(Dialect *dialect) const;
@@ -261,6 +301,11 @@ class DialectRegistry {
 
 private:
   MapTy registry;
+  /// Names of dialects that should be preloaded into the MLIRContext but for
+  /// which no allocator has been registered here. The allocator is expected
+  /// to be resolved from the MLIRContext's own registry when the dialect is
+  /// loaded (e.g. via MLIRContext::getOrLoadDialect).
+  SmallVector<std::string> dialectsToPreload;
   llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
 };
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 3822d1d2a4156..1b08ec98baf06 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -54,7 +54,11 @@ def CanonicalizerPass : Pass<"canonicalize"> {
            "Run full CSE between each pattern-application iteration. "
            "CSE-driven changes trigger extra iterations, so this may push "
            "the iteration count up to max-iterations and affect convergence "
-           "under test-convergence.">
+           "under test-convergence.">,
+    ListOption<"filterDialects", "filter-dialects", "std::string",
+               "If non-empty, only collect canonicalization patterns from the"
+               " dialects with the given namespaces. The listed dialects are"
+               " force-loaded into the context as dependent dialects.">
   ] # RewritePassUtils.options;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 938608afacc40..68d491443bde9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -334,7 +334,7 @@ void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
     registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
                     tensor::TensorDialect>();
     ctx->appendDialectRegistry(registry);
-    for (StringRef name : registry.getDialectNames())
+    for (StringRef name : registry.getRegisteredDialectNames())
       ctx->getOrLoadDialect(name);
 
     registerOne<linalg::GenericOp>(ctx);
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 952619b4477a7..14b8c9c1bb6e3 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -225,6 +225,29 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
   }
 }
 
+LogicalResult DialectRegistry::preloadSelectDialects(
+    MLIRContext *ctx, function_ref<InFlightDiagnostic()> emitError) const {
+  for (const std::string &name : dialectsToPreload) {
+    if (!ctx->getOrLoadDialect(name)) {
+      if (emitError)
+        emitError() << "can't load dialect '" << name
+                    << "': missing registration?";
+      return failure();
+    }
+  }
+  return success();
+}
+
+void DialectRegistry::addDialectToPreload(StringRef name) {
+  // If we already have an allocator for this name, nothing to do: the existing
+  // registration will take care of loading the dialect.
+  if (registry.count(name))
+    return;
+  if (llvm::is_contained(dialectsToPreload, name))
+    return;
+  dialectsToPreload.emplace_back(name);
+}
+
 void DialectRegistry::insertDynamic(
     StringRef name, const DynamicDialectPopulationFunction &ctor) {
   // This TypeID marks dynamic dialects. We cannot give a TypeID for the
@@ -326,6 +349,14 @@ bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
     return false;
 
   // Check that the current dialects fully overlap with the dialects in 'rhs'.
-  return llvm::all_of(
-      registry, [&](const auto &it) { return rhs.registry.count(it.first); });
+  if (!llvm::all_of(registry, [&](const auto &it) {
+        return rhs.registry.count(it.first);
+      }))
+    return false;
+
+  // Check that all preload-only entries are known in 'rhs'.
+  return llvm::all_of(dialectsToPreload, [&](const std::string &name) {
+    return rhs.registry.count(name) ||
+           llvm::is_contained(rhs.dialectsToPreload, name);
+  });
 }
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 21891db11aa65..7b666d11a4a89 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -444,7 +444,7 @@ std::vector<Dialect *> MLIRContext::getLoadedDialects() {
 }
 std::vector<StringRef> MLIRContext::getAvailableDialects() {
   std::vector<StringRef> result;
-  for (auto dialect : impl->dialectsRegistry.getDialectNames())
+  for (auto dialect : impl->dialectsRegistry.getRegisteredDialectNames())
     result.push_back(dialect);
   return result;
 }
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 124dda4740d5b..6162591ca5f74 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -1053,10 +1053,13 @@ LogicalResult PassManager::run(Operation *op) {
   DialectRegistry dependentDialects;
   getDependentDialects(dependentDialects);
   context->appendDialectRegistry(dependentDialects);
-  for (StringRef name : dependentDialects.getDialectNames()) {
+  for (StringRef name : dependentDialects.getRegisteredDialectNames()) {
     LDBG(2) << "Loading dialect: " << name;
     context->getOrLoadDialect(name);
   }
+  if (failed(dependentDialects.preloadSelectDialects(
+          context, [&]() { return emitError(op->getLoc()); })))
+    return failure();
 
   // Before running, make sure to finalize the pipeline pass list.
   if (failed(getImpl().finalizePassList(context))) {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 56d2eb0f80185..8b49258e135d2 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -707,7 +707,7 @@ std::string mlir::registerCLIOptions(llvm::StringRef toolName,
   std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
   {
     llvm::raw_string_ostream os(helpHeader);
-    interleaveComma(registry.getDialectNames(), os,
+    interleaveComma(registry.getRegisteredDialectNames(), os,
                     [&](auto name) { os << name; });
   }
   return helpHeader;
@@ -735,7 +735,7 @@ mlir::registerAndParseCLIOptions(int argc, char **argv,
 
 static LogicalResult printRegisteredDialects(DialectRegistry &registry) {
   llvm::outs() << "Available Dialects: ";
-  interleave(registry.getDialectNames(), llvm::outs(), ",");
+  interleave(registry.getRegisteredDialectNames(), llvm::outs(), ",");
   llvm::outs() << "\n";
   return success();
 }
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index aa3b1152f1181..0bc82ec2025ca 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,8 +13,10 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/IR/DialectRegistry.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CANONICALIZERPASS
@@ -40,6 +42,13 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
     this->enabledPatterns = enabledPatterns;
   }
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    // Force-load any dialects named via the `filter-dialects` option. The
+    // allocator is resolved later from the MLIRContext's own registry.
+    for (const std::string &name : filterDialects)
+      registry.addDialectToPreload(StringRef(name));
+  }
+
   /// Initialize the canonicalizer by building the set of patterns used during
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
@@ -50,11 +59,28 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
     config.setMaxNumRewrites(maxNumRewrites);
     config.enableCSEBetweenIterations(cseBetweenIterations);
 
+    llvm::DenseSet<TypeID> allowedDialects;
+    for (const std::string &name : filterDialects) {
+      Dialect *dialect = context->getLoadedDialect(name);
+      if (!dialect) {
+        return emitError(UnknownLoc::get(context))
+               << "canonicalize filter-dialects: dialect '" << name
+               << "' is not loaded in the context";
+      }
+      allowedDialects.insert(dialect->getTypeID());
+    }
+    auto isAllowed = [&](Dialect *dialect) {
+      return allowedDialects.empty() ||
+             allowedDialects.contains(dialect->getTypeID());
+    };
+
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
-      dialect->getCanonicalizationPatterns(owningPatterns);
+      if (isAllowed(dialect))
+        dialect->getCanonicalizationPatterns(owningPatterns);
     for (RegisteredOperationName op : context->getRegisteredOperations())
-      op.getCanonicalizationPatterns(owningPatterns, context);
+      if (isAllowed(&op.getDialect()))
+        op.getCanonicalizationPatterns(owningPatterns, context);
 
     patterns = std::make_shared<FrozenRewritePatternSet>(
         std::move(owningPatterns), disabledPatterns, enabledPatterns);
diff --git a/mlir/test/Transforms/canonicalize-filter-dialects.mlir b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
new file mode 100644
index 0000000000000..c84c98c21e505
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=arith}))' | FileCheck %s --check-prefix=ARITH
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=func}))' | FileCheck %s --check-prefix=FUNC
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s --check-prefix=ALL
+// RUN: not mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=does_not_exist}))' 2>&1 | FileCheck %s --check-prefix=ERR
+
+// The `SubIRHSAddConstant` arith pattern rewrites `subi(addi(x, c0), c1)` into
+// `addi(x, c0 - c1)`. The pattern only fires when arith canonicalizations are
+// loaded.
+
+// ARITH-LABEL: func @pattern_test
+// ARITH-NOT:     arith.subi
+// ARITH:         arith.addi %{{.*}}, %[[C:.*]]
+
+// FUNC-LABEL: func @pattern_test
+// FUNC:         arith.addi
+// FUNC:         arith.subi
+
+// ALL-LABEL: func @pattern_test
+// ALL-NOT:     arith.subi
+// ALL:         arith.addi %{{.*}}, %[[C:.*]]
+
+// ERR: can't load dialect 'does_not_exist': missing registration?
+func.func @pattern_test(%a: i32) -> i32 {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2 : i32
+  %add = arith.addi %a, %c1 : i32
+  %sub = arith.subi %add, %c2 : i32
+  return %sub : i32
+}



More information about the Mlir-commits mailing list