[Mlir-commits] [mlir] [MLIR] Support dynamic traits in `DynamicDialect` (PR #177735)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 21:03:48 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/177735
Unlike Interfaces, Traits in MLIR are static: they are defined via CRTP templates and used as base classes of an Op, which makes them difficult to attach to an op dynamically.
However, in IRDL and the Python bindings, we define operations dynamically through `DynamicDialect`, which means the traditional static traits cannot be applied to them. Traits are important, for example, they are how MLIR marks an op as a terminator or a non-terminator.
If `DynamicDialect` does not support traits, then even though we can define an op with regions, we cannot define new terminators or mark an op as a non-terminator. This makes `DynamicDialect` very limited in region-related scenarios.
In this PR, we introduce a `DynamicOpTrait` type that “dynamizes” `OpTrait`, enabling traits to be attached to ops in `DynamicDialect`. The key design goal is that existing checks in the MLIR codebase such as `op->hasTrait<XXX>()` work seamlessly on ops defined by `DynamicOpDefinition`, without requiring any changes.
>From 7684a00ac1fc7466583e5f26b5527fc10cec6ee4 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 24 Jan 2026 12:49:47 +0800
Subject: [PATCH] [MLIR] Support dynamic traits in DynamicDialect
---
mlir/include/mlir/IR/ExtensibleDialect.h | 76 ++++++++++++++++++++--
mlir/test/IR/dynamic.mlir | 44 +++++++++++++
mlir/test/lib/Dialect/Test/TestDialect.cpp | 21 ++++++
3 files changed, 137 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 955faaad9408b..37eea4b7fdd1b 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,64 @@ class DynamicType
void print(AsmPrinter &printer);
};
+class DynamicOpTrait {
+public:
+ virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
+ virtual LogicalResult verifyRegionTrait(Operation *op) const {
+ return success();
+ };
+
+ virtual TypeID getTypeID() const = 0;
+ virtual ~DynamicOpTrait() = default;
+};
+
+class DynamicOpTraitList {
+public:
+ void insert(std::unique_ptr<DynamicOpTrait> trait) {
+ traits.try_emplace(trait->getTypeID(), std::move(trait));
+ }
+
+ bool contains(TypeID id) const { return traits.contains(id); }
+
+ LogicalResult verifyTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+ LogicalResult verifyRegionTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyRegionTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+private:
+ DenseMap<TypeID, std::unique_ptr<DynamicOpTrait>> traits;
+};
+
+template <template <typename T> class Trait>
+class DynamicOpTraitImpl : public DynamicOpTrait {
+public:
+ TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+};
+
+namespace DynamicOpTraits {
+
+class IsTerminator : public DynamicOpTraitImpl<OpTrait::IsTerminator> {
+public:
+ LogicalResult verifyTrait(Operation *op) const override {
+ return OpTrait::impl::verifyIsTerminator(op);
+ }
+};
+
+class NoTerminator : public DynamicOpTraitImpl<OpTrait::NoTerminator> {};
+
+} // namespace DynamicOpTraits
+
//===----------------------------------------------------------------------===//
// Dynamic operation
//===----------------------------------------------------------------------===//
@@ -437,6 +495,10 @@ class DynamicOpDefinition : public OperationName::Impl {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
+ void addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+ traits.insert(std::move(trait));
+ }
+
LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
SmallVectorImpl<OpFoldResult> &results) final {
return foldHookFn(op, attrs, results);
@@ -445,7 +507,7 @@ class DynamicOpDefinition : public OperationName::Impl {
MLIRContext *context) final {
getCanonicalizationPatternsFn(set, context);
}
- bool hasTrait(TypeID id) final { return false; }
+ bool hasTrait(TypeID id) final { return traits.contains(id); }
OperationName::ParseAssemblyFn getParseAssemblyFn() final {
return [&](OpAsmParser &parser, OperationState &state) {
return parseFn(parser, state);
@@ -459,9 +521,12 @@ class DynamicOpDefinition : public OperationName::Impl {
StringRef name) final {
printFn(op, printer, name);
}
- LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+ LogicalResult verifyInvariants(Operation *op) final {
+ return failure(failed(traits.verifyTraits(op)) || failed(verifyFn(op)));
+ }
LogicalResult verifyRegionInvariants(Operation *op) final {
- return verifyRegionFn(op);
+ return failure(failed(traits.verifyRegionTraits(op)) ||
+ failed(verifyRegionFn(op)));
}
/// Implementation for properties (unsupported right now here).
@@ -494,7 +559,9 @@ class DynamicOpDefinition : public OperationName::Impl {
}
Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
- bool compareProperties(OpaqueProperties, OpaqueProperties) final { return false; }
+ bool compareProperties(OpaqueProperties, OpaqueProperties) final {
+ return false;
+ }
llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
private:
@@ -518,6 +585,7 @@ class DynamicOpDefinition : public OperationName::Impl {
OperationName::FoldHookFn foldHookFn;
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
+ DynamicOpTraitList traits;
friend ExtensibleDialect;
};
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index cf03414c89a35..3d261e6ca8efd 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -125,6 +125,50 @@ func.func @customOpParserPrinter() {
return
}
+// -----
+
+func.func @failedDynamicGenericOpNoTerminator() {
+ // expected-error at +1 {{empty block: expect at least a terminator}}
+ "test.dynamic_generic"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicTerminatorOp() {
+ // CHECK: "test.dynamic_generic"()
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // CHECK: test.dynamic_terminator"()
+ "test.dynamic_terminator"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @failedDynamicTerminatorOp() {
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // expected-error at +1 {{'test.dynamic_terminator' op must be the last operation in the parent block}}
+ "test.dynamic_terminator"() : () -> ()
+ "test.dynamic_generic"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicNoTerminatorOp() {
+ // CHECK: "test.dynamic_noterminator"()
+ "test.dynamic_noterminator"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
//===----------------------------------------------------------------------===//
// Dynamic dialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 43392d78b37d2..7c1db3884be10 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -42,6 +42,7 @@
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
+#include <memory>
#include <numeric>
#include <optional>
@@ -242,6 +243,24 @@ getDynamicGenericOp(TestDialect *dialect) {
[](Operation *op) { return success(); });
}
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_terminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::IsTerminator>());
+ return def;
+}
+
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicNoTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_noterminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::NoTerminator>());
+ return def;
+}
+
static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
@@ -329,6 +348,8 @@ void TestDialect::initialize() {
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
+ registerDynamicOp(getDynamicTerminatorOp(this));
+ registerDynamicOp(getDynamicNoTerminatorOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
registerInterfaces();
More information about the Mlir-commits
mailing list