[Mlir-commits] [mlir] 2b9ad86 - [MLIR] Support dynamic traits in `DynamicDialect` (#177735)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 24 00:14:12 PST 2026
Author: Twice
Date: 2026-01-24T16:14:07+08:00
New Revision: 2b9ad865f734e8c8e7e76031ee392afdb560d665
URL: https://github.com/llvm/llvm-project/commit/2b9ad865f734e8c8e7e76031ee392afdb560d665
DIFF: https://github.com/llvm/llvm-project/commit/2b9ad865f734e8c8e7e76031ee392afdb560d665.diff
LOG: [MLIR] Support dynamic traits in `DynamicDialect` (#177735)
Unlike Interfaces, Traits in MLIR are static: they are defined via CRTP
templates and used as base classes of an `Op`, which makes them
difficult to attach to an op dynamically.
However, in IRDL and the Python bindings, we define operations
dynamically through `DynamicDialect`, which means the traditional static
traits cannot be applied to them. Traits are important, for example,
they are how MLIR marks an op as a terminator or a non-terminator.
If `DynamicDialect` does not support traits, then even though we can
define an op with regions, we cannot define new terminators or mark an
op as a non-terminator. This makes `DynamicDialect` very limited in
region-related scenarios.
In this PR, we introduce a `DynamicOpTrait` type that “dynamizes”
`OpTrait`, enabling traits to be attached to ops in `DynamicDialect`.
The key design goal is that existing checks in the MLIR codebase such as
`op->hasTrait<XXX>()` work seamlessly on ops defined by
`DynamicOpDefinition`, without requiring any changes.
Note that currently only two traits `IsTerminator` and `NoTerminator`
are supported in this PR.
This PR aims to lay the groundwork for adding support for traits in IRDL
and python bindings (and maybe other bindings) in the future.
Related to #158066.
Added:
Modified:
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/test/IR/dynamic.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 955faaad9408b..1021bd870d809 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,69 @@ class DynamicType
void print(AsmPrinter &printer);
};
+/// Base class of traits for dynamic-defined operations.
+class DynamicOpTrait {
+public:
+ virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
+ virtual LogicalResult verifyRegionTrait(Operation *op) const {
+ return success();
+ };
+
+ /// Returns the TypeID of the trait.
+ /// It must be equal to the TypeID of corresponding static trait
+ /// which will be used in `hasTrait(TypeID)`.
+ virtual TypeID getTypeID() const = 0;
+ virtual ~DynamicOpTrait() = default;
+};
+
+/// This class holds a list of traits for dynamic-defined operations.
+class DynamicOpTraitList {
+public:
+ bool insert(std::unique_ptr<DynamicOpTrait> trait) {
+ return traits.try_emplace(trait->getTypeID(), std::move(trait)).second;
+ }
+
+ bool contains(TypeID id) const { return traits.contains(id); }
+
+ LogicalResult verifyTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+ LogicalResult verifyRegionTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyRegionTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+private:
+ DenseMap<TypeID, std::unique_ptr<DynamicOpTrait>> traits;
+};
+
+template <template <typename T> class Trait>
+class DynamicOpTraitImpl : public DynamicOpTrait {
+public:
+ TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+};
+
+namespace DynamicOpTraits {
+
+class IsTerminator : public DynamicOpTraitImpl<OpTrait::IsTerminator> {
+public:
+ LogicalResult verifyTrait(Operation *op) const override {
+ return OpTrait::impl::verifyIsTerminator(op);
+ }
+};
+
+class NoTerminator : public DynamicOpTraitImpl<OpTrait::NoTerminator> {};
+
+} // namespace DynamicOpTraits
+
//===----------------------------------------------------------------------===//
// Dynamic operation
//===----------------------------------------------------------------------===//
@@ -437,6 +500,11 @@ class DynamicOpDefinition : public OperationName::Impl {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
+ /// Attach a trait to this dynamic-defined op.
+ bool addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+ return traits.insert(std::move(trait));
+ }
+
LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
SmallVectorImpl<OpFoldResult> &results) final {
return foldHookFn(op, attrs, results);
@@ -445,7 +513,7 @@ class DynamicOpDefinition : public OperationName::Impl {
MLIRContext *context) final {
getCanonicalizationPatternsFn(set, context);
}
- bool hasTrait(TypeID id) final { return false; }
+ bool hasTrait(TypeID id) final { return traits.contains(id); }
OperationName::ParseAssemblyFn getParseAssemblyFn() final {
return [&](OpAsmParser &parser, OperationState &state) {
return parseFn(parser, state);
@@ -459,9 +527,12 @@ class DynamicOpDefinition : public OperationName::Impl {
StringRef name) final {
printFn(op, printer, name);
}
- LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+ LogicalResult verifyInvariants(Operation *op) final {
+ return failure(failed(traits.verifyTraits(op)) || failed(verifyFn(op)));
+ }
LogicalResult verifyRegionInvariants(Operation *op) final {
- return verifyRegionFn(op);
+ return failure(failed(traits.verifyRegionTraits(op)) ||
+ failed(verifyRegionFn(op)));
}
/// Implementation for properties (unsupported right now here).
@@ -494,7 +565,9 @@ class DynamicOpDefinition : public OperationName::Impl {
}
Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
- bool compareProperties(OpaqueProperties, OpaqueProperties) final { return false; }
+ bool compareProperties(OpaqueProperties, OpaqueProperties) final {
+ return false;
+ }
llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
private:
@@ -518,6 +591,7 @@ class DynamicOpDefinition : public OperationName::Impl {
OperationName::FoldHookFn foldHookFn;
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
+ DynamicOpTraitList traits;
friend ExtensibleDialect;
};
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index cf03414c89a35..3d261e6ca8efd 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -125,6 +125,50 @@ func.func @customOpParserPrinter() {
return
}
+// -----
+
+func.func @failedDynamicGenericOpNoTerminator() {
+ // expected-error at +1 {{empty block: expect at least a terminator}}
+ "test.dynamic_generic"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicTerminatorOp() {
+ // CHECK: "test.dynamic_generic"()
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // CHECK: test.dynamic_terminator"()
+ "test.dynamic_terminator"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @failedDynamicTerminatorOp() {
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // expected-error at +1 {{'test.dynamic_terminator' op must be the last operation in the parent block}}
+ "test.dynamic_terminator"() : () -> ()
+ "test.dynamic_generic"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicNoTerminatorOp() {
+ // CHECK: "test.dynamic_noterminator"()
+ "test.dynamic_noterminator"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
//===----------------------------------------------------------------------===//
// Dynamic dialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 43392d78b37d2..7c1db3884be10 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -42,6 +42,7 @@
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
+#include <memory>
#include <numeric>
#include <optional>
@@ -242,6 +243,24 @@ getDynamicGenericOp(TestDialect *dialect) {
[](Operation *op) { return success(); });
}
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_terminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::IsTerminator>());
+ return def;
+}
+
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicNoTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_noterminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::NoTerminator>());
+ return def;
+}
+
static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
@@ -329,6 +348,8 @@ void TestDialect::initialize() {
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
+ registerDynamicOp(getDynamicTerminatorOp(this));
+ registerDynamicOp(getDynamicNoTerminatorOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
registerInterfaces();
More information about the Mlir-commits
mailing list