[Mlir-commits] [mlir] 3580721 - [mlir][sparse][taco] Support the use of index values in tensor expressions.

Bixia Zheng llvmlistbot at llvm.org
Tue Mar 15 15:31:00 PDT 2022


Author: Bixia Zheng
Date: 2022-03-15T15:30:55-07:00
New Revision: 3580721a59d9a6db365978d176044dc21a7eb3fa

URL: https://github.com/llvm/llvm-project/commit/3580721a59d9a6db365978d176044dc21a7eb3fa
DIFF: https://github.com/llvm/llvm-project/commit/3580721a59d9a6db365978d176044dc21a7eb3fa.diff

LOG: [mlir][sparse][taco] Support the use of index values in tensor expressions.

PyTACO DSL doesn't support the use of index values as in A[i] = B[i]+ i.
We extend the DSL to support such a use in MLIR-PyTACO.

Remove an obsolete unit test. Add unit tests and PyTACO tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D121716

Added: 
    

Modified: 
    mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
index d02bdce285ced..641d1afa01b04 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
@@ -31,5 +31,31 @@
 passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
 passed += np.allclose(values, [20.0, 5.0, 63.0])
 
-# CHECK: Number of passed: 2
+# PyTACO doesn't allow the use of index values, but MLIR-PyTACO removes this
+# restriction.
+E = pt.tensor([3])
+E[i] = i
+indices, values = E.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0], [1], [2]])
+passed += np.allclose(values, [0.0, 1.0, 2.0])
+
+F = pt.tensor([3])
+G = pt.tensor([3])
+F.insert([0], 10)
+F.insert([2], 40)
+G[i] = F[i] + i
+indices, values = G.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0], [1], [2]])
+passed += np.allclose(values, [10.0, 1.0, 42.0])
+
+H = pt.tensor([3])
+I = pt.tensor([3])
+H.insert([0], 10)
+H.insert([2], 40)
+I[i] = H[i] * i
+indices, values = I.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0], [2]])
+passed += np.allclose(values, [0.0, 80.0])
+
+# CHECK: Number of passed: 8
 print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index b0cb216944833..28b4ccc809f5b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -611,7 +611,7 @@ def _identify_structured_ops(
     # _StructOpInfo for the top level expression.
     expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
                                                      tuple(dst.shape),
-                                                     self.dtype(), dst.name,
+                                                     dst.dtype, dst.name,
                                                      dst.format)
 
     return structop_roots
@@ -650,7 +650,7 @@ def _validate_and_collect_expr_info(
         raise ValueError("Destination IndexVar not used in the "
                          f"source expression: {i}")
       else:
-        if d != index_to_dim_info[i].dim:
+        if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
           raise ValueError(f"Inconsistent destination dimension for {i}: "
                            f"{d} vs {index_to_dim_info[i].dim}")
 
@@ -739,7 +739,7 @@ def increment(self) -> int:
     return old_value
 
 
-class IndexVar:
+class IndexVar(IndexExpr):
   """The tensor index class.
 
   We support the TACO API index_var class with an alias of this class.
@@ -763,6 +763,34 @@ def name(self) -> str:
     """Returns the name of the IndexVar."""
     return self._name
 
+  def _visit(self,
+             func: _ExprVisitor,
+             args,
+             *,
+             leaf_checker: _SubtreeLeafChecker = None) -> None:
+    """A post-order visitor."""
+    if leaf_checker:
+      assert leaf_checker(self, *args)
+    func(self, *args)
+
+  def _emit_expression(
+      self,
+      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+      expr_to_info: _ExprInfoDict,
+  ) -> lang.ScalarExpression:
+    """Emits a index value casted to the data type of the tensor expression."""
+    dim = getattr(lang.D, self.name)
+    index = lang.index(dim)
+    int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
+    return lang.TypeFn.cast_unsigned(lang.T, int_value)
+
+  def dtype(self) -> DType:
+    """Returns the data type for the index value.
+
+    This is unreachable for IndexVar.
+    """
+    assert 0
+
 
 def get_index_vars(n: int) -> List[IndexVar]:
   """Returns a list of n IndexVar.
@@ -1527,6 +1555,11 @@ class _DimInfo:
   mode_format: ModeFormat
 
 
+def _get_dummy_dim_info() -> _DimInfo:
+  """Constructs the _DimInfo for an index used in tensor expressions."""
+  return _DimInfo(-1, ModeFormat.DENSE)
+
+
 @dataclasses.dataclass()
 class _ExprInfo:
   """Expression information for validation and code generation.
@@ -1788,9 +1821,12 @@ def _validate_and_collect_dim_info(
     if i not in index_to_dim_info:
       index_to_dim_info[i] = d
     else:
-      if d.dim != index_to_dim_info[i].dim:
+      dim = index_to_dim_info[i].dim
+      if dim == -1 or d.dim == -1:
+        dim = dim if dim != -1 else d.dim
+      elif dim != d.dim:
         raise ValueError(f"Inconsistent source dimension for {i}: "
-                         f"{d.dim} vs {index_to_dim_info[i].dim}")
+                         f"{d.dim} vs {dim}")
       mode_format = _mode_format_estimator(expr.op)(
           index_to_dim_info[i].mode_format, d.mode_format)
       index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
@@ -1823,7 +1859,10 @@ def _validate_and_collect_expr_info(
   if expr in expr_to_info:
     return
 
-  if isinstance(expr, Access):
+  if isinstance(expr, IndexVar):
+    src_indices = expr,  # A tuple with one element.
+    dim_infos = _get_dummy_dim_info(),  # A tuple with one element.
+  elif isinstance(expr, Access):
     src_indices = expr.indices
     src_dims = tuple(expr.tensor.shape)
     if expr.tensor.format is None:
@@ -1883,6 +1922,9 @@ def _mark_structured_op_root(
     reduce_index: The IndexVar which we want to find out the proper expression
       to perform a reduction.
     expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+
+  Raises:
+      ValueError: If the expression is not proper or not supported.
   """
   expr_info = expr_to_info[expr]
   if isinstance(expr, Access):
@@ -1890,6 +1932,9 @@ def _mark_structured_op_root(
     if reduce_index in expr_info.src_indices:
       expr_info.reduce_indices.add(reduce_index)
     return
+  elif isinstance(expr, IndexVar):
+    # A[i] = B[i] + j is not allowed.
+    raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
 
   assert (isinstance(expr, _BinaryExpr))
   a_info = expr_to_info[expr.a]
@@ -1933,6 +1978,11 @@ def _accumulate_reduce_indices(
     a_info = expr_to_info[expr.a]
     expr_info.acc_reduce_indices = (
         a_info.acc_reduce_indices | expr_info.reduce_indices)
+  elif isinstance(expr, IndexVar):
+    # If an IndexVar is reducing itself, it means the IndexVar is outside the
+    # iteration domain. This usage is now allowed and we should emit an error
+    # before reaching here.
+    assert not expr_info.reduce_indices
   else:
     assert isinstance(expr, Access)
     # Handle simple reduction expression in the format of A[i] = B[i, j].
@@ -2011,7 +2061,7 @@ def _is_structured_op_leaf(
   """
   return (expr != root and
           expr_to_info[expr].structop_info is not None) or isinstance(
-              expr, Access)
+              expr, Access) or isinstance(expr, IndexVar)
 
 
 def _gather_structured_op_input(

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
index 1390f24d1027a..f6289e537ed13 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
@@ -201,14 +201,6 @@ def test_index_var():
 C[i, j] = A[1, j] + B[i, j]
     """), "Expected IndexVars")
 
-# CHECK: test_invalid_operation: passed
-test_expect_error("invalid_operation", ("""
-i, j = mlir_pytaco.get_index_vars(2)
-A = mlir_pytaco.Tensor([2, 3])
-C = mlir_pytaco.Tensor([2, 3], _DENSE)
-C[i, j] = A[i, j] + i
-    """), "Expected IndexExpr")
-
 # CHECK: test_inconsistent_rank_indices: passed
 test_expect_error("inconsistent_rank_indices", ("""
 i, j = mlir_pytaco.get_index_vars(2)
@@ -245,6 +237,15 @@ def test_index_var():
 C.evaluate()
     """), "Inconsistent source dimension for IndexVar")
 
+# CHECK: test_index_var_outside_domain: passed
+test_expect_error("index_var_outside_domain", ("""
+i, j = mlir_pytaco.get_index_vars(2)
+A = mlir_pytaco.Tensor([3])
+B = mlir_pytaco.Tensor([3])
+B[i] = A[i] + j
+B.evaluate()
+    """), "IndexVar is not part of the iteration domain")
+
 
 # CHECK-LABEL: test_tensor_all_dense_sparse
 @testing_utils.run_test


        


More information about the Mlir-commits mailing list