[Mlir-commits] [mlir] a974667 - [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (#110656)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 1 13:48:04 PDT 2024
Author: Mateusz Sokół
Date: 2024-10-01T16:48:00-04:00
New Revision: a9746675a505bc891c97dfcd1dbb480cf93116d5
URL: https://github.com/llvm/llvm-project/commit/a9746675a505bc891c97dfcd1dbb480cf93116d5
DIFF: https://github.com/llvm/llvm-project/commit/a9746675a505bc891c97dfcd1dbb480cf93116d5.diff
LOG: [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (#110656)
Hi @xurui1995 @makslevental,
I think in https://github.com/llvm/llvm-project/pull/103087 there's
unintended regression where user can no longer create sparse tensors
with `tensor.empty`.
Previously I could pass:
```python
out = tensor.empty(tensor_type, [])
```
where `tensor_type` contained `shape`, `dtype`, and `encoding`.
With the latest
```python
tensor.empty(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None)
```
it's no longer possible.
I propose to add `encoding` argument which is passed to
`RankedTensorType.get(static_sizes, element_type, encoding)` (I updated
one of the tests to check it).
Added:
Modified:
mlir/python/mlir/dialects/tensor.py
mlir/test/python/dialects/sparse_tensor/dialect.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 0b30d102099088..146b5f85d07f53 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -1,6 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
from ._tensor_ops_gen import *
from ._tensor_ops_gen import _Dialect
@@ -25,6 +26,7 @@ def __init__(
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
+ encoding: Optional[Attribute] = None,
loc=None,
ip=None,
):
@@ -40,7 +42,7 @@ def __init__(
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
- result_type = RankedTensorType.get(static_sizes, element_type)
+ result_type = RankedTensorType.get(static_sizes, element_type, encoding)
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
@@ -48,11 +50,14 @@ def empty(
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
+ encoding: Optional[Attribute] = None,
loc=None,
ip=None,
) -> _ods_cext.ir.Value:
return _get_op_result_or_op_results(
- EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+ EmptyOp(
+ sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip
+ )
)
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 3cc4575eb3e240..656979f3d9a1df 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -1,7 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
-from mlir.dialects import sparse_tensor as st
+from mlir.dialects import sparse_tensor as st, tensor
import textwrap
@@ -225,3 +225,21 @@ def testEncodingAttrOnTensorType():
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
print(tt.encoding)
assert tt.encoding == encoding
+
+
+# CHECK-LABEL: TEST: testEncodingEmptyTensor
+ at run
+def testEncodingEmptyTensor():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ levels = [st.LevelFormat.compressed]
+ ordering = AffineMap.get_permutation([0])
+ encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
+ tensor.empty((1024,), F32Type.get(), encoding=encoding)
+
+ # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
+ # CHECK: module {
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
+ # CHECK: }
+ print(module)
More information about the Mlir-commits
mailing list