[Mlir-commits] [mlir] [MLIR][Python] enhance python api for tensor.empty (PR #103087)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 13 07:13:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Bimo (xurui1995)

<details>
<summary>Changes</summary>

Since we have extended `EmptyOp`, maybe we should also provide a corresponding `tensor.empty` method. In the downstream usage, I tend to use APIs with all lowercase letters to create ops, so having a `tensor.empty` to replace the extended `tensor.EmptyOp` would keep my code style consistent. 

---
Full diff: https://github.com/llvm/llvm-project/pull/103087.diff


3 Files Affected:

- (modified) mlir/python/mlir/dialects/tensor.py (+13) 
- (modified) mlir/test/python/dialects/linalg/opdsl/emit_matmul.py (+2-2) 
- (modified) mlir/test/python/dialects/linalg/ops.py (+2-2) 


``````````diff
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 79dd9476ad0ff9..0b30d102099088 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -13,6 +13,7 @@
 
 from typing import Sequence, Union
 from ._ods_common import _cext as _ods_cext
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -43,6 +44,18 @@ def __init__(
         super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
 
 
+def empty(
+    sizes: Sequence[Union[int, Value]],
+    element_type: Type,
+    *,
+    loc=None,
+    ip=None,
+) -> _ods_cext.ir.Value:
+    return _get_op_result_or_op_results(
+        EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+    )
+
+
 generate = region_op(
     lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
     terminator=lambda args: YieldOp(args[0]),
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
index 18c237c68081a1..64df4e1276222f 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
@@ -63,8 +63,8 @@ def matmul_poly(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
         )
         def test_matmul_mono(lhs, rhs):
-            init_result = tensor.EmptyOp([4, 8], f32)
-            return matmul_mono(lhs, rhs, outs=[init_result.result])
+            init_result = tensor.empty([4, 8], f32)
+            return matmul_mono(lhs, rhs, outs=[init_result])
 
         # CHECK-LABEL: @test_i8i8i32_matmul
         # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b147551c2e73db..3bfbcf7d7f7c81 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -97,7 +97,7 @@ def testNamedStructuredOpGenericForm():
                 RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
             )
             def named_form(lhs, rhs):
-                init_result = tensor.EmptyOp([4, 8], f32)
+                init_result = tensor.empty([4, 8], f32)
                 #      CHECK: "linalg.matmul"(%{{.*}})
                 # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
                 # CHECK-SAME:    operandSegmentSizes = array<i32: 2, 1>
@@ -106,7 +106,7 @@ def named_form(lhs, rhs):
                 # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
                 # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
                 # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-                return linalg.matmul(lhs, rhs, outs=[init_result.result])
+                return linalg.matmul(lhs, rhs, outs=[init_result])
 
     module.operation.print(print_generic_op_form=True)
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/103087


More information about the Mlir-commits mailing list