[Mlir-commits] [mlir] [mlir][canonicalize] Add filter-dialects option (PR #193041)
Mehdi Amini
llvmlistbot at llvm.org
Wed Apr 22 08:18:01 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/193041
>From ca8cdaf6dd4158958ce03c5ea6942382ac24db81 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 20 Apr 2026 10:04:05 -0700
Subject: [PATCH] [mlir][canonicalize] Add filter-dialects option
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Add a new `filter-dialects` list option to the canonicalize pass. When
provided, only canonicalization patterns from the listed dialects are
collected, and the named dialects are force-loaded via
getDependentDialects.
Loading flow: the Canonicalizer's getDependentDialects override calls
`registry.addDialectToPreload(name)` for each filter-dialect name, which
records the name in a new `dialectsToPreload` list on DialectRegistry.
The PassManager's pipeline-init then calls
`dependentDialects.preloadSelectDialects(ctx, emitError)`, which loads
each preload entry via `context->getOrLoadDialect(name)` — the real
allocator is resolved from the context's own registry (registered by
the tool) and the dialect is loaded before multi-threaded execution
begins. If a requested dialect has no registration in the context, a
diagnostic `"can't load dialect '<name>': missing registration?"` is
emitted.
DialectRegistry changes:
- New `addDialectToPreload(StringRef)` method: records a dialect name
that should be loaded into the MLIRContext but whose allocator lives
in the context's own registry. The registry itself does not load the
dialect — it just carries the request.
- New `preloadSelectDialects(MLIRContext *, function_ref<InFlightDiagnostic()> = {})`
method: loads every preload-registered dialect into the context,
returning failure (and optionally emitting a diagnostic) for the
first name that cannot be resolved.
- `getDialectNames()` is split into two accessors:
* `getRegisteredDialectNames()` — names of allocator-backed entries
from the registry map.
* `getDialectsToPreload()` — preload-only entries added via
`addDialectToPreload(StringRef)`.
- `appendTo`/`isSubsetOf` updated to carry preload entries through.
- Class-level doc updated to describe the preload feature and make
clear the registry only carries the request — callers must invoke
`preloadSelectDialects` to actually load.
Assisted-by: Claude Code
---
mlir/include/mlir/IR/DialectRegistry.h | 55 +++++++++++++++++--
mlir/include/mlir/Transforms/Passes.td | 6 +-
.../Transforms/ShardingInterfaceImpl.cpp | 2 +-
mlir/lib/IR/Dialect.cpp | 35 +++++++++++-
mlir/lib/IR/MLIRContext.cpp | 2 +-
mlir/lib/Pass/Pass.cpp | 5 +-
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 4 +-
mlir/lib/Transforms/Canonicalizer.cpp | 30 +++++++++-
.../canonicalize-filter-dialects.mlir | 29 ++++++++++
9 files changed, 153 insertions(+), 15 deletions(-)
create mode 100644 mlir/test/Transforms/canonicalize-filter-dialects.mlir
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index b7d3e5d67e6d7..ea957b808260e 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -17,8 +17,11 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
#include <map>
+#include <string>
#include <tuple>
namespace mlir {
@@ -136,6 +139,14 @@ bool hasPromisedInterface(Dialect &dialect) {
/// "available" from the dialects loaded in the Context. The parser in
/// particular will lazily load dialects in the Context as operations are
/// encountered.
+///
+/// In addition to allocator-backed registrations, the registry can also carry
+/// a set of dialect *names* that some caller has asked to be preloaded into
+/// the context (see `addDialectToPreload(StringRef)`). The registry itself
+/// does not load those dialects — it merely records the request; the
+/// allocator is expected to live in the MLIRContext's own registry, and
+/// actually loading them is the caller's responsibility via
+/// `preloadSelectDialects(MLIRContext *)`.
class DialectRegistry {
using MapTy =
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
@@ -172,6 +183,13 @@ class DialectRegistry {
void insert(TypeID typeID, StringRef name,
const DialectAllocatorFunction &ctor);
+ /// Request that the dialect with the given name be preloaded into the
+ /// MLIRContext, without providing an allocator. Useful when a caller knows a
+ /// dialect is required but expects its allocator to be available in the
+ /// MLIRContext's own registry at load time (e.g. a pass learning dialect
+ /// names from string-valued options).
+ void addDialectToPreload(StringRef name);
+
/// Add a new dynamic dialect constructor in the registry. The constructor
/// provides as argument the created dynamic dialect, and is expected to
/// register the dialect types, attributes, and ops, using the
@@ -190,19 +208,41 @@ class DialectRegistry {
destination.insert(nameAndRegistrationIt.second.first,
nameAndRegistrationIt.first,
nameAndRegistrationIt.second.second);
+ for (const std::string &name : dialectsToPreload)
+ destination.addDialectToPreload(StringRef(name));
// Merge the extensions.
for (const auto &extension : extensions)
destination.extensions.try_emplace(extension.first,
extension.second->clone());
}
- /// Return the names of dialects known to this registry.
- auto getDialectNames() const {
- return llvm::map_range(
- registry,
- [](const MapTy::value_type &item) -> StringRef { return item.first; });
+ /// Return the names of dialects registered in this registry with an
+ /// allocator function. Does not include preload-only entries added via
+ /// `addDialectToPreload(StringRef)` — use `getDialectsToPreload()` for those.
+ SmallVector<StringRef> getRegisteredDialectNames() const {
+ SmallVector<StringRef> names;
+ names.reserve(registry.size());
+ for (const auto &item : registry)
+ names.push_back(item.first);
+ return names;
+ }
+
+ /// Return the names of dialects that should be preloaded into the context
+ /// but whose allocator is expected to be resolved from the context's own
+ /// registry (added via `addDialectToPreload(StringRef)`).
+ ArrayRef<std::string> getDialectsToPreload() const {
+ return dialectsToPreload;
}
+ /// Load into `ctx` every dialect previously added via
+ /// `addDialectToPreload(StringRef)`. The allocator is resolved from the
+ /// context's own registry. On failure, if `emitError` is provided, it is
+ /// invoked to produce a diagnostic naming the offending dialect; otherwise
+ /// the failure is silent.
+ LogicalResult preloadSelectDialects(
+ MLIRContext *ctx,
+ function_ref<InFlightDiagnostic()> emitError = {}) const;
+
/// Apply any held extensions that require the given dialect. Users are not
/// expected to call this directly.
void applyExtensions(Dialect *dialect) const;
@@ -261,6 +301,11 @@ class DialectRegistry {
private:
MapTy registry;
+ /// Names of dialects that should be preloaded into the MLIRContext but for
+ /// which no allocator has been registered here. The allocator is expected
+ /// to be resolved from the MLIRContext's own registry when the dialect is
+ /// loaded (e.g. via MLIRContext::getOrLoadDialect).
+ SmallVector<std::string> dialectsToPreload;
llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
};
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 3822d1d2a4156..1b08ec98baf06 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -54,7 +54,11 @@ def CanonicalizerPass : Pass<"canonicalize"> {
"Run full CSE between each pattern-application iteration. "
"CSE-driven changes trigger extra iterations, so this may push "
"the iteration count up to max-iterations and affect convergence "
- "under test-convergence.">
+ "under test-convergence.">,
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "If non-empty, only collect canonicalization patterns from the"
+ " dialects with the given namespaces. The listed dialects are"
+ " force-loaded into the context as dependent dialects.">
] # RewritePassUtils.options;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 938608afacc40..68d491443bde9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -334,7 +334,7 @@ void registerShardingInterfaceExternalModels(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
tensor::TensorDialect>();
ctx->appendDialectRegistry(registry);
- for (StringRef name : registry.getDialectNames())
+ for (StringRef name : registry.getRegisteredDialectNames())
ctx->getOrLoadDialect(name);
registerOne<linalg::GenericOp>(ctx);
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 952619b4477a7..14b8c9c1bb6e3 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -225,6 +225,29 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
}
}
+LogicalResult DialectRegistry::preloadSelectDialects(
+ MLIRContext *ctx, function_ref<InFlightDiagnostic()> emitError) const {
+ for (const std::string &name : dialectsToPreload) {
+ if (!ctx->getOrLoadDialect(name)) {
+ if (emitError)
+ emitError() << "can't load dialect '" << name
+ << "': missing registration?";
+ return failure();
+ }
+ }
+ return success();
+}
+
+void DialectRegistry::addDialectToPreload(StringRef name) {
+ // If we already have an allocator for this name, nothing to do: the existing
+ // registration will take care of loading the dialect.
+ if (registry.count(name))
+ return;
+ if (llvm::is_contained(dialectsToPreload, name))
+ return;
+ dialectsToPreload.emplace_back(name);
+}
+
void DialectRegistry::insertDynamic(
StringRef name, const DynamicDialectPopulationFunction &ctor) {
// This TypeID marks dynamic dialects. We cannot give a TypeID for the
@@ -326,6 +349,14 @@ bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
return false;
// Check that the current dialects fully overlap with the dialects in 'rhs'.
- return llvm::all_of(
- registry, [&](const auto &it) { return rhs.registry.count(it.first); });
+ if (!llvm::all_of(registry, [&](const auto &it) {
+ return rhs.registry.count(it.first);
+ }))
+ return false;
+
+ // Check that all preload-only entries are known in 'rhs'.
+ return llvm::all_of(dialectsToPreload, [&](const std::string &name) {
+ return rhs.registry.count(name) ||
+ llvm::is_contained(rhs.dialectsToPreload, name);
+ });
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 21891db11aa65..7b666d11a4a89 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -444,7 +444,7 @@ std::vector<Dialect *> MLIRContext::getLoadedDialects() {
}
std::vector<StringRef> MLIRContext::getAvailableDialects() {
std::vector<StringRef> result;
- for (auto dialect : impl->dialectsRegistry.getDialectNames())
+ for (auto dialect : impl->dialectsRegistry.getRegisteredDialectNames())
result.push_back(dialect);
return result;
}
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 124dda4740d5b..6162591ca5f74 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -1053,10 +1053,13 @@ LogicalResult PassManager::run(Operation *op) {
DialectRegistry dependentDialects;
getDependentDialects(dependentDialects);
context->appendDialectRegistry(dependentDialects);
- for (StringRef name : dependentDialects.getDialectNames()) {
+ for (StringRef name : dependentDialects.getRegisteredDialectNames()) {
LDBG(2) << "Loading dialect: " << name;
context->getOrLoadDialect(name);
}
+ if (failed(dependentDialects.preloadSelectDialects(
+ context, [&]() { return emitError(op->getLoc()); })))
+ return failure();
// Before running, make sure to finalize the pipeline pass list.
if (failed(getImpl().finalizePassList(context))) {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 56d2eb0f80185..8b49258e135d2 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -707,7 +707,7 @@ std::string mlir::registerCLIOptions(llvm::StringRef toolName,
std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
{
llvm::raw_string_ostream os(helpHeader);
- interleaveComma(registry.getDialectNames(), os,
+ interleaveComma(registry.getRegisteredDialectNames(), os,
[&](auto name) { os << name; });
}
return helpHeader;
@@ -735,7 +735,7 @@ mlir::registerAndParseCLIOptions(int argc, char **argv,
static LogicalResult printRegisteredDialects(DialectRegistry ®istry) {
llvm::outs() << "Available Dialects: ";
- interleave(registry.getDialectNames(), llvm::outs(), ",");
+ interleave(registry.getRegisteredDialectNames(), llvm::outs(), ",");
llvm::outs() << "\n";
return success();
}
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index aa3b1152f1181..0bc82ec2025ca 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,8 +13,10 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/DialectRegistry.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
namespace mlir {
#define GEN_PASS_DEF_CANONICALIZERPASS
@@ -40,6 +42,13 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
this->enabledPatterns = enabledPatterns;
}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ // Force-load any dialects named via the `filter-dialects` option. The
+ // allocator is resolved later from the MLIRContext's own registry.
+ for (const std::string &name : filterDialects)
+ registry.addDialectToPreload(StringRef(name));
+ }
+
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
@@ -50,11 +59,28 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
config.setMaxNumRewrites(maxNumRewrites);
config.enableCSEBetweenIterations(cseBetweenIterations);
+ llvm::DenseSet<TypeID> allowedDialects;
+ for (const std::string &name : filterDialects) {
+ Dialect *dialect = context->getLoadedDialect(name);
+ if (!dialect) {
+ return emitError(UnknownLoc::get(context))
+ << "canonicalize filter-dialects: dialect '" << name
+ << "' is not loaded in the context";
+ }
+ allowedDialects.insert(dialect->getTypeID());
+ }
+ auto isAllowed = [&](Dialect *dialect) {
+ return allowedDialects.empty() ||
+ allowedDialects.contains(dialect->getTypeID());
+ };
+
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
- dialect->getCanonicalizationPatterns(owningPatterns);
+ if (isAllowed(dialect))
+ dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
- op.getCanonicalizationPatterns(owningPatterns, context);
+ if (isAllowed(&op.getDialect()))
+ op.getCanonicalizationPatterns(owningPatterns, context);
patterns = std::make_shared<FrozenRewritePatternSet>(
std::move(owningPatterns), disabledPatterns, enabledPatterns);
diff --git a/mlir/test/Transforms/canonicalize-filter-dialects.mlir b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
new file mode 100644
index 0000000000000..c84c98c21e505
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=arith}))' | FileCheck %s --check-prefix=ARITH
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=func}))' | FileCheck %s --check-prefix=FUNC
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s --check-prefix=ALL
+// RUN: not mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=does_not_exist}))' 2>&1 | FileCheck %s --check-prefix=ERR
+
+// The `SubIRHSAddConstant` arith pattern rewrites `subi(addi(x, c0), c1)` into
+// `addi(x, c0 - c1)`. The pattern only fires when arith canonicalizations are
+// loaded.
+
+// ARITH-LABEL: func @pattern_test
+// ARITH-NOT: arith.subi
+// ARITH: arith.addi %{{.*}}, %[[C:.*]]
+
+// FUNC-LABEL: func @pattern_test
+// FUNC: arith.addi
+// FUNC: arith.subi
+
+// ALL-LABEL: func @pattern_test
+// ALL-NOT: arith.subi
+// ALL: arith.addi %{{.*}}, %[[C:.*]]
+
+// ERR: can't load dialect 'does_not_exist': missing registration?
+func.func @pattern_test(%a: i32) -> i32 {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2 : i32
+ %add = arith.addi %a, %c1 : i32
+ %sub = arith.subi %add, %c2 : i32
+ return %sub : i32
+}
More information about the Mlir-commits
mailing list