[Mlir-commits] [mlir] [mlir][python] value casting (PR #69644)

Maksim Levental llvmlistbot at llvm.org
Tue Oct 31 16:38:01 PDT 2023


================
@@ -270,3 +271,119 @@ def testValueSetType():
 
             # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
             print(value.owner)
+
+
+# CHECK-LABEL: TEST: testValueCasters
+ at run
+def testValueCasters():
+    class NOPResult(OpResult):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPResult.__name__)
+
+    class NOPValue(Value):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPValue.__name__)
+
+    class NOPBlockArg(BlockArgument):
+        def __init__(self, v):
+            super().__init__(v)
+
+        def __str__(self):
+            return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+
+    @register_value_caster(IntegerType.static_typeid)
+    def cast_int(v) -> SubClassValueT:
+        print("in caster", v.__class__.__name__)
+        if isinstance(v, OpResult):
+            return NOPResult(v)
+        if isinstance(v, BlockArgument):
+            return NOPBlockArg(v)
+        elif isinstance(v, Value):
+            return NOPValue(v)
+
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            values = Operation.create("custom.op1", results=[i32, i32]).results
+            # CHECK: in caster OpResult
+            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", values[0].result_number, values[0])
+            # CHECK: in caster OpResult
+            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", values[1].result_number, values[1])
+
+            # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("results slice", values[:1][0].result_number, values[:1][0])
+
+            value0, value1 = values
+            # CHECK: in caster OpResult
+            # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", value0.result_number, values[0])
+            # CHECK: in caster OpResult
+            # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("result", value1.result_number, values[1])
+
+            op1 = Operation.create("custom.op2", operands=[value0, value1])
+            # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
+            print(op1)
+
+            # CHECK: in caster Value
+            # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("operand 0", op1.operands[0])
+            # CHECK: in caster Value
+            # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            print("operand 1", op1.operands[1])
+
+            # CHECK: in caster BlockArgument
+            # CHECK: in caster BlockArgument
+            @func.FuncOp.from_py_func(i32, i32)
+            def reduction(arg0, arg1):
+                # CHECK: as func arg 0 NOPBlockArg
+                print("as func arg", arg0.arg_number, arg0.__class__.__name__)
+                # CHECK: as func arg 1 NOPBlockArg
+                print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+
+            # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
+            print(
+                "args slice",
+                reduction.func_op.arguments[:1][0].arg_number,
+                reduction.func_op.arguments[:1][0],
+            )
+
+    try:
+
+        @register_value_caster(IntegerType.static_typeid)
+        def dont_cast_int_shouldnt_register(v):
+            ...
+
+    except RuntimeError as e:
+        # CHECK: Value caster is already registered: <function testValueCasters.<locals>.cast_int at
----------------
makslevental wrote:

Match on just `cast_int`

https://github.com/llvm/llvm-project/pull/69644


More information about the Mlir-commits mailing list