[Mlir-commits] [mlir] b57d6fe - [mlir][Python] Add casting constructor to Type and Attribute.
Stella Laurenzo
llvmlistbot at llvm.org
Mon May 3 10:12:44 PDT 2021
Author: Stella Laurenzo
Date: 2021-05-03T10:12:03-07:00
New Revision: b57d6fe42ed3bc6867fab25be0edcb124ea0629f
URL: https://github.com/llvm/llvm-project/commit/b57d6fe42ed3bc6867fab25be0edcb124ea0629f
DIFF: https://github.com/llvm/llvm-project/commit/b57d6fe42ed3bc6867fab25be0edcb124ea0629f.diff
LOG: [mlir][Python] Add casting constructor to Type and Attribute.
* This makes them consistent with custom types/attributes, whose constructors will do a type checked conversion. Of course, the base classes can represent everything so never error.
* More importantly, this makes it possible to subclass Type and Attribute out of tree in sensible ways.
Differential Revision: https://reviews.llvm.org/D101734
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/test/Bindings/Python/ir_attributes.py
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 160e35b21353a..d11edb1c688e8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2255,6 +2255,10 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of PyAttribute.
//----------------------------------------------------------------------------
py::class_<PyAttribute>(m, "Attribute")
+ // Delegate to the PyAttribute copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirAttribute.
+ .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
+ "Casts the passed attribute to the generic Attribute")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyAttribute::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
@@ -2358,6 +2362,10 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of PyType.
//----------------------------------------------------------------------------
py::class_<PyType>(m, "Type")
+ // Delegate to the PyType copy constructor, which will also lifetime
+ // extend the backing context which owns the MlirType.
+ .def(py::init<PyType &>(), py::arg("cast_from_type"),
+ "Casts the passed type to the generic Type")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index c09a86ad2e5eb..bdda2e3843acf 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -8,9 +8,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testParsePrint
+ at run
def testParsePrint():
with Context() as ctx:
t = Attribute.parse('"hello"')
@@ -22,12 +24,11 @@ def testParsePrint():
# CHECK: Attribute("hello")
print(repr(t))
-run(testParsePrint)
-
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
+ at run
def testParseError():
with Context():
try:
@@ -38,10 +39,9 @@ def testParseError():
else:
print("Exception not produced")
-run(testParseError)
-
# CHECK-LABEL: TEST: testAttrEq
+ at run
def testAttrEq():
with Context():
a1 = Attribute.parse('"attr1"')
@@ -56,10 +56,19 @@ def testAttrEq():
# CHECK: a1 == None: False
print("a1 == None:", a1 == None)
-run(testAttrEq)
+
+# CHECK-LABEL: TEST: testAttrCast
+ at run
+def testAttrCast():
+ with Context():
+ a1 = Attribute.parse('"attr1"')
+ a2 = Attribute(a1)
+ # CHECK: a1 == a2: True
+ print("a1 == a2:", a1 == a2)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
+ at run
def testAttrEqDoesNotRaise():
with Context():
a1 = Attribute.parse('"attr1"')
@@ -71,10 +80,9 @@ def testAttrEqDoesNotRaise():
# CHECK: True
print(a1 != None)
-run(testAttrEqDoesNotRaise)
-
# CHECK-LABEL: TEST: testAttrCapsule
+ at run
def testAttrCapsule():
with Context() as ctx:
a1 = Attribute.parse('"attr1"')
@@ -85,10 +93,9 @@ def testAttrCapsule():
assert a2 == a1
assert a2.context is ctx
-run(testAttrCapsule)
-
# CHECK-LABEL: TEST: testStandardAttrCasts
+ at run
def testStandardAttrCasts():
with Context():
a1 = Attribute.parse('"attr1"')
@@ -104,10 +111,9 @@ def testStandardAttrCasts():
else:
print("Exception not produced")
-run(testStandardAttrCasts)
-
# CHECK-LABEL: TEST: testAffineMapAttr
+ at run
def testAffineMapAttr():
with Context() as ctx:
d0 = AffineDimExpr.get(0)
@@ -122,10 +128,9 @@ def testAffineMapAttr():
attr_parsed = Attribute.parse(str(attr_built))
assert attr_built == attr_parsed
-run(testAffineMapAttr)
-
# CHECK-LABEL: TEST: testFloatAttr
+ at run
def testFloatAttr():
with Context(), Location.unknown():
fattr = FloatAttr(Attribute.parse("42.0 : f32"))
@@ -149,10 +154,9 @@ def testFloatAttr():
else:
print("Exception not produced")
-run(testFloatAttr)
-
# CHECK-LABEL: TEST: testIntegerAttr
+ at run
def testIntegerAttr():
with Context() as ctx:
iattr = IntegerAttr(Attribute.parse("42"))
@@ -166,10 +170,9 @@ def testIntegerAttr():
print("default_get:", IntegerAttr.get(
IntegerType.get_signless(32), 42))
-run(testIntegerAttr)
-
# CHECK-LABEL: TEST: testBoolAttr
+ at run
def testBoolAttr():
with Context() as ctx:
battr = BoolAttr(Attribute.parse("true"))
@@ -180,10 +183,9 @@ def testBoolAttr():
# CHECK: default_get: true
print("default_get:", BoolAttr.get(True))
-run(testBoolAttr)
-
# CHECK-LABEL: TEST: testFlatSymbolRefAttr
+ at run
def testFlatSymbolRefAttr():
with Context() as ctx:
sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
@@ -194,10 +196,9 @@ def testFlatSymbolRefAttr():
# CHECK: default_get: @foobar
print("default_get:", FlatSymbolRefAttr.get("foobar"))
-run(testFlatSymbolRefAttr)
-
# CHECK-LABEL: TEST: testStringAttr
+ at run
def testStringAttr():
with Context() as ctx:
sattr = StringAttr(Attribute.parse('"stringattr"'))
@@ -211,10 +212,9 @@ def testStringAttr():
print("typed_get:", StringAttr.get_typed(
IntegerType.get_signless(32), "12345"))
-run(testStringAttr)
-
# CHECK-LABEL: TEST: testNamedAttr
+ at run
def testNamedAttr():
with Context():
a = Attribute.parse('"stringattr"')
@@ -226,10 +226,9 @@ def testNamedAttr():
# CHECK: named: NamedAttribute(foobar="stringattr")
print("named:", named)
-run(testNamedAttr)
-
# CHECK-LABEL: TEST: testDenseIntAttr
+ at run
def testDenseIntAttr():
with Context():
raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
@@ -263,10 +262,8 @@ def testDenseIntAttr():
print(ShapedType(a.type).element_type)
-run(testDenseIntAttr)
-
-
# CHECK-LABEL: TEST: testDenseFPAttr
+ at run
def testDenseFPAttr():
with Context():
raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
@@ -286,10 +283,8 @@ def testDenseFPAttr():
print(ShapedType(a.type).element_type)
-run(testDenseFPAttr)
-
-
# CHECK-LABEL: TEST: testDictAttr
+ at run
def testDictAttr():
with Context():
dict_attr = {
@@ -327,10 +322,8 @@ def testDictAttr():
assert False, "expected IndexError on accessing an out-of-bounds attribute"
-
-run(testDictAttr)
-
# CHECK-LABEL: TEST: testTypeAttr
+ at run
def testTypeAttr():
with Context():
raw = Attribute.parse("vector<4xf32>")
@@ -341,10 +334,8 @@ def testTypeAttr():
print(ShapedType(type_attr.value).element_type)
-run(testTypeAttr)
-
-
# CHECK-LABEL: TEST: testArrayAttr
+ at run
def testArrayAttr():
with Context():
raw = Attribute.parse("[42, true, vector<4xf32>]")
@@ -391,5 +382,4 @@ def testArrayAttr():
except RuntimeError as e:
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
print("Error: ", e)
-run(testArrayAttr)
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index ea05c1561f74b..a2cc2da894973 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -8,9 +8,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testParsePrint
+ at run
def testParsePrint():
ctx = Context()
t = Type.parse("i32", ctx)
@@ -22,12 +24,11 @@ def testParsePrint():
# CHECK: Type(i32)
print(repr(t))
-run(testParsePrint)
-
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
+ at run
def testParseError():
ctx = Context()
try:
@@ -38,10 +39,9 @@ def testParseError():
else:
print("Exception not produced")
-run(testParseError)
-
# CHECK-LABEL: TEST: testTypeEq
+ at run
def testTypeEq():
ctx = Context()
t1 = Type.parse("i32", ctx)
@@ -56,10 +56,19 @@ def testTypeEq():
# CHECK: t1 == None: False
print("t1 == None:", t1 == None)
-run(testTypeEq)
+
+# CHECK-LABEL: TEST: testTypeCast
+ at run
+def testTypeCast():
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ t2 = Type(t1)
+ # CHECK: t1 == t2: True
+ print("t1 == t2:", t1 == t2)
# CHECK-LABEL: TEST: testTypeIsInstance
+ at run
def testTypeIsInstance():
ctx = Context()
t1 = Type.parse("i32", ctx)
@@ -71,10 +80,9 @@ def testTypeIsInstance():
# CHECK: True
print(F32Type.isinstance(t2))
-run(testTypeIsInstance)
-
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
+ at run
def testTypeEqDoesNotRaise():
ctx = Context()
t1 = Type.parse("i32", ctx)
@@ -86,10 +94,9 @@ def testTypeEqDoesNotRaise():
# CHECK: True
print(t1 != None)
-run(testTypeEqDoesNotRaise)
-
# CHECK-LABEL: TEST: testTypeCapsule
+ at run
def testTypeCapsule():
with Context() as ctx:
t1 = Type.parse("i32", ctx)
@@ -100,10 +107,9 @@ def testTypeCapsule():
assert t2 == t1
assert t2.context is ctx
-run(testTypeCapsule)
-
# CHECK-LABEL: TEST: testStandardTypeCasts
+ at run
def testStandardTypeCasts():
ctx = Context()
t1 = Type.parse("i32", ctx)
@@ -119,10 +125,9 @@ def testStandardTypeCasts():
else:
print("Exception not produced")
-run(testStandardTypeCasts)
-
# CHECK-LABEL: TEST: testIntegerType
+ at run
def testIntegerType():
with Context() as ctx:
i32 = IntegerType(Type.parse("i32"))
@@ -158,17 +163,16 @@ def testIntegerType():
# CHECK: unsigned: ui64
print("unsigned:", IntegerType.get_unsigned(64))
-run(testIntegerType)
-
# CHECK-LABEL: TEST: testIndexType
+ at run
def testIndexType():
with Context() as ctx:
# CHECK: index type: index
print("index type:", IndexType.get())
-run(testIndexType)
# CHECK-LABEL: TEST: testFloatType
+ at run
def testFloatType():
with Context():
# CHECK: float: bf16
@@ -180,17 +184,17 @@ def testFloatType():
# CHECK: float: f64
print("float:", F64Type.get())
-run(testFloatType)
# CHECK-LABEL: TEST: testNoneType
+ at run
def testNoneType():
with Context():
# CHECK: none type: none
print("none type:", NoneType.get())
-run(testNoneType)
# CHECK-LABEL: TEST: testComplexType
+ at run
def testComplexType():
with Context() as ctx:
complex_i32 = ComplexType(Type.parse("complex<i32>"))
@@ -210,13 +214,12 @@ def testComplexType():
else:
print("Exception not produced")
-run(testComplexType)
-
# CHECK-LABEL: TEST: testConcreteShapedType
# Shaped type is not a kind of builtin types, it is the base class for vectors,
# memrefs and tensors, so this test case uses an instance of vector to test the
# shaped type. The class hierarchy is preserved on the python side.
+ at run
def testConcreteShapedType():
with Context() as ctx:
vector = VectorType(Type.parse("vector<2x3xf32>"))
@@ -239,20 +242,20 @@ def testConcreteShapedType():
# CHECK: isinstance(ShapedType): True
print("isinstance(ShapedType):", isinstance(vector, ShapedType))
-run(testConcreteShapedType)
# CHECK-LABEL: TEST: testAbstractShapedType
# Tests that ShapedType operates as an abstract base class of a concrete
# shaped type (using vector as an example).
+ at run
def testAbstractShapedType():
ctx = Context()
vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
# CHECK: element type: f32
print("element type:", vector.element_type)
-run(testAbstractShapedType)
# CHECK-LABEL: TEST: testVectorType
+ at run
def testVectorType():
with Context(), Location.unknown():
f32 = F32Type.get()
@@ -269,9 +272,9 @@ def testVectorType():
else:
print("Exception not produced")
-run(testVectorType)
# CHECK-LABEL: TEST: testRankedTensorType
+ at run
def testRankedTensorType():
with Context(), Location.unknown():
f32 = F32Type.get()
@@ -291,9 +294,9 @@ def testRankedTensorType():
else:
print("Exception not produced")
-run(testRankedTensorType)
# CHECK-LABEL: TEST: testUnrankedTensorType
+ at run
def testUnrankedTensorType():
with Context(), Location.unknown():
f32 = F32Type.get()
@@ -333,9 +336,9 @@ def testUnrankedTensorType():
else:
print("Exception not produced")
-run(testUnrankedTensorType)
# CHECK-LABEL: TEST: testMemRefType
+ at run
def testMemRefType():
with Context(), Location.unknown():
f32 = F32Type.get()
@@ -369,9 +372,9 @@ def testMemRefType():
else:
print("Exception not produced")
-run(testMemRefType)
# CHECK-LABEL: TEST: testUnrankedMemRefType
+ at run
def testUnrankedMemRefType():
with Context(), Location.unknown():
f32 = F32Type.get()
@@ -411,9 +414,9 @@ def testUnrankedMemRefType():
else:
print("Exception not produced")
-run(testUnrankedMemRefType)
# CHECK-LABEL: TEST: testTupleType
+ at run
def testTupleType():
with Context() as ctx:
i32 = IntegerType(Type.parse("i32"))
@@ -428,10 +431,9 @@ def testTupleType():
# CHECK: pos-th type in the tuple type: f32
print("pos-th type in the tuple type:", tuple_type.get_type(1))
-run(testTupleType)
-
# CHECK-LABEL: TEST: testFunctionType
+ at run
def testFunctionType():
with Context() as ctx:
input_types = [IntegerType.get_signless(32),
@@ -442,6 +444,3 @@ def testFunctionType():
print("INPUTS:", func.inputs)
# CHECK: RESULTS: [Type(index)]
print("RESULTS:", func.results)
-
-
-run(testFunctionType)
More information about the Mlir-commits
mailing list