[Mlir-commits] [mlir] 6dabcef - [MLIR][IRDL][Python] Fix error while composing `irdl.any_of` and `irdl.base` (#187914)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 22 17:47:59 PDT 2026
Author: Twice
Date: 2026-03-23T08:47:53+08:00
New Revision: 6dabcef0b3ffa0beadd426e0fd56c61c45b5b396
URL: https://github.com/llvm/llvm-project/commit/6dabcef0b3ffa0beadd426e0fd56c61c45b5b396
DIFF: https://github.com/llvm/llvm-project/commit/6dabcef0b3ffa0beadd426e0fd56c61c45b5b396.diff
LOG: [MLIR][IRDL][Python] Fix error while composing `irdl.any_of` and `irdl.base` (#187914)
Previously, while users compose `irdl.any_of` and `irdl.base`, e.g.
```mlir
module {
irdl.dialect @ext_attr_in_op {
irdl.operation @op_with_attr {
%0 = irdl.base "#builtin.integer"
%1 = irdl.base "#builtin.string"
%2 = irdl.any_of(%0, %1)
irdl.attributes {"a" = %2}
}
}
}
```
The program will crash due to `llvm_unreachable("unknown IRDL
constraint")`.
This PR implements `getBases(..)` for `irdl::BaseOp` to make this work.
This make fields like `attr: IntegerAttr | StringAttr` work in
Python-defined dialects, which led to a crash previously.
Added:
Modified:
mlir/lib/Dialect/IRDL/IRDLLoading.cpp
mlir/python/mlir/dialects/ext.py
mlir/test/python/dialects/ext.py
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 8d10aacb53ec9..54c7d17a97b50 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -533,6 +533,41 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds,
return false;
}
+ if (auto base = dyn_cast<BaseOp>(op)) {
+ if (base.getBaseName()) {
+ StringRef baseName = *base.getBaseName();
+ if (baseName[0] == '!') {
+ auto abstractType =
+ AbstractType::lookup(baseName.drop_front(1), op->getContext());
+ assert(abstractType && "type name should refer to an existing type");
+ paramIds.insert(abstractType->get().getTypeID());
+ } else if (baseName[0] == '#') {
+ auto abstractAttr =
+ AbstractAttribute::lookup(baseName.drop_front(1), op->getContext());
+ assert(abstractAttr && "attribute name should refer to an existing "
+ "attribute");
+ paramIds.insert(abstractAttr->get().getTypeID());
+ } else {
+ llvm_unreachable(
+ "invalid `irdl.base` operation: base name should start "
+ "with '!' for types or '#' for attributes");
+ }
+ return false;
+ }
+
+ if (base.getBaseRef()) {
+ SymbolRefAttr symRef = *base.getBaseRef();
+ Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
+ assert(defOp && "symbol reference should refer to an existing operation");
+ paramIrdlOps.insert(defOp);
+ return false;
+ }
+
+ llvm_unreachable(
+ "invalid `irdl.base` operation: expected either a base name "
+ "or a base symbol reference");
+ }
+
// For `irdl.any`, we return `false` since we can match any type or attribute
// base.
if (auto isA = dyn_cast<AnyOp>(op))
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 1900b8c162456..dfcd7f2d641d0 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -85,29 +85,22 @@ def _lower(self, type_) -> ir.Value:
return irdl.any()
elif isinstance(type_, TypeVar):
return self.lower(type_)
+ elif origin and issubclass(origin, Type | Attribute):
+ return irdl.parametric(
+ base_type=[origin._dialect_name, origin._name],
+ args=[self.lower(arg) for arg in get_args(type_)],
+ )
elif origin and issubclass(origin, ir.Type):
- if issubclass(origin, Type):
- return irdl.parametric(
- base_type=[origin._dialect_name, origin._name],
- args=[self.lower(arg) for arg in get_args(type_)],
- )
t = construct_instance(origin, get_args(type_))
return irdl.is_(ir.TypeAttr.get(t))
elif origin and issubclass(origin, ir.Attribute):
- if issubclass(origin, Attribute):
- return irdl.parametric(
- base_type=[origin._dialect_name, origin._name],
- args=[self.lower(arg) for arg in get_args(type_)],
- )
attr = construct_instance(origin, get_args(type_))
return irdl.is_(attr)
+ elif issubclass(type_, Type | Attribute):
+ return irdl.base(base_ref=[type_._dialect_name, type_._name])
elif issubclass(type_, ir.Type):
- if issubclass(type_, Type):
- return irdl.base(base_ref=[type_._dialect_name, type_._name])
return irdl.base(base_name=f"!{type_.type_name}")
elif issubclass(type_, ir.Attribute):
- if issubclass(type_, Attribute):
- return irdl.base(base_ref=[type_._dialect_name, type_._name])
return irdl.base(base_name=f"#{type_.attr_name}")
raise TypeError(f"unsupported type in constraints: {type_}")
@@ -197,13 +190,9 @@ def from_type_hint(name, type_, specifier) -> "FieldDef":
get_args(type_)[0],
kw_only=specifier.kw_only(),
)
- elif issubclass(origin or type_, ir.Attribute):
- return AttributeDef(name, variadicity, type_)
elif type_ is ir.Region:
return RegionDef(name, variadicity, Any)
- raise TypeError(
- f"unsupported type for field '{name}' in operation definition: {type_}"
- )
+ return AttributeDef(name, variadicity, type_)
@dataclass
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index a1593c35855ea..78c74684cef77 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -736,3 +736,44 @@ class AssignNoneOnNonOptionalOp(
except ValueError as e:
# CHECK: only optional operand can be a keyword parameter
print(e)
+
+
+# CHECK: TEST: testExtDialectWithAttrInOp
+ at run
+def testExtDialectWithAttrInOp():
+ class TestAttrInOp(Dialect, name="ext_attr_in_op"):
+ pass
+
+ class OpWithAttr(TestAttrInOp.Operation, name="op_with_attr"):
+ a: IntegerAttr | StringAttr
+ b: IntegerType[32] | IntegerType[64]
+
+ with Context(), Location.unknown():
+ TestAttrInOp.load()
+ # CHECK: irdl.dialect @ext_attr_in_op {
+ # CHECK: irdl.operation @op_with_attr {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: %1 = irdl.base "#builtin.string"
+ # CHECK: %2 = irdl.any_of(%0, %1)
+ # CHECK: %3 = irdl.is i32
+ # CHECK: %4 = irdl.is i64
+ # CHECK: %5 = irdl.any_of(%3, %4)
+ # CHECK: irdl.attributes {"a" = %2, "b" = %5}
+ # CHECK: }
+ # CHECK: }
+ print(TestAttrInOp._mlir_module)
+
+ i32 = IntegerType.get_signless(32)
+ i64 = IntegerType.get_signless(64)
+ iattr = IntegerAttr.get(i32, 42)
+ sattr = StringAttr.get("hello")
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ OpWithAttr(iattr, TypeAttr.get(i32))
+ OpWithAttr(sattr, TypeAttr.get(i64))
+
+ assert module.operation.verify()
+ # CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> ()
+ # CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> ()
+ print(module)
More information about the Mlir-commits
mailing list