[Mlir-commits] [mlir] 3580721 - [mlir][sparse][taco] Support the use of index values in tensor expressions.
Bixia Zheng
llvmlistbot at llvm.org
Tue Mar 15 15:31:00 PDT 2022
Author: Bixia Zheng
Date: 2022-03-15T15:30:55-07:00
New Revision: 3580721a59d9a6db365978d176044dc21a7eb3fa
URL: https://github.com/llvm/llvm-project/commit/3580721a59d9a6db365978d176044dc21a7eb3fa
DIFF: https://github.com/llvm/llvm-project/commit/3580721a59d9a6db365978d176044dc21a7eb3fa.diff
LOG: [mlir][sparse][taco] Support the use of index values in tensor expressions.
PyTACO DSL doesn't support the use of index values as in A[i] = B[i]+ i.
We extend the DSL to support such a use in MLIR-PyTACO.
Remove an obsolete unit test. Add unit tests and PyTACO tests.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D121716
Added:
Modified:
mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
index d02bdce285ced..641d1afa01b04 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_simple_tensor_algebra.py
@@ -31,5 +31,31 @@
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20.0, 5.0, 63.0])
-# CHECK: Number of passed: 2
+# PyTACO doesn't allow the use of index values, but MLIR-PyTACO removes this
+# restriction.
+E = pt.tensor([3])
+E[i] = i
+indices, values = E.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0], [1], [2]])
+passed += np.allclose(values, [0.0, 1.0, 2.0])
+
+F = pt.tensor([3])
+G = pt.tensor([3])
+F.insert([0], 10)
+F.insert([2], 40)
+G[i] = F[i] + i
+indices, values = G.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0], [1], [2]])
+passed += np.allclose(values, [10.0, 1.0, 42.0])
+
+H = pt.tensor([3])
+I = pt.tensor([3])
+H.insert([0], 10)
+H.insert([2], 40)
+I[i] = H[i] * i
+indices, values = I.get_coordinates_and_values()
+passed += np.array_equal(indices, [[0], [2]])
+passed += np.allclose(values, [0.0, 80.0])
+
+# CHECK: Number of passed: 8
print("Number of passed:", passed)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index b0cb216944833..28b4ccc809f5b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -611,7 +611,7 @@ def _identify_structured_ops(
# _StructOpInfo for the top level expression.
expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
tuple(dst.shape),
- self.dtype(), dst.name,
+ dst.dtype, dst.name,
dst.format)
return structop_roots
@@ -650,7 +650,7 @@ def _validate_and_collect_expr_info(
raise ValueError("Destination IndexVar not used in the "
f"source expression: {i}")
else:
- if d != index_to_dim_info[i].dim:
+ if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
raise ValueError(f"Inconsistent destination dimension for {i}: "
f"{d} vs {index_to_dim_info[i].dim}")
@@ -739,7 +739,7 @@ def increment(self) -> int:
return old_value
-class IndexVar:
+class IndexVar(IndexExpr):
"""The tensor index class.
We support the TACO API index_var class with an alias of this class.
@@ -763,6 +763,34 @@ def name(self) -> str:
"""Returns the name of the IndexVar."""
return self._name
+ def _visit(self,
+ func: _ExprVisitor,
+ args,
+ *,
+ leaf_checker: _SubtreeLeafChecker = None) -> None:
+ """A post-order visitor."""
+ if leaf_checker:
+ assert leaf_checker(self, *args)
+ func(self, *args)
+
+ def _emit_expression(
+ self,
+ expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+ expr_to_info: _ExprInfoDict,
+ ) -> lang.ScalarExpression:
+ """Emits a index value casted to the data type of the tensor expression."""
+ dim = getattr(lang.D, self.name)
+ index = lang.index(dim)
+ int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
+ return lang.TypeFn.cast_unsigned(lang.T, int_value)
+
+ def dtype(self) -> DType:
+ """Returns the data type for the index value.
+
+ This is unreachable for IndexVar.
+ """
+ assert 0
+
def get_index_vars(n: int) -> List[IndexVar]:
"""Returns a list of n IndexVar.
@@ -1527,6 +1555,11 @@ class _DimInfo:
mode_format: ModeFormat
+def _get_dummy_dim_info() -> _DimInfo:
+ """Constructs the _DimInfo for an index used in tensor expressions."""
+ return _DimInfo(-1, ModeFormat.DENSE)
+
+
@dataclasses.dataclass()
class _ExprInfo:
"""Expression information for validation and code generation.
@@ -1788,9 +1821,12 @@ def _validate_and_collect_dim_info(
if i not in index_to_dim_info:
index_to_dim_info[i] = d
else:
- if d.dim != index_to_dim_info[i].dim:
+ dim = index_to_dim_info[i].dim
+ if dim == -1 or d.dim == -1:
+ dim = dim if dim != -1 else d.dim
+ elif dim != d.dim:
raise ValueError(f"Inconsistent source dimension for {i}: "
- f"{d.dim} vs {index_to_dim_info[i].dim}")
+ f"{d.dim} vs {dim}")
mode_format = _mode_format_estimator(expr.op)(
index_to_dim_info[i].mode_format, d.mode_format)
index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
@@ -1823,7 +1859,10 @@ def _validate_and_collect_expr_info(
if expr in expr_to_info:
return
- if isinstance(expr, Access):
+ if isinstance(expr, IndexVar):
+ src_indices = expr, # A tuple with one element.
+ dim_infos = _get_dummy_dim_info(), # A tuple with one element.
+ elif isinstance(expr, Access):
src_indices = expr.indices
src_dims = tuple(expr.tensor.shape)
if expr.tensor.format is None:
@@ -1883,6 +1922,9 @@ def _mark_structured_op_root(
reduce_index: The IndexVar which we want to find out the proper expression
to perform a reduction.
expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+
+ Raises:
+ ValueError: If the expression is not proper or not supported.
"""
expr_info = expr_to_info[expr]
if isinstance(expr, Access):
@@ -1890,6 +1932,9 @@ def _mark_structured_op_root(
if reduce_index in expr_info.src_indices:
expr_info.reduce_indices.add(reduce_index)
return
+ elif isinstance(expr, IndexVar):
+ # A[i] = B[i] + j is not allowed.
+ raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
assert (isinstance(expr, _BinaryExpr))
a_info = expr_to_info[expr.a]
@@ -1933,6 +1978,11 @@ def _accumulate_reduce_indices(
a_info = expr_to_info[expr.a]
expr_info.acc_reduce_indices = (
a_info.acc_reduce_indices | expr_info.reduce_indices)
+ elif isinstance(expr, IndexVar):
+ # If an IndexVar is reducing itself, it means the IndexVar is outside the
+ # iteration domain. This usage is now allowed and we should emit an error
+ # before reaching here.
+ assert not expr_info.reduce_indices
else:
assert isinstance(expr, Access)
# Handle simple reduction expression in the format of A[i] = B[i, j].
@@ -2011,7 +2061,7 @@ def _is_structured_op_leaf(
"""
return (expr != root and
expr_to_info[expr].structop_info is not None) or isinstance(
- expr, Access)
+ expr, Access) or isinstance(expr, IndexVar)
def _gather_structured_op_input(
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
index 1390f24d1027a..f6289e537ed13 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
@@ -201,14 +201,6 @@ def test_index_var():
C[i, j] = A[1, j] + B[i, j]
"""), "Expected IndexVars")
-# CHECK: test_invalid_operation: passed
-test_expect_error("invalid_operation", ("""
-i, j = mlir_pytaco.get_index_vars(2)
-A = mlir_pytaco.Tensor([2, 3])
-C = mlir_pytaco.Tensor([2, 3], _DENSE)
-C[i, j] = A[i, j] + i
- """), "Expected IndexExpr")
-
# CHECK: test_inconsistent_rank_indices: passed
test_expect_error("inconsistent_rank_indices", ("""
i, j = mlir_pytaco.get_index_vars(2)
@@ -245,6 +237,15 @@ def test_index_var():
C.evaluate()
"""), "Inconsistent source dimension for IndexVar")
+# CHECK: test_index_var_outside_domain: passed
+test_expect_error("index_var_outside_domain", ("""
+i, j = mlir_pytaco.get_index_vars(2)
+A = mlir_pytaco.Tensor([3])
+B = mlir_pytaco.Tensor([3])
+B[i] = A[i] + j
+B.evaluate()
+ """), "IndexVar is not part of the iteration domain")
+
# CHECK-LABEL: test_tensor_all_dense_sparse
@testing_utils.run_test
More information about the Mlir-commits
mailing list