[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining dialects in Python bindings (PR #169045)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 05:28:04 PST 2026


================
@@ -0,0 +1,325 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import arith
+from mlir.dialects.ext import *
+from typing import Any, Optional, Sequence, TypeVar, Union
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    f()
+
+
+# CHECK: TEST: testMyInt
+ at run
+def testMyInt():
+    class MyInt(Dialect, name="myint"):
+        pass
+
+    i32 = IntegerType[32]
+
+    class ConstantOp(MyInt.Operation, name="constant"):
+        value: IntegerAttr
+        cst: Result[i32]
+
+    class AddOp(MyInt.Operation, name="add"):
+        lhs: Operand[i32]
+        rhs: Operand[i32]
+        res: Result[i32]
+
+    # CHECK: irdl.dialect @myint {
+    # CHECK:   irdl.operation @constant {
+    # CHECK:     %0 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"value" = %0}
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     irdl.results(cst: %1)
+    # CHECK:   }
+    # CHECK:   irdl.operation @add {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(lhs: %0, rhs: %0)
+    # CHECK:     irdl.results(res: %0)
+    # CHECK:   }
+    # CHECK: }
+    with Context(), Location.unknown():
+        MyInt.load()
+        print(MyInt.mlir_module)
+
+        # CHECK: ['constant', 'add']
+        print([i._op_name for i in MyInt.operations])
+        i32 = IntegerType.get_signless(32)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            two = ConstantOp(IntegerAttr.get(i32, 2))
+            three = ConstantOp(IntegerAttr.get(i32, 3))
+            add1 = AddOp(two, three)
+            add2 = AddOp(add1, two)
+            add3 = AddOp(add2, three)
+
+        # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+        # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+        # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+        # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+        # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+        print(module)
+        assert module.operation.verify()
+
+        # CHECK: AddOp
+        print(type(add1).__name__)
+        # CHECK: ConstantOp
+        print(type(two).__name__)
+        # CHECK: myint.add
+        print(add1.OPERATION_NAME)
+        # CHECK: None
+        print(add1._ODS_OPERAND_SEGMENTS)
+        # CHECK: None
+        print(add1._ODS_RESULT_SEGMENTS)
+        # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+        print(add1.lhs.owner)
+        # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+        print(add1.rhs.owner)
+        # CHECK: 2 : i32
+        print(two.value)
+        # CHECK: OpResult(%0
+        print(two.cst)
+        # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+        print(AddOp.__init__.__signature__)
+        # CHECK: (self, /, value, *, loc=None, ip=None)
+        print(ConstantOp.__init__.__signature__)
+
+
+# CHECK: TEST: testExtDialect
+ at run
+def testExtDialect():
+    class Test(Dialect, name="ext_test"):
+        pass
+
+    i32 = IntegerType[32]
+
+    class ConstraintOp(Test.Operation, name="constraint"):
+        a: Operand[i32 | IntegerType[64]]
+        b: Operand[Any]
+        # Here we use `F32Type[()]` instead of just `F32Type`
+        # because of an existing issue in IRDL implementation
+        # where `irdl.base` cannot exist in `irdl.any_of`.
+        c: Operand[F32Type[()] | i32]
+        d: Operand[Any]
+        x: IntegerAttr
+        y: FloatAttr
+
+    class OptionalOp(Test.Operation, name="optional"):
+        a: Operand[i32]
+        b: Optional[Operand[i32]]
+        out1: Result[i32]
+        out2: Result[i32] | None
+        out3: Result[i32]
+
+    class Optional2Op(Test.Operation, name="optional2"):
+        a: Optional[Operand[i32]]
+        b: Optional[Result[i32]]
+
+    class VariadicOp(Test.Operation, name="variadic"):
+        a: Operand[i32]
+        b: Optional[Operand[i32]]
+        c: Sequence[Operand[i32]]
+        out1: Sequence[Result[i32]]
+        out2: Sequence[Result[i32]]
+        out3: Optional[Result[i32]]
+        out4: Result[i32]
+
+    class Variadic2Op(Test.Operation, name="variadic2"):
+        a: Sequence[Operand[i32]]
+        b: Sequence[Result[i32]]
+
+    class MixedOpBase(Test.Operation):
+        out: Result[i32]
+        in1: Operand[i32]
+
+    class MixedOp(MixedOpBase, name="mixed"):
+        in2: IntegerAttr
+        in3: Optional[Operand[i32]]
+        in4: IntegerAttr
+        in5: Operand[i32]
+
+    T = TypeVar("T")
+    U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
+    V = TypeVar("V", bound=Union[IntegerType[8], IntegerType[16]])
+
+    class TypeVarOp(Test.Operation, name="type_var"):
+        in1: Operand[T]
+        in2: Operand[T]
+        in3: Operand[U]
+        in4: Operand[U | V]
+        in5: Operand[V]
+
+    # CHECK: irdl.dialect @ext_test {
+    # CHECK:   irdl.operation @constraint {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     %1 = irdl.is i64
+    # CHECK:     %2 = irdl.any_of(%0, %1)
+    # CHECK:     %3 = irdl.any
+    # CHECK:     %4 = irdl.is f32
+    # CHECK:     %5 = irdl.any_of(%4, %0)
+    # CHECK:     %6 = irdl.any
+    # CHECK:     irdl.operands(a: %2, b: %3, c: %5, d: %6)
+    # CHECK:     %7 = irdl.base "#builtin.integer"
+    # CHECK:     %8 = irdl.base "#builtin.float"
+    # CHECK:     irdl.attributes {"x" = %7, "y" = %8}
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0)
+    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: optional %0)
+    # CHECK:     irdl.results(b: optional %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
+    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: variadic %0)
+    # CHECK:     irdl.results(b: variadic %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @mixed {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(in1: %0, in3: optional %0, in5: %0)
+    # CHECK:     %1 = irdl.base "#builtin.integer"
+    # CHECK:     %2 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"in2" = %1, "in4" = %2}
+    # CHECK:     irdl.results(out: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @type_var {
+    # CHECK:     %0 = irdl.any
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     %2 = irdl.is i64
+    # CHECK:     %3 = irdl.any_of(%1, %2)
+    # CHECK:     %4 = irdl.is i8
+    # CHECK:     %5 = irdl.is i16
+    # CHECK:     %6 = irdl.any_of(%4, %5)
+    # CHECK:     %7 = irdl.any_of(%3, %6)
+    # CHECK:     irdl.operands(in1: %0, in2: %0, in3: %3, in4: %7, in5: %6)
+    # CHECK:   }
+    # CHECK: }
+    with Context(), Location.unknown():
+        Test.load()
+        print(Test.mlir_module)
+
+        # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
+        print(ConstraintOp.__init__.__signature__)
+        # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+        print(OptionalOp.__init__.__signature__)
+        # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
+        print(Optional2Op.__init__.__signature__)
+        # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+        print(VariadicOp.__init__.__signature__)
+        # CHECK: (self, /, b, a, *, loc=None, ip=None)
+        print(Variadic2Op.__init__.__signature__)
+        # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+        print(MixedOp.__init__.__signature__)
+
+        # CHECK: None None
+        print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [1, 0] [1, 0, 1]
+        print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [0] [0]
+        print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
+        # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+        print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [-1] [-1]
+        print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
+
+        i32 = IntegerType.get_signless(32)
+        i64 = IntegerType.get_signless(64)
+        f32 = F32Type.get()
+
+        iattr = IntegerAttr.get(i32, 2)
+        fattr = FloatAttr.get_f32(2.3)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            ione = arith.constant(i32, 1)
+            fone = arith.constant(f32, 1.2)
+
+            # CHECK: "ext_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
+            c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
----------------
PragmaTwice wrote:

It seems currently IRDL verification failure will lead to an abort. So I just add some cases for wrong number of arguments.

https://github.com/llvm/llvm-project/pull/169045


More information about the Mlir-commits mailing list