[Mlir-commits] [mlir] [mlir][python] value casting (PR #69644)
Maksim Levental
llvmlistbot at llvm.org
Tue Oct 31 16:38:01 PDT 2023
================
@@ -270,3 +271,119 @@ def testValueSetType():
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
print(value.owner)
+
+
+# CHECK-LABEL: TEST: testValueCasters
+ at run
+def testValueCasters():
+ class NOPResult(OpResult):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPResult.__name__)
+
+ class NOPValue(Value):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPValue.__name__)
+
+ class NOPBlockArg(BlockArgument):
+ def __init__(self, v):
+ super().__init__(v)
+
+ def __str__(self):
+ return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
+
+ @register_value_caster(IntegerType.static_typeid)
+ def cast_int(v) -> SubClassValueT:
+ print("in caster", v.__class__.__name__)
+ if isinstance(v, OpResult):
+ return NOPResult(v)
+ if isinstance(v, BlockArgument):
+ return NOPBlockArg(v)
+ elif isinstance(v, Value):
+ return NOPValue(v)
+
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ values = Operation.create("custom.op1", results=[i32, i32]).results
+ # CHECK: in caster OpResult
+ # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", values[0].result_number, values[0])
+ # CHECK: in caster OpResult
+ # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", values[1].result_number, values[1])
+
+ # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("results slice", values[:1][0].result_number, values[:1][0])
+
+ value0, value1 = values
+ # CHECK: in caster OpResult
+ # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", value0.result_number, values[0])
+ # CHECK: in caster OpResult
+ # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("result", value1.result_number, values[1])
+
+ op1 = Operation.create("custom.op2", operands=[value0, value1])
+ # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
+ print(op1)
+
+ # CHECK: in caster Value
+ # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("operand 0", op1.operands[0])
+ # CHECK: in caster Value
+ # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ print("operand 1", op1.operands[1])
+
+ # CHECK: in caster BlockArgument
+ # CHECK: in caster BlockArgument
+ @func.FuncOp.from_py_func(i32, i32)
+ def reduction(arg0, arg1):
+ # CHECK: as func arg 0 NOPBlockArg
+ print("as func arg", arg0.arg_number, arg0.__class__.__name__)
+ # CHECK: as func arg 1 NOPBlockArg
+ print("as func arg", arg1.arg_number, arg1.__class__.__name__)
+
+ # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
+ print(
+ "args slice",
+ reduction.func_op.arguments[:1][0].arg_number,
+ reduction.func_op.arguments[:1][0],
+ )
+
+ try:
+
+ @register_value_caster(IntegerType.static_typeid)
+ def dont_cast_int_shouldnt_register(v):
+ ...
+
+ except RuntimeError as e:
+ # CHECK: Value caster is already registered: <function testValueCasters.<locals>.cast_int at
----------------
makslevental wrote:
Match on just `cast_int`
https://github.com/llvm/llvm-project/pull/69644
More information about the Mlir-commits
mailing list