[Mlir-commits] [mlir] [mlir:python] Change PyOperation::create to actually return a PyOperation. (PR #114542)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 1 07:43:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Peter Hawkins (hawkinsp)
<details>
<summary>Changes</summary>
In the tablegen-generated Python bindings, we typically see a pattern like:
```
class ConstantOp(_ods_ir.OpView):
...
def __init__(self, value, *, loc=None, ip=None):
...
super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
```
i.e., the generated code calls `OpView.__init__()` with the output of `build_generic`. The purpose of `OpView` is to wrap another operation object, and `OpView.__init__` can accept any `PyOperationBase` subclass, and presumably the intention is that `build_generic` returns a `PyOperation`, so the user ends up with a `PyOpView` wrapping a `PyOperation`.
However, `PyOpView::buildGeneric` calls `PyOperation::create`, which does not just build a PyOperation, but it also calls `createOpView` to wrap that operation in a subclass of `PyOpView` and returns that view. But that's rather pointless: we called this code from the constructor of an `OpView` subclass, so we already have a view object ready to go; we don't need to build another one!
If we change `PyOperation::create` to return the underlying `PyOperation`, rather than a view wrapper, we can save allocating a useless `PyOpView` object for each ODS-generated Python object.
This saves approximately 1.5s of Python time in a JAX LLM benchmark that generates a mixture of upstream dialects and StableHLO.
Flame graph for calls to `arith_ops_gen.ConstantOp` in that benchmark before:
<img width="2672" alt="image" src="https://github.com/user-attachments/assets/3e8bfa8e-af58-42b6-9545-9e11fc9c35d6">
and after:
<img width="2675" alt="image" src="https://github.com/user-attachments/assets/adc92d96-f26e-4001-818c-984cf048f382">
---
Full diff: https://github.com/llvm/llvm-project/pull/114542.diff
1 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+1-1)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c12f75e7d224a8..3562ff38201dc3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1534,7 +1534,7 @@ py::object PyOperation::create(const std::string &name,
PyOperation::createDetached(location->getContext(), operation);
maybeInsertOperation(created, maybeIp);
- return created->createOpView();
+ return created.getObject();
}
py::object PyOperation::clone(const py::object &maybeIp) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/114542
More information about the Mlir-commits
mailing list