[Mlir-commits] [mlir] a974667 - [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (#110656)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 1 13:48:04 PDT 2024


Author: Mateusz Sokół
Date: 2024-10-01T16:48:00-04:00
New Revision: a9746675a505bc891c97dfcd1dbb480cf93116d5

URL: https://github.com/llvm/llvm-project/commit/a9746675a505bc891c97dfcd1dbb480cf93116d5
DIFF: https://github.com/llvm/llvm-project/commit/a9746675a505bc891c97dfcd1dbb480cf93116d5.diff

LOG: [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (#110656)

Hi @xurui1995 @makslevental,

I think in https://github.com/llvm/llvm-project/pull/103087 there's
unintended regression where user can no longer create sparse tensors
with `tensor.empty`.

Previously I could pass:
```python
out = tensor.empty(tensor_type, [])
```
where `tensor_type` contained `shape`, `dtype`, and `encoding`.

With the latest 
```python
tensor.empty(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None)
```
it's no longer possible.

I propose to add `encoding` argument which is passed to
`RankedTensorType.get(static_sizes, element_type, encoding)` (I updated
one of the tests to check it).

Added: 
    

Modified: 
    mlir/python/mlir/dialects/tensor.py
    mlir/test/python/dialects/sparse_tensor/dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 0b30d102099088..146b5f85d07f53 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -1,6 +1,7 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
 
 from ._tensor_ops_gen import *
 from ._tensor_ops_gen import _Dialect
@@ -25,6 +26,7 @@ def __init__(
         sizes: Sequence[Union[int, Value]],
         element_type: Type,
         *,
+        encoding: Optional[Attribute] = None,
         loc=None,
         ip=None,
     ):
@@ -40,7 +42,7 @@ def __init__(
             else:
                 static_sizes.append(ShapedType.get_dynamic_size())
                 dynamic_sizes.append(s)
-        result_type = RankedTensorType.get(static_sizes, element_type)
+        result_type = RankedTensorType.get(static_sizes, element_type, encoding)
         super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
 
 
@@ -48,11 +50,14 @@ def empty(
     sizes: Sequence[Union[int, Value]],
     element_type: Type,
     *,
+    encoding: Optional[Attribute] = None,
     loc=None,
     ip=None,
 ) -> _ods_cext.ir.Value:
     return _get_op_result_or_op_results(
-        EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+        EmptyOp(
+            sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip
+        )
     )
 
 

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 3cc4575eb3e240..656979f3d9a1df 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -1,7 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-from mlir.dialects import sparse_tensor as st
+from mlir.dialects import sparse_tensor as st, tensor
 import textwrap
 
 
@@ -225,3 +225,21 @@ def testEncodingAttrOnTensorType():
         # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
         print(tt.encoding)
         assert tt.encoding == encoding
+
+
+# CHECK-LABEL: TEST: testEncodingEmptyTensor
+ at run
+def testEncodingEmptyTensor():
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            levels = [st.LevelFormat.compressed]
+            ordering = AffineMap.get_permutation([0])
+            encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
+            tensor.empty((1024,), F32Type.get(), encoding=encoding)
+
+        # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
+        # CHECK: module {
+        # CHECK:   %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
+        # CHECK: }
+        print(module)


        


More information about the Mlir-commits mailing list