[Mlir-commits] [mlir] 90f22ab - [mlir][sparse][taco] Add support for scalar tensors.

Bixia Zheng llvmlistbot at llvm.org
Fri Feb 25 07:20:19 PST 2022


Author: Bixia Zheng
Date: 2022-02-25T07:20:15-08:00
New Revision: 90f22ab3adcf2e516e7a15c870c30eb0f47fbc20

URL: https://github.com/llvm/llvm-project/commit/90f22ab3adcf2e516e7a15c870c30eb0f47fbc20
DIFF: https://github.com/llvm/llvm-project/commit/90f22ab3adcf2e516e7a15c870c30eb0f47fbc20.diff

LOG: [mlir][sparse][taco] Add support for scalar tensors.

This change allows the use of scalar tensors with index 0 in tensor index
expressions. In this case, the scalar value is broadcast to match the
dimensions of other tensors in the same expression.

Using scalar tensors as a destination in tensor index expressions is not
supported in the PyTACO DSL.

Add a PyTACO test to show the use of scalar tensors.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120524

Added: 
    mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py

Modified: 
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
new file mode 100644
index 0000000000000..d559943c2c3ec
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
@@ -0,0 +1,28 @@
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+compressed = pt.compressed
+
+i, j = pt.get_index_vars(2)
+A = pt.tensor([2, 3])
+S = pt.tensor(3) # S is a scalar tensor.
+B = pt.tensor([2, 3], compressed)
+A.insert([0, 1], 10)
+A.insert([1, 2], 40)
+
+# Use [0] to index the scalar tensor.
+B[i, j] = A[i, j] * S[0]
+
+indices, values = B.get_coordinates_and_values()
+passed = np.array_equal(indices, [[0, 1], [1, 2]])
+passed += np.array_equal(values, [30.0, 120.0])
+
+# CHECK: Number of passed: 2
+print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 89c4687f06d41..ed73ab13dd8d0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -447,27 +447,6 @@ def _mlir_tensor_type(
   return ir.RankedTensorType.get(shape, ir_type, attr)
 
 
-def _verify_and_normalize_indices(indices) -> Tuple[IndexVar, ...]:
-  """Verifies and normalizes the indices for a tensor access.
-
-  Args:
-    indices: The index expression used to access a tensor, which could be any
-      Python object from user inputs.
-
-  Returns:
-    A tuple of IndexVar.
-
-  Raises:
-    ValueError: If indices is not an IndexVar or a tuple of IndexVar.
-  """
-  if isinstance(indices, IndexVar):
-    return (indices,)
-  elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
-    return indices
-
-  raise ValueError(f"Expected IndexVars: {indices}")
-
-
 @dataclasses.dataclass(frozen=True)
 class _StructOpInfo:
   """Information for generating a structured op in the linalg dialect.
@@ -761,7 +740,7 @@ def insert(self, coords: List[int], val: Union[float, int]) -> None:
 
   def is_dense(self) -> bool:
     """Returns true if the tensor doesn't have sparsity annotation."""
-    return self._format is None
+    return self.order == 0 or self._format is None
 
   def to_array(self) -> np.ndarray:
     """Returns the numpy array for the Tensor.
@@ -918,6 +897,32 @@ def shape(self) -> List[int]:
     """Returns the shape of the Tensor."""
     return self._shape
 
+  def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
+    """Verifies and normalizes the indices to access the tensor.
+
+    Args:
+      indices: The index expression used to access a tensor, which could be any
+        Python object from user inputs.
+
+    Returns:
+      A tuple of IndexVar.
+
+    Raises:
+      ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
+        a tuple of IndexVar for other tensors.
+    """
+    if self.order == 0:
+      if not isinstance(indices, int) or indices != 0:
+        raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
+      return ()
+
+    if isinstance(indices, IndexVar):
+      return (indices,)
+    elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
+      return indices
+
+    raise ValueError(f"Expected IndexVars: {indices}")
+
   def __getitem__(self, key) -> "Access":
     """Verifies and processes a tensor access.
 
@@ -936,7 +941,7 @@ def __getitem__(self, key) -> "Access":
     Raises:
       ValueError: If key is not an IndexVar or a tuple of IndexVar.
     """
-    indices = _verify_and_normalize_indices(key)
+    indices = self._verify_and_normalize_indices(key)
     return Access(self, indices)
 
   def __setitem__(self, key, value) -> None:
@@ -960,7 +965,7 @@ def __setitem__(self, key, value) -> None:
         or a tuple of IndexVar, or the length of the indices is not the same as
         the rank of the tensor.
     """
-    indices = _verify_and_normalize_indices(key)
+    indices = self._verify_and_normalize_indices(key)
     if len(indices) != self.order:
       raise ValueError("Mismatch between indices and tensor rank: "
                        f"len({indices}) != {self.order}.")
@@ -985,8 +990,8 @@ def _sync_value(self) -> None:
 
   def mlir_tensor_type(self) -> ir.RankedTensorType:
     """Returns the MLIR type for the tensor."""
-    mlir_attr = None if (
-        self._format is None) else self._format.mlir_tensor_attr()
+    mlir_attr = (None if (self._format is None or self.order == 0) else
+                 self._format.mlir_tensor_attr())
     return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
 
   def dense_dst_ctype_pointer(self) -> ctypes.pointer:


        


More information about the Mlir-commits mailing list