[Mlir-commits] [mlir] [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (PR #110656)

Mateusz Sokół llvmlistbot at llvm.org
Tue Oct 1 06:42:22 PDT 2024


https://github.com/mtsokol updated https://github.com/llvm/llvm-project/pull/110656

>From d2b4e9e96ab157e1a317cdd92da137c845c2a8e2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <mat646 at gmail.com>
Date: Tue, 1 Oct 2024 12:41:38 +0000
Subject: [PATCH] Add `encoding` argument to `tensor.empty` Python function

---
 mlir/python/mlir/dialects/tensor.py           |  7 +++++--
 .../python/dialects/sparse_tensor/dialect.py  | 20 ++++++++++++++++++-
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 0b30d102099088..a1a9fd6eceb3e6 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -1,6 +1,7 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
 
 from ._tensor_ops_gen import *
 from ._tensor_ops_gen import _Dialect
@@ -25,6 +26,7 @@ def __init__(
         sizes: Sequence[Union[int, Value]],
         element_type: Type,
         *,
+        encoding: Optional[Attribute] = None,
         loc=None,
         ip=None,
     ):
@@ -40,7 +42,7 @@ def __init__(
             else:
                 static_sizes.append(ShapedType.get_dynamic_size())
                 dynamic_sizes.append(s)
-        result_type = RankedTensorType.get(static_sizes, element_type)
+        result_type = RankedTensorType.get(static_sizes, element_type, encoding)
         super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
 
 
@@ -48,11 +50,12 @@ def empty(
     sizes: Sequence[Union[int, Value]],
     element_type: Type,
     *,
+    encoding: Optional[Attribute] = None,
     loc=None,
     ip=None,
 ) -> _ods_cext.ir.Value:
     return _get_op_result_or_op_results(
-        EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+        EmptyOp(sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip)
     )
 
 
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 3cc4575eb3e240..656979f3d9a1df 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -1,7 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-from mlir.dialects import sparse_tensor as st
+from mlir.dialects import sparse_tensor as st, tensor
 import textwrap
 
 
@@ -225,3 +225,21 @@ def testEncodingAttrOnTensorType():
         # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
         print(tt.encoding)
         assert tt.encoding == encoding
+
+
+# CHECK-LABEL: TEST: testEncodingEmptyTensor
+ at run
+def testEncodingEmptyTensor():
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            levels = [st.LevelFormat.compressed]
+            ordering = AffineMap.get_permutation([0])
+            encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
+            tensor.empty((1024,), F32Type.get(), encoding=encoding)
+
+        # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
+        # CHECK: module {
+        # CHECK:   %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
+        # CHECK: }
+        print(module)



More information about the Mlir-commits mailing list