[Mlir-commits] [mlir] [MLIR] Support dynamic traits in `DynamicDialect` (PR #177735)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 21:04:24 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
Unlike Interfaces, Traits in MLIR are static: they are defined via CRTP templates and used as base classes of an Op, which makes them difficult to attach to an op dynamically.
However, in IRDL and the Python bindings, we define operations dynamically through `DynamicDialect`, which means the traditional static traits cannot be applied to them. Traits are important, for example, they are how MLIR marks an op as a terminator or a non-terminator.
If `DynamicDialect` does not support traits, then even though we can define an op with regions, we cannot define new terminators or mark an op as a non-terminator. This makes `DynamicDialect` very limited in region-related scenarios.
In this PR, we introduce a `DynamicOpTrait` type that “dynamizes” `OpTrait`, enabling traits to be attached to ops in `DynamicDialect`. The key design goal is that existing checks in the MLIR codebase such as `op->hasTrait<XXX>()` work seamlessly on ops defined by `DynamicOpDefinition`, without requiring any changes.
---
Full diff: https://github.com/llvm/llvm-project/pull/177735.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/ExtensibleDialect.h (+72-4)
- (modified) mlir/test/IR/dynamic.mlir (+44)
- (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+21)
``````````diff
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 955faaad9408b..37eea4b7fdd1b 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,64 @@ class DynamicType
void print(AsmPrinter &printer);
};
+class DynamicOpTrait {
+public:
+ virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
+ virtual LogicalResult verifyRegionTrait(Operation *op) const {
+ return success();
+ };
+
+ virtual TypeID getTypeID() const = 0;
+ virtual ~DynamicOpTrait() = default;
+};
+
+class DynamicOpTraitList {
+public:
+ void insert(std::unique_ptr<DynamicOpTrait> trait) {
+ traits.try_emplace(trait->getTypeID(), std::move(trait));
+ }
+
+ bool contains(TypeID id) const { return traits.contains(id); }
+
+ LogicalResult verifyTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+ LogicalResult verifyRegionTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyRegionTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+private:
+ DenseMap<TypeID, std::unique_ptr<DynamicOpTrait>> traits;
+};
+
+template <template <typename T> class Trait>
+class DynamicOpTraitImpl : public DynamicOpTrait {
+public:
+ TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+};
+
+namespace DynamicOpTraits {
+
+class IsTerminator : public DynamicOpTraitImpl<OpTrait::IsTerminator> {
+public:
+ LogicalResult verifyTrait(Operation *op) const override {
+ return OpTrait::impl::verifyIsTerminator(op);
+ }
+};
+
+class NoTerminator : public DynamicOpTraitImpl<OpTrait::NoTerminator> {};
+
+} // namespace DynamicOpTraits
+
//===----------------------------------------------------------------------===//
// Dynamic operation
//===----------------------------------------------------------------------===//
@@ -437,6 +495,10 @@ class DynamicOpDefinition : public OperationName::Impl {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
+ void addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+ traits.insert(std::move(trait));
+ }
+
LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
SmallVectorImpl<OpFoldResult> &results) final {
return foldHookFn(op, attrs, results);
@@ -445,7 +507,7 @@ class DynamicOpDefinition : public OperationName::Impl {
MLIRContext *context) final {
getCanonicalizationPatternsFn(set, context);
}
- bool hasTrait(TypeID id) final { return false; }
+ bool hasTrait(TypeID id) final { return traits.contains(id); }
OperationName::ParseAssemblyFn getParseAssemblyFn() final {
return [&](OpAsmParser &parser, OperationState &state) {
return parseFn(parser, state);
@@ -459,9 +521,12 @@ class DynamicOpDefinition : public OperationName::Impl {
StringRef name) final {
printFn(op, printer, name);
}
- LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+ LogicalResult verifyInvariants(Operation *op) final {
+ return failure(failed(traits.verifyTraits(op)) || failed(verifyFn(op)));
+ }
LogicalResult verifyRegionInvariants(Operation *op) final {
- return verifyRegionFn(op);
+ return failure(failed(traits.verifyRegionTraits(op)) ||
+ failed(verifyRegionFn(op)));
}
/// Implementation for properties (unsupported right now here).
@@ -494,7 +559,9 @@ class DynamicOpDefinition : public OperationName::Impl {
}
Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
- bool compareProperties(OpaqueProperties, OpaqueProperties) final { return false; }
+ bool compareProperties(OpaqueProperties, OpaqueProperties) final {
+ return false;
+ }
llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
private:
@@ -518,6 +585,7 @@ class DynamicOpDefinition : public OperationName::Impl {
OperationName::FoldHookFn foldHookFn;
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
+ DynamicOpTraitList traits;
friend ExtensibleDialect;
};
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index cf03414c89a35..3d261e6ca8efd 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -125,6 +125,50 @@ func.func @customOpParserPrinter() {
return
}
+// -----
+
+func.func @failedDynamicGenericOpNoTerminator() {
+ // expected-error at +1 {{empty block: expect at least a terminator}}
+ "test.dynamic_generic"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicTerminatorOp() {
+ // CHECK: "test.dynamic_generic"()
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // CHECK: test.dynamic_terminator"()
+ "test.dynamic_terminator"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @failedDynamicTerminatorOp() {
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // expected-error at +1 {{'test.dynamic_terminator' op must be the last operation in the parent block}}
+ "test.dynamic_terminator"() : () -> ()
+ "test.dynamic_generic"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicNoTerminatorOp() {
+ // CHECK: "test.dynamic_noterminator"()
+ "test.dynamic_noterminator"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
//===----------------------------------------------------------------------===//
// Dynamic dialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 43392d78b37d2..7c1db3884be10 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -42,6 +42,7 @@
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
+#include <memory>
#include <numeric>
#include <optional>
@@ -242,6 +243,24 @@ getDynamicGenericOp(TestDialect *dialect) {
[](Operation *op) { return success(); });
}
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_terminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::IsTerminator>());
+ return def;
+}
+
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicNoTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_noterminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::NoTerminator>());
+ return def;
+}
+
static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
@@ -329,6 +348,8 @@ void TestDialect::initialize() {
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
+ registerDynamicOp(getDynamicTerminatorOp(this));
+ registerDynamicOp(getDynamicNoTerminatorOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
registerInterfaces();
``````````
</details>
https://github.com/llvm/llvm-project/pull/177735
More information about the Mlir-commits
mailing list