[Mlir-commits] [mlir] [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (PR #110656)
Mateusz Sokół
llvmlistbot at llvm.org
Tue Oct 1 06:42:22 PDT 2024
https://github.com/mtsokol updated https://github.com/llvm/llvm-project/pull/110656
>From d2b4e9e96ab157e1a317cdd92da137c845c2a8e2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <mat646 at gmail.com>
Date: Tue, 1 Oct 2024 12:41:38 +0000
Subject: [PATCH] Add `encoding` argument to `tensor.empty` Python function
---
mlir/python/mlir/dialects/tensor.py | 7 +++++--
.../python/dialects/sparse_tensor/dialect.py | 20 ++++++++++++++++++-
2 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 0b30d102099088..a1a9fd6eceb3e6 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -1,6 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
from ._tensor_ops_gen import *
from ._tensor_ops_gen import _Dialect
@@ -25,6 +26,7 @@ def __init__(
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
+ encoding: Optional[Attribute] = None,
loc=None,
ip=None,
):
@@ -40,7 +42,7 @@ def __init__(
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
- result_type = RankedTensorType.get(static_sizes, element_type)
+ result_type = RankedTensorType.get(static_sizes, element_type, encoding)
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
@@ -48,11 +50,12 @@ def empty(
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
+ encoding: Optional[Attribute] = None,
loc=None,
ip=None,
) -> _ods_cext.ir.Value:
return _get_op_result_or_op_results(
- EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
+ EmptyOp(sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip)
)
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 3cc4575eb3e240..656979f3d9a1df 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -1,7 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
-from mlir.dialects import sparse_tensor as st
+from mlir.dialects import sparse_tensor as st, tensor
import textwrap
@@ -225,3 +225,21 @@ def testEncodingAttrOnTensorType():
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
print(tt.encoding)
assert tt.encoding == encoding
+
+
+# CHECK-LABEL: TEST: testEncodingEmptyTensor
+ at run
+def testEncodingEmptyTensor():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ levels = [st.LevelFormat.compressed]
+ ordering = AffineMap.get_permutation([0])
+ encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
+ tensor.empty((1024,), F32Type.get(), encoding=encoding)
+
+ # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
+ # CHECK: module {
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
+ # CHECK: }
+ print(module)
More information about the Mlir-commits
mailing list