[Mlir-commits] [mlir] 78f2dae - [mlir][python] Provide some methods and properties for API completeness
Alex Zinenko
llvmlistbot at llvm.org
Wed Oct 13 05:31:01 PDT 2021
Author: Alex Zinenko
Date: 2021-10-13T14:30:55+02:00
New Revision: 78f2dae00d32504d1f645f74c67bf4340ebcda82
URL: https://github.com/llvm/llvm-project/commit/78f2dae00d32504d1f645f74c67bf4340ebcda82
DIFF: https://github.com/llvm/llvm-project/commit/78f2dae00d32504d1f645f74c67bf4340ebcda82.diff
LOG: [mlir][python] Provide some methods and properties for API completeness
When writing the user-facing documentation, I noticed several inconsistencies
and asymmetries in the Python API we provide. Fix them by adding:
- the `owner` property to regions, similarly to blocks;
- the `isinstance` method to any class derived from `PyConcreteAttr`,
`PyConcreteValue` and `PyConreteAffineExpr`, similar to `PyConcreteType` to
enable `isa`-like calls without having to handle exceptions;
- a mechanism to create the first block in the region as we could only create
blocks relative to other blocks, with is impossible in an empty region.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D111556
Added:
Modified:
mlir/docs/Bindings/Python.md
mlir/lib/Bindings/Python/IRAffine.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/test/python/ir/affine_expr.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/blocks.py
mlir/test/python/ir/operation.py
mlir/test/python/ir/value.py
Removed:
################################################################################
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 9a5965b4a5580..32914be7c4a10 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -612,8 +612,22 @@ operations (unlike in C++ that supports detached regions).
Blocks can be created within a given region and inserted before or after another
block of the same region using `create_before()`, `create_after()` methods of
-the `Block` class. They are not expected to exist outside of regions (unlike in
-C++ that supports detached blocks).
+the `Block` class, or the `create_at_start()` static method of the same class.
+They are not expected to exist outside of regions (unlike in C++ that supports
+detached blocks).
+
+```python
+from mlir.ir import Block, Context, Operation
+
+with Context():
+ op = Operation.create("generic.op", regions=1)
+
+ # Create the first block in the region.
+ entry_block = Block.create_at_start(op.regions[0])
+
+ # Create further blocks.
+ other_block = entry_block.create_after()
+```
Blocks can be used to create `InsertionPoint`s, which can point to the beginning
or the end of the block, or just before its terminator. It is common for
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 5314badba64f0..0027b68ee073d 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -99,6 +99,9 @@ class PyConcreteAffineExpr : public BaseTy {
static void bind(py::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
cls.def(py::init<PyAffineExpr &>());
+ cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
+ return DerivedTy::isaFunction(otherAffineExpr);
+ });
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d53efd9c7bd16..7b1c998297658 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1548,6 +1548,9 @@ class PyConcreteValue : public PyValue {
static void bind(py::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
+ cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ });
DerivedTy::bindDerived(cls);
}
@@ -2248,6 +2251,12 @@ void mlir::python::populateIRCore(py::module &m) {
return PyBlockList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of blocks.")
+ .def_property_readonly(
+ "owner",
+ [](PyRegion &self) {
+ return self.getParentOperation()->createOpView();
+ },
+ "Returns the operation owning this region.")
.def(
"__iter__",
[](PyRegion &self) {
@@ -2291,6 +2300,23 @@ void mlir::python::populateIRCore(py::module &m) {
return PyOperationList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of operations.")
+ .def_static(
+ "create_at_start",
+ [](PyRegion &parent, py::list pyArgTypes) {
+ parent.checkValid();
+ llvm::SmallVector<MlirType, 4> argTypes;
+ argTypes.reserve(pyArgTypes.size());
+ for (auto &pyArg : pyArgTypes) {
+ argTypes.push_back(pyArg.cast<PyType &>());
+ }
+
+ MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
+ mlirRegionInsertOwnedBlock(parent, 0, block);
+ return PyBlock(parent.getParentOperation(), block);
+ },
+ py::arg("parent"), py::arg("pyArgTypes") = py::list(),
+ "Creates and returns a new Block at the beginning of the given "
+ "region (with given argument types).")
.def(
"create_before",
[](PyBlock &self, py::args pyArgTypes) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 702870487d751..ae85ef8507c30 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -533,6 +533,7 @@ class PyRegion {
: parentOperation(std::move(parentOperation)), region(region) {
assert(!mlirRegionIsNull(region) && "python region cannot be null");
}
+ operator MlirRegion() const { return region; }
MlirRegion get() { return region; }
PyOperationRef &getParentOperation() { return parentOperation; }
@@ -681,6 +682,9 @@ class PyConcreteAttribute : public BaseTy {
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
pybind11::module_local());
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
+ cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool {
+ return DerivedTy::isaFunction(otherAttr);
+ });
DerivedTy::bindDerived(cls);
}
@@ -764,6 +768,7 @@ class PyValue {
public:
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(parentOperation), value(value) {}
+ operator MlirValue() const { return value; }
MlirValue get() { return value; }
PyOperationRef &getParentOperation() { return parentOperation; }
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index eb58579448cae..184466870a578 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -8,9 +8,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testAffineExprCapsule
+ at run
def testAffineExprCapsule():
with Context() as ctx:
affine_expr = AffineExpr.get_constant(42)
@@ -24,10 +26,9 @@ def testAffineExprCapsule():
assert affine_expr == affine_expr_2
assert affine_expr_2.context == ctx
-run(testAffineExprCapsule)
-
# CHECK-LABEL: TEST: testAffineExprEq
+ at run
def testAffineExprEq():
with Context():
a1 = AffineExpr.get_constant(42)
@@ -44,10 +45,9 @@ def testAffineExprEq():
# CHECK: False
print(a1 == "foo")
-run(testAffineExprEq)
-
# CHECK-LABEL: TEST: testAffineExprContext
+ at run
def testAffineExprContext():
with Context():
a1 = AffineExpr.get_constant(42)
@@ -61,6 +61,7 @@ def testAffineExprContext():
# CHECK-LABEL: TEST: testAffineExprConstant
+ at run
def testAffineExprConstant():
with Context():
a1 = AffineExpr.get_constant(42)
@@ -77,10 +78,9 @@ def testAffineExprConstant():
assert a1 == a2
-run(testAffineExprConstant)
-
# CHECK-LABEL: TEST: testAffineExprDim
+ at run
def testAffineExprDim():
with Context():
d1 = AffineExpr.get_dim(1)
@@ -100,10 +100,9 @@ def testAffineExprDim():
assert d1 == d11
assert d1 != d2
-run(testAffineExprDim)
-
# CHECK-LABEL: TEST: testAffineExprSymbol
+ at run
def testAffineExprSymbol():
with Context():
s1 = AffineExpr.get_symbol(1)
@@ -123,10 +122,9 @@ def testAffineExprSymbol():
assert s1 == s11
assert s1 != s2
-run(testAffineExprSymbol)
-
# CHECK-LABEL: TEST: testAffineAddExpr
+ at run
def testAffineAddExpr():
with Context():
d1 = AffineDimExpr.get(1)
@@ -143,10 +141,9 @@ def testAffineAddExpr():
assert d12.lhs == d1
assert d12.rhs == d2
-run(testAffineAddExpr)
-
# CHECK-LABEL: TEST: testAffineMulExpr
+ at run
def testAffineMulExpr():
with Context():
d1 = AffineDimExpr.get(1)
@@ -163,10 +160,9 @@ def testAffineMulExpr():
assert expr.lhs == d1
assert expr.rhs == c2
-run(testAffineMulExpr)
-
# CHECK-LABEL: TEST: testAffineModExpr
+ at run
def testAffineModExpr():
with Context():
d1 = AffineDimExpr.get(1)
@@ -183,10 +179,9 @@ def testAffineModExpr():
assert expr.lhs == d1
assert expr.rhs == c2
-run(testAffineModExpr)
-
# CHECK-LABEL: TEST: testAffineFloorDivExpr
+ at run
def testAffineFloorDivExpr():
with Context():
d1 = AffineDimExpr.get(1)
@@ -198,10 +193,9 @@ def testAffineFloorDivExpr():
assert expr.lhs == d1
assert expr.rhs == c2
-run(testAffineFloorDivExpr)
-
# CHECK-LABEL: TEST: testAffineCeilDivExpr
+ at run
def testAffineCeilDivExpr():
with Context():
d1 = AffineDimExpr.get(1)
@@ -213,10 +207,9 @@ def testAffineCeilDivExpr():
assert expr.lhs == d1
assert expr.rhs == c2
-run(testAffineCeilDivExpr)
-
# CHECK-LABEL: TEST: testAffineExprSub
+ at run
def testAffineExprSub():
with Context():
d1 = AffineDimExpr.get(1)
@@ -232,9 +225,8 @@ def testAffineExprSub():
# CHECK: -1
print(rhs.rhs)
-run(testAffineExprSub)
-
-
+# CHECK-LABEL: TEST: testClassHierarchy
+ at run
def testClassHierarchy():
with Context():
d1 = AffineDimExpr.get(1)
@@ -272,4 +264,28 @@ def testClassHierarchy():
# CHECK: Cannot cast affine expression to AffineBinaryExpr
print(e)
-run(testClassHierarchy)
+# CHECK-LABEL: TEST: testIsInstance
+ at run
+def testIsInstance():
+ with Context():
+ d1 = AffineDimExpr.get(1)
+ c2 = AffineConstantExpr.get(2)
+ add = AffineAddExpr.get(d1, c2)
+ mul = AffineMulExpr.get(d1, c2)
+
+ # CHECK: True
+ print(AffineDimExpr.isinstance(d1))
+ # CHECK: False
+ print(AffineConstantExpr.isinstance(d1))
+ # CHECK: True
+ print(AffineConstantExpr.isinstance(c2))
+ # CHECK: False
+ print(AffineMulExpr.isinstance(c2))
+ # CHECK: True
+ print(AffineAddExpr.isinstance(add))
+ # CHECK: False
+ print(AffineMulExpr.isinstance(add))
+ # CHECK: True
+ print(AffineMulExpr.isinstance(mul))
+ # CHECK: False
+ print(AffineAddExpr.isinstance(mul))
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index d2deb39a69df3..661b3ce1ccaba 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -89,6 +89,18 @@ def testAttrCast():
print("a1 == a2:", a1 == a2)
+# CHECK-LABEL: TEST: testAttrIsInstance
+ at run
+def testAttrIsInstance():
+ with Context():
+ a1 = Attribute.parse("42")
+ a2 = Attribute.parse("[42]")
+ assert IntegerAttr.isinstance(a1)
+ assert not IntegerAttr.isinstance(a2)
+ assert not ArrayAttr.isinstance(a1)
+ assert ArrayAttr.isinstance(a2)
+
+
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
@run
def testAttrEqDoesNotRaise():
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 81dccdd1f52d5..1bc38768949f4 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -51,3 +51,22 @@ def testBlockCreation():
print(module.operation)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region
+
+
+# CHECK-LABEL: TEST: testFirstBlockCreation
+# CHECK: func @test(%{{.*}}: f32)
+# CHECK: return
+ at run
+def testFirstBlockCreation():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ func = builtin.FuncOp("test", ([f32], []))
+ entry_block = Block.create_at_start(func.operation.regions[0], [f32])
+ with InsertionPoint(entry_block):
+ std.ReturnOp([])
+
+ print(module)
+ assert module.operation.verify()
+ assert func.body.blocks[0] == entry_block
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 7e4eac06382de..57a6d57a7f7bc 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -11,10 +11,12 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
+ at run
def testTraverseOpRegionBlockIterators():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -69,11 +71,9 @@ def walk_operations(indent, op):
walk_operations("", op)
-run(testTraverseOpRegionBlockIterators)
-
-
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
+ at run
def testTraverseOpRegionBlockIndices():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -111,10 +111,30 @@ def walk_operations(indent, op):
walk_operations("", module.operation)
-run(testTraverseOpRegionBlockIndices)
+# CHECK-LABEL: TEST: testBlockAndRegionOwners
+ at run
+def testBlockAndRegionOwners():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ builtin.module {
+ builtin.func @f() {
+ std.return
+ }
+ }
+ """, ctx)
+
+ assert module.operation.regions[0].owner == module.operation
+ assert module.operation.regions[0].blocks[0].owner == module.operation
+
+ func = module.body.operations[0]
+ assert func.operation.regions[0].owner == func
+ assert func.operation.regions[0].blocks[0].owner == func
# CHECK-LABEL: TEST: testBlockArgumentList
+ at run
def testBlockArgumentList():
with Context() as ctx:
module = Module.parse(
@@ -158,10 +178,8 @@ def testBlockArgumentList():
print("Type: ", t)
-run(testBlockArgumentList)
-
-
# CHECK-LABEL: TEST: testOperationOperands
+ at run
def testOperationOperands():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
@@ -181,10 +199,10 @@ def testOperationOperands():
print(f"Operand {i}, type {operand.type}")
-run(testOperationOperands)
# CHECK-LABEL: TEST: testOperationOperandsSlice
+ at run
def testOperationOperandsSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
@@ -239,10 +257,10 @@ def testOperationOperandsSlice():
print(operand)
-run(testOperationOperandsSlice)
# CHECK-LABEL: TEST: testOperationOperandsSet
+ at run
def testOperationOperandsSet():
with Context() as ctx, Location.unknown(ctx):
ctx.allow_unregistered_dialects = True
@@ -271,10 +289,10 @@ def testOperationOperandsSet():
print(consumer.operands[0])
-run(testOperationOperandsSet)
# CHECK-LABEL: TEST: testDetachedOperation
+ at run
def testDetachedOperation():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -295,10 +313,8 @@ def testDetachedOperation():
# TODO: Check successors once enough infra exists to do it properly.
-run(testDetachedOperation)
-
-
# CHECK-LABEL: TEST: testOperationInsertionPoint
+ at run
def testOperationInsertionPoint():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -335,10 +351,8 @@ def testOperationInsertionPoint():
assert False, "expected insert of attached op to raise"
-run(testOperationInsertionPoint)
-
-
# CHECK-LABEL: TEST: testOperationWithRegion
+ at run
def testOperationWithRegion():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -377,10 +391,8 @@ def testOperationWithRegion():
print(module)
-run(testOperationWithRegion)
-
-
# CHECK-LABEL: TEST: testOperationResultList
+ at run
def testOperationResultList():
ctx = Context()
module = Module.parse(
@@ -407,10 +419,10 @@ def testOperationResultList():
print(f"Result type {t}")
-run(testOperationResultList)
# CHECK-LABEL: TEST: testOperationResultListSlice
+ at run
def testOperationResultListSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
@@ -458,10 +470,10 @@ def testOperationResultListSlice():
print(f"Result {res.result_number}, type {res.type}")
-run(testOperationResultListSlice)
# CHECK-LABEL: TEST: testOperationAttributes
+ at run
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -506,10 +518,10 @@ def testOperationAttributes():
assert False, "expected IndexError on accessing an out-of-bounds attribute"
-run(testOperationAttributes)
# CHECK-LABEL: TEST: testOperationPrint
+ at run
def testOperationPrint():
ctx = Context()
module = Module.parse(
@@ -553,10 +565,10 @@ def testOperationPrint():
use_local_scope=True)
-run(testOperationPrint)
# CHECK-LABEL: TEST: testKnownOpView
+ at run
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
@@ -586,10 +598,8 @@ def testKnownOpView():
print(repr(custom))
-run(testKnownOpView)
-
-
# CHECK-LABEL: TEST: testSingleResultProperty
+ at run
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
@@ -620,10 +630,8 @@ def testSingleResultProperty():
print(module.body.operations[2])
-run(testSingleResultProperty)
-
-
# CHECK-LABEL: TEST: testPrintInvalidOperation
+ at run
def testPrintInvalidOperation():
ctx = Context()
with Location.unknown(ctx):
@@ -639,10 +647,8 @@ def testPrintInvalidOperation():
print(f".verify = {module.operation.verify()}")
-run(testPrintInvalidOperation)
-
-
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
+ at run
def testCreateWithInvalidAttributes():
ctx = Context()
with Location.unknown(ctx):
@@ -670,10 +676,8 @@ def testCreateWithInvalidAttributes():
print(e)
-run(testCreateWithInvalidAttributes)
-
-
# CHECK-LABEL: TEST: testOperationName
+ at run
def testOperationName():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -691,10 +695,8 @@ def testOperationName():
print(op.operation.name)
-run(testOperationName)
-
-
# CHECK-LABEL: TEST: testCapsuleConversions
+ at run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -706,10 +708,8 @@ def testCapsuleConversions():
assert m2 is m
-run(testCapsuleConversions)
-
-
# CHECK-LABEL: TEST: testOperationErase
+ at run
def testOperationErase():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -728,6 +728,3 @@ def testOperationErase():
# Ensure we can create another operation
Operation.create("custom.op2")
-
-
-run(testOperationErase)
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 1db9f33a65cda..230025a8f3306 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -9,9 +9,11 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
# CHECK-LABEL: TEST: testCapsuleConversions
+ at run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -24,10 +26,8 @@ def testCapsuleConversions():
assert value2 == value
-run(testCapsuleConversions)
-
-
# CHECK-LABEL: TEST: testOpResultOwner
+ at run
def testOpResultOwner():
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -37,4 +37,21 @@ def testOpResultOwner():
assert op.result.owner == op
-run(testOpResultOwner)
+# CHECK-LABEL: TEST: testValueIsInstance
+ at run
+def testValueIsInstance():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(
+ r"""
+ func @foo(%arg0: f32) {
+ %0 = "some_dialect.some_op"() : () -> f64
+ return
+ }""", ctx)
+ func = module.body.operations[0]
+ assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
+ assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
+
+ op = func.regions[0].blocks[0].operations[0]
+ assert not BlockArgument.isinstance(op.results[0])
+ assert OpResult.isinstance(op.results[0])
More information about the Mlir-commits
mailing list