[Mlir-commits] [mlir] 90f22ab - [mlir][sparse][taco] Add support for scalar tensors.
Bixia Zheng
llvmlistbot at llvm.org
Fri Feb 25 07:20:19 PST 2022
Author: Bixia Zheng
Date: 2022-02-25T07:20:15-08:00
New Revision: 90f22ab3adcf2e516e7a15c870c30eb0f47fbc20
URL: https://github.com/llvm/llvm-project/commit/90f22ab3adcf2e516e7a15c870c30eb0f47fbc20
DIFF: https://github.com/llvm/llvm-project/commit/90f22ab3adcf2e516e7a15c870c30eb0f47fbc20.diff
LOG: [mlir][sparse][taco] Add support for scalar tensors.
This change allows the use of scalar tensors with index 0 in tensor index
expressions. In this case, the scalar value is broadcast to match the
dimensions of other tensors in the same expression.
Using scalar tensors as a destination in tensor index expressions is not
supported in the PyTACO DSL.
Add a PyTACO test to show the use of scalar tensors.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D120524
Added:
mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
Modified:
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
new file mode 100644
index 0000000000000..d559943c2c3ec
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
@@ -0,0 +1,28 @@
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+compressed = pt.compressed
+
+i, j = pt.get_index_vars(2)
+A = pt.tensor([2, 3])
+S = pt.tensor(3) # S is a scalar tensor.
+B = pt.tensor([2, 3], compressed)
+A.insert([0, 1], 10)
+A.insert([1, 2], 40)
+
+# Use [0] to index the scalar tensor.
+B[i, j] = A[i, j] * S[0]
+
+indices, values = B.get_coordinates_and_values()
+passed = np.array_equal(indices, [[0, 1], [1, 2]])
+passed += np.array_equal(values, [30.0, 120.0])
+
+# CHECK: Number of passed: 2
+print("Number of passed:", passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 89c4687f06d41..ed73ab13dd8d0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -447,27 +447,6 @@ def _mlir_tensor_type(
return ir.RankedTensorType.get(shape, ir_type, attr)
-def _verify_and_normalize_indices(indices) -> Tuple[IndexVar, ...]:
- """Verifies and normalizes the indices for a tensor access.
-
- Args:
- indices: The index expression used to access a tensor, which could be any
- Python object from user inputs.
-
- Returns:
- A tuple of IndexVar.
-
- Raises:
- ValueError: If indices is not an IndexVar or a tuple of IndexVar.
- """
- if isinstance(indices, IndexVar):
- return (indices,)
- elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
- return indices
-
- raise ValueError(f"Expected IndexVars: {indices}")
-
-
@dataclasses.dataclass(frozen=True)
class _StructOpInfo:
"""Information for generating a structured op in the linalg dialect.
@@ -761,7 +740,7 @@ def insert(self, coords: List[int], val: Union[float, int]) -> None:
def is_dense(self) -> bool:
"""Returns true if the tensor doesn't have sparsity annotation."""
- return self._format is None
+ return self.order == 0 or self._format is None
def to_array(self) -> np.ndarray:
"""Returns the numpy array for the Tensor.
@@ -918,6 +897,32 @@ def shape(self) -> List[int]:
"""Returns the shape of the Tensor."""
return self._shape
+ def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
+ """Verifies and normalizes the indices to access the tensor.
+
+ Args:
+ indices: The index expression used to access a tensor, which could be any
+ Python object from user inputs.
+
+ Returns:
+ A tuple of IndexVar.
+
+ Raises:
+ ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
+ a tuple of IndexVar for other tensors.
+ """
+ if self.order == 0:
+ if not isinstance(indices, int) or indices != 0:
+ raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
+ return ()
+
+ if isinstance(indices, IndexVar):
+ return (indices,)
+ elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
+ return indices
+
+ raise ValueError(f"Expected IndexVars: {indices}")
+
def __getitem__(self, key) -> "Access":
"""Verifies and processes a tensor access.
@@ -936,7 +941,7 @@ def __getitem__(self, key) -> "Access":
Raises:
ValueError: If key is not an IndexVar or a tuple of IndexVar.
"""
- indices = _verify_and_normalize_indices(key)
+ indices = self._verify_and_normalize_indices(key)
return Access(self, indices)
def __setitem__(self, key, value) -> None:
@@ -960,7 +965,7 @@ def __setitem__(self, key, value) -> None:
or a tuple of IndexVar, or the length of the indices is not the same as
the rank of the tensor.
"""
- indices = _verify_and_normalize_indices(key)
+ indices = self._verify_and_normalize_indices(key)
if len(indices) != self.order:
raise ValueError("Mismatch between indices and tensor rank: "
f"len({indices}) != {self.order}.")
@@ -985,8 +990,8 @@ def _sync_value(self) -> None:
def mlir_tensor_type(self) -> ir.RankedTensorType:
"""Returns the MLIR type for the tensor."""
- mlir_attr = None if (
- self._format is None) else self._format.mlir_tensor_attr()
+ mlir_attr = (None if (self._format is None or self.order == 0) else
+ self._format.mlir_tensor_attr())
return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
def dense_dst_ctype_pointer(self) -> ctypes.pointer:
More information about the Mlir-commits
mailing list