[Mlir-commits] [mlir] [MLIR][Python] enhance python api for tensor.empty (PR #103087)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 13 07:12:55 PDT 2024
https://github.com/xurui1995 created https://github.com/llvm/llvm-project/pull/103087
Since we have extended `EmptyOp`, maybe we should also provide a corresponding `tensor.empty` method. In the downstream usage, I tend to use APIs with all lowercase letters to create ops, so having a `tensor.empty` to replace the extended `tensor.EmptyOp` would keep my code style consistent.
>From 9c3ee672ca58ca938498b85463109f30d0aa1827 Mon Sep 17 00:00:00 2001
From: "Xu, Rui" <rui.xu at intel.com>
Date: Tue, 13 Aug 2024 07:04:01 -0700
Subject: [PATCH] enhance python api for tensor.empty
---
mlir/python/mlir/dialects/tensor.py | 13 +++++++++++++
.../python/dialects/linalg/opdsl/emit_matmul.py | 4 ++--
mlir/test/python/dialects/linalg/ops.py | 4 ++--
3 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 79dd9476ad0ff9..0b30d102099088 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -13,6 +13,7 @@
from typing import Sequence, Union
from ._ods_common import _cext as _ods_cext
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -43,6 +44,18 @@ def __init__(
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
+def empty(
+ sizes: Sequence[Union[int, Value]],
+ element_type: Type,
+ *,
+ loc=None,
+ ip=None,
+) -> _ods_cext.ir.Value:
+ return _get_op_result_or_op_results(
+ EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+ )
+
+
generate = region_op(
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
terminator=lambda args: YieldOp(args[0]),
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
index 18c237c68081a1..64df4e1276222f 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
@@ -63,8 +63,8 @@ def matmul_poly(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
)
def test_matmul_mono(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- return matmul_mono(lhs, rhs, outs=[init_result.result])
+ init_result = tensor.empty([4, 8], f32)
+ return matmul_mono(lhs, rhs, outs=[init_result])
# CHECK-LABEL: @test_i8i8i32_matmul
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b147551c2e73db..3bfbcf7d7f7c81 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -97,7 +97,7 @@ def testNamedStructuredOpGenericForm():
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
)
def named_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
+ init_result = tensor.empty([4, 8], f32)
# CHECK: "linalg.matmul"(%{{.*}})
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
# CHECK-SAME: operandSegmentSizes = array<i32: 2, 1>
@@ -106,7 +106,7 @@ def named_form(lhs, rhs):
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
# CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
- return linalg.matmul(lhs, rhs, outs=[init_result.result])
+ return linalg.matmul(lhs, rhs, outs=[init_result])
module.operation.print(print_generic_op_form=True)
More information about the Mlir-commits
mailing list