[Mlir-commits] [mlir] b57d6fe - [mlir][Python] Add casting constructor to Type and Attribute.

Stella Laurenzo llvmlistbot at llvm.org
Mon May 3 10:12:44 PDT 2021


Author: Stella Laurenzo
Date: 2021-05-03T10:12:03-07:00
New Revision: b57d6fe42ed3bc6867fab25be0edcb124ea0629f

URL: https://github.com/llvm/llvm-project/commit/b57d6fe42ed3bc6867fab25be0edcb124ea0629f
DIFF: https://github.com/llvm/llvm-project/commit/b57d6fe42ed3bc6867fab25be0edcb124ea0629f.diff

LOG: [mlir][Python] Add casting constructor to Type and Attribute.

* This makes them consistent with custom types/attributes, whose constructors will do a type checked conversion. Of course, the base classes can represent everything so never error.
* More importantly, this makes it possible to subclass Type and Attribute out of tree in sensible ways.

Differential Revision: https://reviews.llvm.org/D101734

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/Bindings/Python/ir_attributes.py
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 160e35b21353a..d11edb1c688e8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2255,6 +2255,10 @@ void mlir::python::populateIRCore(py::module &m) {
   // Mapping of PyAttribute.
   //----------------------------------------------------------------------------
   py::class_<PyAttribute>(m, "Attribute")
+      // Delegate to the PyAttribute copy constructor, which will also lifetime
+      // extend the backing context which owns the MlirAttribute.
+      .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
+           "Casts the passed attribute to the generic Attribute")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyAttribute::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
@@ -2358,6 +2362,10 @@ void mlir::python::populateIRCore(py::module &m) {
   // Mapping of PyType.
   //----------------------------------------------------------------------------
   py::class_<PyType>(m, "Type")
+      // Delegate to the PyType copy constructor, which will also lifetime
+      // extend the backing context which owns the MlirType.
+      .def(py::init<PyType &>(), py::arg("cast_from_type"),
+           "Casts the passed type to the generic Type")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
       .def_static(

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index c09a86ad2e5eb..bdda2e3843acf 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -8,9 +8,11 @@ def run(f):
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # CHECK-LABEL: TEST: testParsePrint
+ at run
 def testParsePrint():
   with Context() as ctx:
     t = Attribute.parse('"hello"')
@@ -22,12 +24,11 @@ def testParsePrint():
   # CHECK: Attribute("hello")
   print(repr(t))
 
-run(testParsePrint)
-
 
 # CHECK-LABEL: TEST: testParseError
 # TODO: Hook the diagnostic manager to capture a more meaningful error
 # message.
+ at run
 def testParseError():
   with Context():
     try:
@@ -38,10 +39,9 @@ def testParseError():
     else:
       print("Exception not produced")
 
-run(testParseError)
-
 
 # CHECK-LABEL: TEST: testAttrEq
+ at run
 def testAttrEq():
   with Context():
     a1 = Attribute.parse('"attr1"')
@@ -56,10 +56,19 @@ def testAttrEq():
     # CHECK: a1 == None: False
     print("a1 == None:", a1 == None)
 
-run(testAttrEq)
+
+# CHECK-LABEL: TEST: testAttrCast
+ at run
+def testAttrCast():
+  with Context():
+    a1 = Attribute.parse('"attr1"')
+    a2 = Attribute(a1)
+    # CHECK: a1 == a2: True
+    print("a1 == a2:", a1 == a2)
 
 
 # CHECK-LABEL: TEST: testAttrEqDoesNotRaise
+ at run
 def testAttrEqDoesNotRaise():
   with Context():
     a1 = Attribute.parse('"attr1"')
@@ -71,10 +80,9 @@ def testAttrEqDoesNotRaise():
     # CHECK: True
     print(a1 != None)
 
-run(testAttrEqDoesNotRaise)
-
 
 # CHECK-LABEL: TEST: testAttrCapsule
+ at run
 def testAttrCapsule():
   with Context() as ctx:
     a1 = Attribute.parse('"attr1"')
@@ -85,10 +93,9 @@ def testAttrCapsule():
   assert a2 == a1
   assert a2.context is ctx
 
-run(testAttrCapsule)
-
 
 # CHECK-LABEL: TEST: testStandardAttrCasts
+ at run
 def testStandardAttrCasts():
   with Context():
     a1 = Attribute.parse('"attr1"')
@@ -104,10 +111,9 @@ def testStandardAttrCasts():
     else:
       print("Exception not produced")
 
-run(testStandardAttrCasts)
-
 
 # CHECK-LABEL: TEST: testAffineMapAttr
+ at run
 def testAffineMapAttr():
   with Context() as ctx:
     d0 = AffineDimExpr.get(0)
@@ -122,10 +128,9 @@ def testAffineMapAttr():
     attr_parsed = Attribute.parse(str(attr_built))
     assert attr_built == attr_parsed
 
-run(testAffineMapAttr)
-
 
 # CHECK-LABEL: TEST: testFloatAttr
+ at run
 def testFloatAttr():
   with Context(), Location.unknown():
     fattr = FloatAttr(Attribute.parse("42.0 : f32"))
@@ -149,10 +154,9 @@ def testFloatAttr():
     else:
       print("Exception not produced")
 
-run(testFloatAttr)
-
 
 # CHECK-LABEL: TEST: testIntegerAttr
+ at run
 def testIntegerAttr():
   with Context() as ctx:
     iattr = IntegerAttr(Attribute.parse("42"))
@@ -166,10 +170,9 @@ def testIntegerAttr():
     print("default_get:", IntegerAttr.get(
         IntegerType.get_signless(32), 42))
 
-run(testIntegerAttr)
-
 
 # CHECK-LABEL: TEST: testBoolAttr
+ at run
 def testBoolAttr():
   with Context() as ctx:
     battr = BoolAttr(Attribute.parse("true"))
@@ -180,10 +183,9 @@ def testBoolAttr():
     # CHECK: default_get: true
     print("default_get:", BoolAttr.get(True))
 
-run(testBoolAttr)
-
 
 # CHECK-LABEL: TEST: testFlatSymbolRefAttr
+ at run
 def testFlatSymbolRefAttr():
   with Context() as ctx:
     sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
@@ -194,10 +196,9 @@ def testFlatSymbolRefAttr():
     # CHECK: default_get: @foobar
     print("default_get:", FlatSymbolRefAttr.get("foobar"))
 
-run(testFlatSymbolRefAttr)
-
 
 # CHECK-LABEL: TEST: testStringAttr
+ at run
 def testStringAttr():
   with Context() as ctx:
     sattr = StringAttr(Attribute.parse('"stringattr"'))
@@ -211,10 +212,9 @@ def testStringAttr():
     print("typed_get:", StringAttr.get_typed(
         IntegerType.get_signless(32), "12345"))
 
-run(testStringAttr)
-
 
 # CHECK-LABEL: TEST: testNamedAttr
+ at run
 def testNamedAttr():
   with Context():
     a = Attribute.parse('"stringattr"')
@@ -226,10 +226,9 @@ def testNamedAttr():
     # CHECK: named: NamedAttribute(foobar="stringattr")
     print("named:", named)
 
-run(testNamedAttr)
-
 
 # CHECK-LABEL: TEST: testDenseIntAttr
+ at run
 def testDenseIntAttr():
   with Context():
     raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
@@ -263,10 +262,8 @@ def testDenseIntAttr():
     print(ShapedType(a.type).element_type)
 
 
-run(testDenseIntAttr)
-
-
 # CHECK-LABEL: TEST: testDenseFPAttr
+ at run
 def testDenseFPAttr():
   with Context():
     raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
@@ -286,10 +283,8 @@ def testDenseFPAttr():
     print(ShapedType(a.type).element_type)
 
 
-run(testDenseFPAttr)
-
-
 # CHECK-LABEL: TEST: testDictAttr
+ at run
 def testDictAttr():
   with Context():
     dict_attr = {
@@ -327,10 +322,8 @@ def testDictAttr():
       assert False, "expected IndexError on accessing an out-of-bounds attribute"
 
 
-
-run(testDictAttr)
-
 # CHECK-LABEL: TEST: testTypeAttr
+ at run
 def testTypeAttr():
   with Context():
     raw = Attribute.parse("vector<4xf32>")
@@ -341,10 +334,8 @@ def testTypeAttr():
     print(ShapedType(type_attr.value).element_type)
 
 
-run(testTypeAttr)
-
-
 # CHECK-LABEL: TEST: testArrayAttr
+ at run
 def testArrayAttr():
   with Context():
     raw = Attribute.parse("[42, true, vector<4xf32>]")
@@ -391,5 +382,4 @@ def testArrayAttr():
     except RuntimeError as e:
       # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
       print("Error: ", e)
-run(testArrayAttr)
 

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index ea05c1561f74b..a2cc2da894973 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -8,9 +8,11 @@ def run(f):
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # CHECK-LABEL: TEST: testParsePrint
+ at run
 def testParsePrint():
   ctx = Context()
   t = Type.parse("i32", ctx)
@@ -22,12 +24,11 @@ def testParsePrint():
   # CHECK: Type(i32)
   print(repr(t))
 
-run(testParsePrint)
-
 
 # CHECK-LABEL: TEST: testParseError
 # TODO: Hook the diagnostic manager to capture a more meaningful error
 # message.
+ at run
 def testParseError():
   ctx = Context()
   try:
@@ -38,10 +39,9 @@ def testParseError():
   else:
     print("Exception not produced")
 
-run(testParseError)
-
 
 # CHECK-LABEL: TEST: testTypeEq
+ at run
 def testTypeEq():
   ctx = Context()
   t1 = Type.parse("i32", ctx)
@@ -56,10 +56,19 @@ def testTypeEq():
   # CHECK: t1 == None: False
   print("t1 == None:", t1 == None)
 
-run(testTypeEq)
+
+# CHECK-LABEL: TEST: testTypeCast
+ at run
+def testTypeCast():
+  ctx = Context()
+  t1 = Type.parse("i32", ctx)
+  t2 = Type(t1)
+  # CHECK: t1 == t2: True
+  print("t1 == t2:", t1 == t2)
 
 
 # CHECK-LABEL: TEST: testTypeIsInstance
+ at run
 def testTypeIsInstance():
   ctx = Context()
   t1 = Type.parse("i32", ctx)
@@ -71,10 +80,9 @@ def testTypeIsInstance():
   # CHECK: True
   print(F32Type.isinstance(t2))
 
-run(testTypeIsInstance)
-
 
 # CHECK-LABEL: TEST: testTypeEqDoesNotRaise
+ at run
 def testTypeEqDoesNotRaise():
   ctx = Context()
   t1 = Type.parse("i32", ctx)
@@ -86,10 +94,9 @@ def testTypeEqDoesNotRaise():
   # CHECK: True
   print(t1 != None)
 
-run(testTypeEqDoesNotRaise)
-
 
 # CHECK-LABEL: TEST: testTypeCapsule
+ at run
 def testTypeCapsule():
   with Context() as ctx:
     t1 = Type.parse("i32", ctx)
@@ -100,10 +107,9 @@ def testTypeCapsule():
   assert t2 == t1
   assert t2.context is ctx
 
-run(testTypeCapsule)
-
 
 # CHECK-LABEL: TEST: testStandardTypeCasts
+ at run
 def testStandardTypeCasts():
   ctx = Context()
   t1 = Type.parse("i32", ctx)
@@ -119,10 +125,9 @@ def testStandardTypeCasts():
   else:
     print("Exception not produced")
 
-run(testStandardTypeCasts)
-
 
 # CHECK-LABEL: TEST: testIntegerType
+ at run
 def testIntegerType():
   with Context() as ctx:
     i32 = IntegerType(Type.parse("i32"))
@@ -158,17 +163,16 @@ def testIntegerType():
     # CHECK: unsigned: ui64
     print("unsigned:", IntegerType.get_unsigned(64))
 
-run(testIntegerType)
-
 # CHECK-LABEL: TEST: testIndexType
+ at run
 def testIndexType():
   with Context() as ctx:
     # CHECK: index type: index
     print("index type:", IndexType.get())
 
-run(testIndexType)
 
 # CHECK-LABEL: TEST: testFloatType
+ at run
 def testFloatType():
   with Context():
     # CHECK: float: bf16
@@ -180,17 +184,17 @@ def testFloatType():
     # CHECK: float: f64
     print("float:", F64Type.get())
 
-run(testFloatType)
 
 # CHECK-LABEL: TEST: testNoneType
+ at run
 def testNoneType():
   with Context():
     # CHECK: none type: none
     print("none type:", NoneType.get())
 
-run(testNoneType)
 
 # CHECK-LABEL: TEST: testComplexType
+ at run
 def testComplexType():
   with Context() as ctx:
     complex_i32 = ComplexType(Type.parse("complex<i32>"))
@@ -210,13 +214,12 @@ def testComplexType():
     else:
       print("Exception not produced")
 
-run(testComplexType)
-
 
 # CHECK-LABEL: TEST: testConcreteShapedType
 # Shaped type is not a kind of builtin types, it is the base class for vectors,
 # memrefs and tensors, so this test case uses an instance of vector to test the
 # shaped type. The class hierarchy is preserved on the python side.
+ at run
 def testConcreteShapedType():
   with Context() as ctx:
     vector = VectorType(Type.parse("vector<2x3xf32>"))
@@ -239,20 +242,20 @@ def testConcreteShapedType():
     # CHECK: isinstance(ShapedType): True
     print("isinstance(ShapedType):", isinstance(vector, ShapedType))
 
-run(testConcreteShapedType)
 
 # CHECK-LABEL: TEST: testAbstractShapedType
 # Tests that ShapedType operates as an abstract base class of a concrete
 # shaped type (using vector as an example).
+ at run
 def testAbstractShapedType():
   ctx = Context()
   vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
   # CHECK: element type: f32
   print("element type:", vector.element_type)
 
-run(testAbstractShapedType)
 
 # CHECK-LABEL: TEST: testVectorType
+ at run
 def testVectorType():
   with Context(), Location.unknown():
     f32 = F32Type.get()
@@ -269,9 +272,9 @@ def testVectorType():
     else:
       print("Exception not produced")
 
-run(testVectorType)
 
 # CHECK-LABEL: TEST: testRankedTensorType
+ at run
 def testRankedTensorType():
   with Context(), Location.unknown():
     f32 = F32Type.get()
@@ -291,9 +294,9 @@ def testRankedTensorType():
     else:
       print("Exception not produced")
 
-run(testRankedTensorType)
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
+ at run
 def testUnrankedTensorType():
   with Context(), Location.unknown():
     f32 = F32Type.get()
@@ -333,9 +336,9 @@ def testUnrankedTensorType():
     else:
       print("Exception not produced")
 
-run(testUnrankedTensorType)
 
 # CHECK-LABEL: TEST: testMemRefType
+ at run
 def testMemRefType():
   with Context(), Location.unknown():
     f32 = F32Type.get()
@@ -369,9 +372,9 @@ def testMemRefType():
     else:
       print("Exception not produced")
 
-run(testMemRefType)
 
 # CHECK-LABEL: TEST: testUnrankedMemRefType
+ at run
 def testUnrankedMemRefType():
   with Context(), Location.unknown():
     f32 = F32Type.get()
@@ -411,9 +414,9 @@ def testUnrankedMemRefType():
     else:
       print("Exception not produced")
 
-run(testUnrankedMemRefType)
 
 # CHECK-LABEL: TEST: testTupleType
+ at run
 def testTupleType():
   with Context() as ctx:
     i32 = IntegerType(Type.parse("i32"))
@@ -428,10 +431,9 @@ def testTupleType():
     # CHECK: pos-th type in the tuple type: f32
     print("pos-th type in the tuple type:", tuple_type.get_type(1))
 
-run(testTupleType)
-
 
 # CHECK-LABEL: TEST: testFunctionType
+ at run
 def testFunctionType():
   with Context() as ctx:
     input_types = [IntegerType.get_signless(32),
@@ -442,6 +444,3 @@ def testFunctionType():
     print("INPUTS:", func.inputs)
     # CHECK: RESULTS: [Type(index)]
     print("RESULTS:", func.results)
-
-
-run(testFunctionType)


        


More information about the Mlir-commits mailing list