[Mlir-commits] [mlir] [MLIR] Support dynamic traits in `DynamicDialect` (PR #177735)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 21:47:20 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/177735

>From 7684a00ac1fc7466583e5f26b5527fc10cec6ee4 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 24 Jan 2026 12:49:47 +0800
Subject: [PATCH 1/2] [MLIR] Support dynamic traits in DynamicDialect

---
 mlir/include/mlir/IR/ExtensibleDialect.h   | 76 ++++++++++++++++++++--
 mlir/test/IR/dynamic.mlir                  | 44 +++++++++++++
 mlir/test/lib/Dialect/Test/TestDialect.cpp | 21 ++++++
 3 files changed, 137 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 955faaad9408b..37eea4b7fdd1b 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,64 @@ class DynamicType
   void print(AsmPrinter &printer);
 };
 
+class DynamicOpTrait {
+public:
+  virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
+  virtual LogicalResult verifyRegionTrait(Operation *op) const {
+    return success();
+  };
+
+  virtual TypeID getTypeID() const = 0;
+  virtual ~DynamicOpTrait() = default;
+};
+
+class DynamicOpTraitList {
+public:
+  void insert(std::unique_ptr<DynamicOpTrait> trait) {
+    traits.try_emplace(trait->getTypeID(), std::move(trait));
+  }
+
+  bool contains(TypeID id) const { return traits.contains(id); }
+
+  LogicalResult verifyTraits(Operation *op) const {
+    for (const auto &[_, trait] : traits) {
+      if (failed(trait->verifyTrait(op)))
+        return failure();
+    }
+    return success();
+  }
+
+  LogicalResult verifyRegionTraits(Operation *op) const {
+    for (const auto &[_, trait] : traits) {
+      if (failed(trait->verifyRegionTrait(op)))
+        return failure();
+    }
+    return success();
+  }
+
+private:
+  DenseMap<TypeID, std::unique_ptr<DynamicOpTrait>> traits;
+};
+
+template <template <typename T> class Trait>
+class DynamicOpTraitImpl : public DynamicOpTrait {
+public:
+  TypeID getTypeID() const override { return TypeID::get<Trait>(); }
+};
+
+namespace DynamicOpTraits {
+
+class IsTerminator : public DynamicOpTraitImpl<OpTrait::IsTerminator> {
+public:
+  LogicalResult verifyTrait(Operation *op) const override {
+    return OpTrait::impl::verifyIsTerminator(op);
+  }
+};
+
+class NoTerminator : public DynamicOpTraitImpl<OpTrait::NoTerminator> {};
+
+} // namespace DynamicOpTraits
+
 //===----------------------------------------------------------------------===//
 // Dynamic operation
 //===----------------------------------------------------------------------===//
@@ -437,6 +495,10 @@ class DynamicOpDefinition : public OperationName::Impl {
     populateDefaultAttrsFn = std::move(populateDefaultAttrs);
   }
 
+  void addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+    traits.insert(std::move(trait));
+  }
+
   LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
                          SmallVectorImpl<OpFoldResult> &results) final {
     return foldHookFn(op, attrs, results);
@@ -445,7 +507,7 @@ class DynamicOpDefinition : public OperationName::Impl {
                                    MLIRContext *context) final {
     getCanonicalizationPatternsFn(set, context);
   }
-  bool hasTrait(TypeID id) final { return false; }
+  bool hasTrait(TypeID id) final { return traits.contains(id); }
   OperationName::ParseAssemblyFn getParseAssemblyFn() final {
     return [&](OpAsmParser &parser, OperationState &state) {
       return parseFn(parser, state);
@@ -459,9 +521,12 @@ class DynamicOpDefinition : public OperationName::Impl {
                      StringRef name) final {
     printFn(op, printer, name);
   }
-  LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
+  LogicalResult verifyInvariants(Operation *op) final {
+    return failure(failed(traits.verifyTraits(op)) || failed(verifyFn(op)));
+  }
   LogicalResult verifyRegionInvariants(Operation *op) final {
-    return verifyRegionFn(op);
+    return failure(failed(traits.verifyRegionTraits(op)) ||
+                   failed(verifyRegionFn(op)));
   }
 
   /// Implementation for properties (unsupported right now here).
@@ -494,7 +559,9 @@ class DynamicOpDefinition : public OperationName::Impl {
   }
   Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
   void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
-  bool compareProperties(OpaqueProperties, OpaqueProperties) final { return false; }
+  bool compareProperties(OpaqueProperties, OpaqueProperties) final {
+    return false;
+  }
   llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
 
 private:
@@ -518,6 +585,7 @@ class DynamicOpDefinition : public OperationName::Impl {
   OperationName::FoldHookFn foldHookFn;
   GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
   OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
+  DynamicOpTraitList traits;
 
   friend ExtensibleDialect;
 };
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index cf03414c89a35..3d261e6ca8efd 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -125,6 +125,50 @@ func.func @customOpParserPrinter() {
   return
 }
 
+// -----
+
+func.func @failedDynamicGenericOpNoTerminator() {
+  // expected-error at +1 {{empty block: expect at least a terminator}}
+  "test.dynamic_generic"() ({
+    ^bb1:
+  }) : () -> ()
+  return
+}
+
+// -----
+
+func.func @dynamicTerminatorOp() {
+  // CHECK: "test.dynamic_generic"()
+  "test.dynamic_generic"() ({
+    ^bb1:
+      // CHECK: test.dynamic_terminator"()
+      "test.dynamic_terminator"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+func.func @failedDynamicTerminatorOp() {
+  "test.dynamic_generic"() ({
+    ^bb1:
+      // expected-error at +1 {{'test.dynamic_terminator' op must be the last operation in the parent block}}
+      "test.dynamic_terminator"() : () -> ()
+      "test.dynamic_generic"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+func.func @dynamicNoTerminatorOp() {
+  // CHECK: "test.dynamic_noterminator"()
+  "test.dynamic_noterminator"() ({
+    ^bb1:
+  }) : () -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Dynamic dialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 43392d78b37d2..7c1db3884be10 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -42,6 +42,7 @@
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include <cstdint>
+#include <memory>
 #include <numeric>
 #include <optional>
 
@@ -242,6 +243,24 @@ getDynamicGenericOp(TestDialect *dialect) {
       [](Operation *op) { return success(); });
 }
 
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicTerminatorOp(TestDialect *dialect) {
+  auto def = DynamicOpDefinition::get(
+      "dynamic_terminator", dialect, [](Operation *op) { return success(); },
+      [](Operation *op) { return success(); });
+  def->addTrait(std::make_unique<DynamicOpTraits::IsTerminator>());
+  return def;
+}
+
+static std::unique_ptr<DynamicOpDefinition>
+getDynamicNoTerminatorOp(TestDialect *dialect) {
+  auto def = DynamicOpDefinition::get(
+      "dynamic_noterminator", dialect, [](Operation *op) { return success(); },
+      [](Operation *op) { return success(); });
+  def->addTrait(std::make_unique<DynamicOpTraits::NoTerminator>());
+  return def;
+}
+
 static std::unique_ptr<DynamicOpDefinition>
 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
   return DynamicOpDefinition::get(
@@ -329,6 +348,8 @@ void TestDialect::initialize() {
   addOperations<ManualCppOpWithFold>();
   registerTestDialectOperations(this);
   registerDynamicOp(getDynamicGenericOp(this));
+  registerDynamicOp(getDynamicTerminatorOp(this));
+  registerDynamicOp(getDynamicNoTerminatorOp(this));
   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
   registerInterfaces();

>From 09f137587962fe1f1c801b48500e2d3bbb1546ee Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 24 Jan 2026 13:46:59 +0800
Subject: [PATCH 2/2] add comments

---
 mlir/include/mlir/IR/ExtensibleDialect.h | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 37eea4b7fdd1b..1021bd870d809 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -354,6 +354,7 @@ class DynamicType
   void print(AsmPrinter &printer);
 };
 
+/// Base class of traits for dynamic-defined operations.
 class DynamicOpTrait {
 public:
   virtual LogicalResult verifyTrait(Operation *op) const { return success(); };
@@ -361,14 +362,18 @@ class DynamicOpTrait {
     return success();
   };
 
+  /// Returns the TypeID of the trait.
+  /// It must be equal to the TypeID of corresponding static trait
+  /// which will be used in `hasTrait(TypeID)`.
   virtual TypeID getTypeID() const = 0;
   virtual ~DynamicOpTrait() = default;
 };
 
+/// This class holds a list of traits for dynamic-defined operations.
 class DynamicOpTraitList {
 public:
-  void insert(std::unique_ptr<DynamicOpTrait> trait) {
-    traits.try_emplace(trait->getTypeID(), std::move(trait));
+  bool insert(std::unique_ptr<DynamicOpTrait> trait) {
+    return traits.try_emplace(trait->getTypeID(), std::move(trait)).second;
   }
 
   bool contains(TypeID id) const { return traits.contains(id); }
@@ -495,8 +500,9 @@ class DynamicOpDefinition : public OperationName::Impl {
     populateDefaultAttrsFn = std::move(populateDefaultAttrs);
   }
 
-  void addTrait(std::unique_ptr<DynamicOpTrait> trait) {
-    traits.insert(std::move(trait));
+  /// Attach a trait to this dynamic-defined op.
+  bool addTrait(std::unique_ptr<DynamicOpTrait> trait) {
+    return traits.insert(std::move(trait));
   }
 
   LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,



More information about the Mlir-commits mailing list