[Mlir-commits] [mlir] [mlir:python] Small optimization to get_op_result_or_results. (PR #123866)
Peter Hawkins
llvmlistbot at llvm.org
Tue Jan 21 17:48:28 PST 2025
https://github.com/hawkinsp created https://github.com/llvm/llvm-project/pull/123866
* We can call .results without figuring out whether we have an Operation or an OpView, and that's likely the common case anyway.
* If we have one or more results, we can return them directly, with no need for a call to get_op_result_or_value. We're guaranteed that .results returns a PyOpResultList, so we have either an OpResult or sequence of OpResults, just as the API expects.
This saves a few 100ms during IR construction in an LLM JAX benchmark.
>From 06deeea9aacda9af9091385f7c74c649557215c1 Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Wed, 22 Jan 2025 01:42:09 +0000
Subject: [PATCH] [mlir:python] Small optimization to get_op_result_or_results.
* We can call .results without figuring out whether we have an Operation
or an OpView, and that's likely the common case anyway.
* If we have one or more results, we can return them directly, with no
need for a call to get_op_result_or_value. We're guaranteed that
.results returns a PyOpResultList, so we have either an OpResult or
sequence of OpResults, just as the API expects.
This saves a few 100ms during IR construction in an LLM JAX benchmark.
---
mlir/python/mlir/dialects/_ods_common.py | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d6..5b67ab03d6f494 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -133,15 +133,16 @@ def get_op_results_or_values(
def get_op_result_or_op_results(
op: _Union[_cext.ir.OpView, _cext.ir.Operation],
) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
- if isinstance(op, _cext.ir.OpView):
- op = op.operation
- return (
- list(get_op_results_or_values(op))
- if len(op.results) > 1
- else get_op_result_or_value(op)
- if len(op.results) > 0
- else op
- )
+ results = op.results
+ num_results = len(results)
+ if num_results == 1:
+ return results[0]
+ elif num_results > 1:
+ return results
+ elif isinstance(op, _cext.ir.OpView):
+ return op.operation
+ else:
+ return op
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
More information about the Mlir-commits
mailing list