[Mlir-commits] [mlir] [MLIR][IRDL][Python] Fix error while composing `irdl.any_of` and `irdl.base` (PR #187914)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 21 23:50:54 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-irdl

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

Previously, while users compose `irdl.any_of` and `irdl.base`, e.g.
```mlir
module {
  irdl.dialect @<!-- -->ext_attr_in_op {
    irdl.operation @<!-- -->op_with_attr {
      %0 = irdl.base "#builtin.integer" 
      %1 = irdl.base "#builtin.string" 
      %2 = irdl.any_of(%0, %1) 
      irdl.attributes {"a" = %2}
    }
  }
}
```

The program will crash due to `llvm_unreachable("unknown IRDL constraint")`.

This PR implements `getBases(..)` for `irdl::BaseOp` to make this work.

---
Full diff: https://github.com/llvm/llvm-project/pull/187914.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/IRDL/IRDLLoading.cpp (+35) 
- (modified) mlir/python/mlir/dialects/ext.py (+8-19) 
- (modified) mlir/test/python/dialects/ext.py (+41) 


``````````diff
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 8d10aacb53ec9..54c7d17a97b50 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -533,6 +533,41 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
     return false;
   }
 
+  if (auto base = dyn_cast<BaseOp>(op)) {
+    if (base.getBaseName()) {
+      StringRef baseName = *base.getBaseName();
+      if (baseName[0] == '!') {
+        auto abstractType =
+            AbstractType::lookup(baseName.drop_front(1), op->getContext());
+        assert(abstractType && "type name should refer to an existing type");
+        paramIds.insert(abstractType->get().getTypeID());
+      } else if (baseName[0] == '#') {
+        auto abstractAttr =
+            AbstractAttribute::lookup(baseName.drop_front(1), op->getContext());
+        assert(abstractAttr && "attribute name should refer to an existing "
+                               "attribute");
+        paramIds.insert(abstractAttr->get().getTypeID());
+      } else {
+        llvm_unreachable(
+            "invalid `irdl.base` operation: base name should start "
+            "with '!' for types or '#' for attributes");
+      }
+      return false;
+    }
+
+    if (base.getBaseRef()) {
+      SymbolRefAttr symRef = *base.getBaseRef();
+      Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
+      assert(defOp && "symbol reference should refer to an existing operation");
+      paramIrdlOps.insert(defOp);
+      return false;
+    }
+
+    llvm_unreachable(
+        "invalid `irdl.base` operation: expected either a base name "
+        "or a base symbol reference");
+  }
+
   // For `irdl.any`, we return `false` since we can match any type or attribute
   // base.
   if (auto isA = dyn_cast<AnyOp>(op))
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 1900b8c162456..dfcd7f2d641d0 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -85,29 +85,22 @@ def _lower(self, type_) -> ir.Value:
             return irdl.any()
         elif isinstance(type_, TypeVar):
             return self.lower(type_)
+        elif origin and issubclass(origin, Type | Attribute):
+            return irdl.parametric(
+                base_type=[origin._dialect_name, origin._name],
+                args=[self.lower(arg) for arg in get_args(type_)],
+            )
         elif origin and issubclass(origin, ir.Type):
-            if issubclass(origin, Type):
-                return irdl.parametric(
-                    base_type=[origin._dialect_name, origin._name],
-                    args=[self.lower(arg) for arg in get_args(type_)],
-                )
             t = construct_instance(origin, get_args(type_))
             return irdl.is_(ir.TypeAttr.get(t))
         elif origin and issubclass(origin, ir.Attribute):
-            if issubclass(origin, Attribute):
-                return irdl.parametric(
-                    base_type=[origin._dialect_name, origin._name],
-                    args=[self.lower(arg) for arg in get_args(type_)],
-                )
             attr = construct_instance(origin, get_args(type_))
             return irdl.is_(attr)
+        elif issubclass(type_, Type | Attribute):
+            return irdl.base(base_ref=[type_._dialect_name, type_._name])
         elif issubclass(type_, ir.Type):
-            if issubclass(type_, Type):
-                return irdl.base(base_ref=[type_._dialect_name, type_._name])
             return irdl.base(base_name=f"!{type_.type_name}")
         elif issubclass(type_, ir.Attribute):
-            if issubclass(type_, Attribute):
-                return irdl.base(base_ref=[type_._dialect_name, type_._name])
             return irdl.base(base_name=f"#{type_.attr_name}")
 
         raise TypeError(f"unsupported type in constraints: {type_}")
@@ -197,13 +190,9 @@ def from_type_hint(name, type_, specifier) -> "FieldDef":
                 get_args(type_)[0],
                 kw_only=specifier.kw_only(),
             )
-        elif issubclass(origin or type_, ir.Attribute):
-            return AttributeDef(name, variadicity, type_)
         elif type_ is ir.Region:
             return RegionDef(name, variadicity, Any)
-        raise TypeError(
-            f"unsupported type for field '{name}' in operation definition: {type_}"
-        )
+        return AttributeDef(name, variadicity, type_)
 
 
 @dataclass
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index a1593c35855ea..78c74684cef77 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -736,3 +736,44 @@ class AssignNoneOnNonOptionalOp(
     except ValueError as e:
         # CHECK: only optional operand can be a keyword parameter
         print(e)
+
+
+# CHECK: TEST: testExtDialectWithAttrInOp
+ at run
+def testExtDialectWithAttrInOp():
+    class TestAttrInOp(Dialect, name="ext_attr_in_op"):
+        pass
+
+    class OpWithAttr(TestAttrInOp.Operation, name="op_with_attr"):
+        a: IntegerAttr | StringAttr
+        b: IntegerType[32] | IntegerType[64]
+
+    with Context(), Location.unknown():
+        TestAttrInOp.load()
+        # CHECK: irdl.dialect @ext_attr_in_op {
+        # CHECK:   irdl.operation @op_with_attr {
+        # CHECK:     %0 = irdl.base "#builtin.integer"
+        # CHECK:     %1 = irdl.base "#builtin.string"
+        # CHECK:     %2 = irdl.any_of(%0, %1)
+        # CHECK:     %3 = irdl.is i32
+        # CHECK:     %4 = irdl.is i64
+        # CHECK:     %5 = irdl.any_of(%3, %4)
+        # CHECK:     irdl.attributes {"a" = %2, "b" = %5}
+        # CHECK:   }
+        # CHECK: }
+        print(TestAttrInOp._mlir_module)
+
+        i32 = IntegerType.get_signless(32)
+        i64 = IntegerType.get_signless(64)
+        iattr = IntegerAttr.get(i32, 42)
+        sattr = StringAttr.get("hello")
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            OpWithAttr(iattr, TypeAttr.get(i32))
+            OpWithAttr(sattr, TypeAttr.get(i64))
+
+        assert module.operation.verify()
+        # CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> ()
+        # CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> ()
+        print(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/187914


More information about the Mlir-commits mailing list