[Mlir-commits] [mlir] [MLIR] Support dynamic traits in `DynamicDialect` (PR #177735)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 21:47:20 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/177735
>From 7684a00ac1fc7466583e5f26b5527fc10cec6ee4 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 24 Jan 2026 12:49:47 +0800
Subject: [PATCH 1/2] [MLIR] Support dynamic traits in DynamicDialect
---
mlir/include/mlir/IR/ExtensibleDialect.h | 76 ++++++++++++++++++++--
mlir/test/IR/dynamic.mlir | 44 +++++++++++++
mlir/test/lib/Dialect/Test/TestDialect.cpp | 21 ++++++
3 files changed, 137 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 955faaad9408b..37eea4b7fdd1b 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,64 @@ class DynamicType
void print(AsmPrinter &printer);
};
+class DynamicOpTrait {
+public:
+ virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
+ virtual LogicalResult verifyRegionTrait(Operation *op) const {
+ return success();
+ };
+
+ virtual TypeID getTypeID() const = 0;
+ virtual ~DynamicOpTrait() = default;
+};
+
+class DynamicOpTraitList {
+public:
+ void insert(std::unique_ptr<DynamicOpTrait> trait) {
+ traits.try_emplace(trait->getTypeID(), std::move(trait));
+ }
+
+ bool contains(TypeID id) const { return traits.contains(id); }
+
+ LogicalResult verifyTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+ LogicalResult verifyRegionTraits(Operation *op) const {
+ for (const auto &[_, trait] : traits) {
+ if (failed(trait->verifyRegionTrait(op)))
+ return failure();
+ }
+ return success();
+ }
+
+private:
+ DenseMap<TypeID, std::unique_ptr<DynamicOpTrait>> traits;
+};
+
+template <template <typename T> class Trait>
+class DynamicOpTraitImpl : public DynamicOpTrait {
+public:
+ TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+};
+
+namespace DynamicOpTraits {
+
+class IsTerminator : public DynamicOpTraitImpl<OpTrait::IsTerminator> {
+public:
+ LogicalResult verifyTrait(Operation *op) const override {
+ return OpTrait::impl::verifyIsTerminator(op);
+ }
+};
+
+class NoTerminator : public DynamicOpTraitImpl<OpTrait::NoTerminator> {};
+
+} // namespace DynamicOpTraits
+
//===----------------------------------------------------------------------===//
// Dynamic operation
//===----------------------------------------------------------------------===//
@@ -437,6 +495,10 @@ class DynamicOpDefinition : public OperationName::Impl {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
+ void addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+ traits.insert(std::move(trait));
+ }
+
LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
SmallVectorImpl<OpFoldResult> &results) final {
return foldHookFn(op, attrs, results);
@@ -445,7 +507,7 @@ class DynamicOpDefinition : public OperationName::Impl {
MLIRContext *context) final {
getCanonicalizationPatternsFn(set, context);
}
- bool hasTrait(TypeID id) final { return false; }
+ bool hasTrait(TypeID id) final { return traits.contains(id); }
OperationName::ParseAssemblyFn getParseAssemblyFn() final {
return [&](OpAsmParser &parser, OperationState &state) {
return parseFn(parser, state);
@@ -459,9 +521,12 @@ class DynamicOpDefinition : public OperationName::Impl {
StringRef name) final {
printFn(op, printer, name);
}
- LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+ LogicalResult verifyInvariants(Operation *op) final {
+ return failure(failed(traits.verifyTraits(op)) || failed(verifyFn(op)));
+ }
LogicalResult verifyRegionInvariants(Operation *op) final {
- return verifyRegionFn(op);
+ return failure(failed(traits.verifyRegionTraits(op)) ||
+ failed(verifyRegionFn(op)));
}
/// Implementation for properties (unsupported right now here).
@@ -494,7 +559,9 @@ class DynamicOpDefinition : public OperationName::Impl {
}
Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
- bool compareProperties(OpaqueProperties, OpaqueProperties) final { return false; }
+ bool compareProperties(OpaqueProperties, OpaqueProperties) final {
+ return false;
+ }
llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
private:
@@ -518,6 +585,7 @@ class DynamicOpDefinition : public OperationName::Impl {
OperationName::FoldHookFn foldHookFn;
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
+ DynamicOpTraitList traits;
friend ExtensibleDialect;
};
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index cf03414c89a35..3d261e6ca8efd 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -125,6 +125,50 @@ func.func @customOpParserPrinter() {
return
}
+// -----
+
+func.func @failedDynamicGenericOpNoTerminator() {
+ // expected-error at +1 {{empty block: expect at least a terminator}}
+ "test.dynamic_generic"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicTerminatorOp() {
+ // CHECK: "test.dynamic_generic"()
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // CHECK: test.dynamic_terminator"()
+ "test.dynamic_terminator"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @failedDynamicTerminatorOp() {
+ "test.dynamic_generic"() ({
+ ^bb1:
+ // expected-error at +1 {{'test.dynamic_terminator' op must be the last operation in the parent block}}
+ "test.dynamic_terminator"() : () -> ()
+ "test.dynamic_generic"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @dynamicNoTerminatorOp() {
+ // CHECK: "test.dynamic_noterminator"()
+ "test.dynamic_noterminator"() ({
+ ^bb1:
+ }) : () -> ()
+ return
+}
+
//===----------------------------------------------------------------------===//
// Dynamic dialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 43392d78b37d2..7c1db3884be10 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -42,6 +42,7 @@
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
+#include <memory>
#include <numeric>
#include <optional>
@@ -242,6 +243,24 @@ getDynamicGenericOp(TestDialect *dialect) {
[](Operation *op) { return success(); });
}
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_terminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::IsTerminator>());
+ return def;
+}
+
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicNoTerminatorOp(TestDialect *dialect) {
+ auto def = DynamicOpDefinition::get(
+ "dynamic_noterminator", dialect, [](Operation *op) { return success(); },
+ [](Operation *op) { return success(); });
+ def->addTrait(std::make_unique<DynamicOpTraits::NoTerminator>());
+ return def;
+}
+
static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
@@ -329,6 +348,8 @@ void TestDialect::initialize() {
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
+ registerDynamicOp(getDynamicTerminatorOp(this));
+ registerDynamicOp(getDynamicNoTerminatorOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
registerInterfaces();
>From 09f137587962fe1f1c801b48500e2d3bbb1546ee Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 24 Jan 2026 13:46:59 +0800
Subject: [PATCH 2/2] add comments
---
mlir/include/mlir/IR/ExtensibleDialect.h | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 37eea4b7fdd1b..1021bd870d809 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,7 @@ class DynamicType
void print(AsmPrinter &printer);
};
+/// Base class of traits for dynamic-defined operations.
class DynamicOpTrait {
public:
virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
@@ -361,14 +362,18 @@ class DynamicOpTrait {
return success();
};
+ /// Returns the TypeID of the trait.
+ /// It must be equal to the TypeID of corresponding static trait
+ /// which will be used in `hasTrait(TypeID)`.
virtual TypeID getTypeID() const = 0;
virtual ~DynamicOpTrait() = default;
};
+/// This class holds a list of traits for dynamic-defined operations.
class DynamicOpTraitList {
public:
- void insert(std::unique_ptr<DynamicOpTrait> trait) {
- traits.try_emplace(trait->getTypeID(), std::move(trait));
+ bool insert(std::unique_ptr<DynamicOpTrait> trait) {
+ return traits.try_emplace(trait->getTypeID(), std::move(trait)).second;
}
bool contains(TypeID id) const { return traits.contains(id); }
@@ -495,8 +500,9 @@ class DynamicOpDefinition : public OperationName::Impl {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
- void addTrait(std::unique_ptr<DynamicOpTrait> trait) {
- traits.insert(std::move(trait));
+ /// Attach a trait to this dynamic-defined op.
+ bool addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+ return traits.insert(std::move(trait));
}
LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
More information about the Mlir-commits
mailing list