[Mlir-commits] [mlir] 4eefc8d - [MLIR][Python] enhance python api for tensor.empty (#103087)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Aug 18 18:06:52 PDT 2024
Author: Bimo
Date: 2024-08-19T09:06:48+08:00
New Revision: 4eefc8d4cee1808f44710622c8c3b66281feb8a3
URL: https://github.com/llvm/llvm-project/commit/4eefc8d4cee1808f44710622c8c3b66281feb8a3
DIFF: https://github.com/llvm/llvm-project/commit/4eefc8d4cee1808f44710622c8c3b66281feb8a3.diff
LOG: [MLIR][Python] enhance python api for tensor.empty (#103087)
Since we have extended `EmptyOp`, maybe we should also provide a
corresponding `tensor.empty` method. In the downstream usage, I tend to
use APIs with all lowercase letters to create ops, so having a
`tensor.empty` to replace the extended `tensor.EmptyOp` would keep my
code style consistent.
Added:
Modified:
mlir/python/mlir/dialects/tensor.py
mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
mlir/test/python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 79dd9476ad0ff9..0b30d102099088 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -13,6 +13,7 @@
from typing import Sequence, Union
from ._ods_common import _cext as _ods_cext
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -43,6 +44,18 @@ def __init__(
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
+def empty(
+ sizes: Sequence[Union[int, Value]],
+ element_type: Type,
+ *,
+ loc=None,
+ ip=None,
+) -> _ods_cext.ir.Value:
+ return _get_op_result_or_op_results(
+ EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+ )
+
+
generate = region_op(
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
terminator=lambda args: YieldOp(args[0]),
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
index 18c237c68081a1..64df4e1276222f 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
@@ -63,8 +63,8 @@ def matmul_poly(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
)
def test_matmul_mono(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- return matmul_mono(lhs, rhs, outs=[init_result.result])
+ init_result = tensor.empty([4, 8], f32)
+ return matmul_mono(lhs, rhs, outs=[init_result])
# CHECK-LABEL: @test_i8i8i32_matmul
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b147551c2e73db..3bfbcf7d7f7c81 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -97,7 +97,7 @@ def testNamedStructuredOpGenericForm():
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
)
def named_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
+ init_result = tensor.empty([4, 8], f32)
# CHECK: "linalg.matmul"(%{{.*}})
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
# CHECK-SAME: operandSegmentSizes = array<i32: 2, 1>
@@ -106,7 +106,7 @@ def named_form(lhs, rhs):
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
# CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
- return linalg.matmul(lhs, rhs, outs=[init_result.result])
+ return linalg.matmul(lhs, rhs, outs=[init_result])
module.operation.print(print_generic_op_form=True)
More information about the Mlir-commits
mailing list