[Mlir-commits] [mlir] 746c68e - [mlir][sparse][taco] Handle tensor copy and trivial reduction expression.
Bixia Zheng
llvmlistbot at llvm.org
Tue Feb 15 15:57:24 PST 2022
Author: Bixia Zheng
Date: 2022-02-15T15:57:18-08:00
New Revision: 746c68eafde341555e59bf5d54602b9d406596f9
URL: https://github.com/llvm/llvm-project/commit/746c68eafde341555e59bf5d54602b9d406596f9
DIFF: https://github.com/llvm/llvm-project/commit/746c68eafde341555e59bf5d54602b9d406596f9.diff
LOG: [mlir][sparse][taco] Handle tensor copy and trivial reduction expression.
Handle tensor copy, such as A[i, j] = B[i, j]. Also, handle trivial
reduction expression, such as A[i] = B[i, j].
Add unit tests.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D119867
Added:
Modified:
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 9d32b2c6accb6..69dedf39c68cf 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -1683,10 +1683,16 @@ def _mark_structured_op_root(
to perform a reduction.
expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
"""
+ expr_info = expr_to_info[expr]
+ if isinstance(expr, Access):
+ # Handle simple reduction expression in the format of A[i] = B[i, j].
+ if reduce_index in expr_info.src_indices:
+ expr_info.reduce_indices.add(reduce_index)
+ return
+
assert (isinstance(expr, _BinaryExpr))
a_info = expr_to_info[expr.a]
b_info = expr_to_info[expr.b]
- expr_info = expr_to_info[expr]
if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
expr_info.reduce_indices.add(reduce_index)
@@ -1724,6 +1730,9 @@ def _accumulate_reduce_indices(
| expr_info.reduce_indices)
else:
assert isinstance(expr, Access)
+ # Handle simple reduction expression in the format of A[i] = B[i, j].
+ expr_info.acc_reduce_indices = expr_info.reduce_indices
+
def _gather_structured_op(
@@ -1821,9 +1830,10 @@ def _gather_structured_op_input(
structop_inputs: The resulting list of IndexExpr that provide input to the
current structured op.
"""
- if (expr != root and expr not in structop_inputs) and (
- isinstance(expr, Access) or
- (expr in expr_to_info and expr_to_info[expr].structop_info)):
+ if ((expr != root or isinstance(expr, Access)) and
+ expr not in structop_inputs) and (isinstance(expr, Access) or
+ (expr in expr_to_info and
+ expr_to_info[expr].structop_info)):
structop_inputs.append(expr)
@@ -1843,7 +1853,7 @@ def _emit_structured_op_input(
An OperandDef in the linalg dialect for the input IndexExpr.
"""
op_info = expr_to_info[expr].structop_info
- if op_info:
+ if op_info and not isinstance(expr, Access):
# The input is a temporary tensor produced by another structured op.
indices = op_info.dst_indices
name = op_info.dst_name
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
index 6b770f7eacc60..8703ef9126c8d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
@@ -37,3 +37,42 @@ def test_tensor_true_dense():
passed += (a.shape[0] == 5)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
+
+
+# CHECK-LABEL: test_tensor_copy
+ at testing_utils.run_test
+def test_tensor_copy():
+ i, j = mlir_pytaco.get_index_vars(2)
+ I = 2
+ J = 3
+ A = mlir_pytaco.Tensor([I, J])
+ A.insert([0, 1], 5.0)
+ A.insert([1, 2], 6.0)
+ B = mlir_pytaco.Tensor([I, J])
+ B[i, j] = A[i, j]
+ indices, values = B.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0, 1], [1, 2]])
+ passed += np.allclose(values, [5.0, 6.0])
+
+ # CHECK: Number of passed: 2
+ print("Number of passed:", passed)
+
+
+# CHECK-LABEL: test_tensor_trivial_reduction
+ at testing_utils.run_test
+def test_tensor_trivial_reduction():
+ i, j = mlir_pytaco.get_index_vars(2)
+ I = 2
+ J = 3
+ A = mlir_pytaco.Tensor([I, J])
+ A.insert([0, 1], 5.0)
+ A.insert([0, 2], 3.0)
+ A.insert([1, 2], 6.0)
+ B = mlir_pytaco.Tensor([I])
+ B[i] = A[i, j]
+ indices, values = B.get_coordinates_and_values()
+ passed = np.array_equal(indices, [[0], [1]])
+ passed += np.allclose(values, [8.0, 6.0])
+
+ # CHECK: Number of passed: 2
+ print("Number of passed:", passed)
More information about the Mlir-commits
mailing list