[Mlir-commits] [mlir] 61a3dd7 - [mlir][taco] Use sparse_tensor.out to write sparse tensors to files.
Bixia Zheng
llvmlistbot at llvm.org
Tue Feb 8 08:47:10 PST 2022
Author: Bixia Zheng
Date: 2022-02-08T08:47:05-08:00
New Revision: 61a3dd70ff8dc4470d4a8d766ad09a2707bfb552
URL: https://github.com/llvm/llvm-project/commit/61a3dd70ff8dc4470d4a8d766ad09a2707bfb552
DIFF: https://github.com/llvm/llvm-project/commit/61a3dd70ff8dc4470d4a8d766ad09a2707bfb552.diff
LOG: [mlir][taco] Use sparse_tensor.out to write sparse tensors to files.
Add a Python method, output_sparse_tensor, to use sparse_tensor.out to write
a sparse tensor value to a file.
Modify the method that evaluates a tensor expression to return a pointer of the
MLIR sparse tensor for the result to delay the extraction of the coordinates and
non-zero values.
Implement the Tensor to_file method to evaluate the tensor assignment and write
the result to a file.
Add unit tests. Modify test golden files to reflect the change that TNS outputs
now have a comment line and two meta data lines.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D118956
Added:
mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
Modified:
mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_A.tns
mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_C.tns
mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_y.tns
mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_A.tns b/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_A.tns
index b66caa12106a9..f06646b51feca 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_A.tns
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_A.tns
@@ -1,3 +1,6 @@
+# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
+2 50
+2 25
1 1 12
1 2 12
1 3 12
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_C.tns b/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_C.tns
index e5c1ec14c4030..9f5aec56d500a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_C.tns
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_C.tns
@@ -1,3 +1,6 @@
+# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
+2 9
+3 3
1.0 1.0 100.0
1.0 2.0 107.0
1.0 3.0 114.0
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_y.tns b/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_y.tns
index a9eab90a0627a..832cb1795aaaa 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_y.tns
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_y.tns
@@ -1,4 +1,6 @@
# See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
+1 3
+3
1 37102
2 -20.4138
3 804927
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
index 1fda4f4406393..1e35a85755382 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
@@ -7,7 +7,9 @@
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
+
from tools import mlir_pytaco_api as pt
+from tools import testing_utils as utils
###### This PyTACO part is taken from the TACO open-source project. ######
# See http://tensor-compiler.org/docs/data_analytics/index.html.
@@ -42,12 +44,12 @@
##########################################################################
-# CHECK: Compare result True
# Perform the MTTKRP computation and write the result to file.
with tempfile.TemporaryDirectory() as test_dir:
- actual_file = os.path.join(test_dir, "A.tns")
- pt.write(actual_file, A)
- actual = np.loadtxt(actual_file, np.float64)
- expected = np.loadtxt(
- os.path.join(_SCRIPT_PATH, "data/gold_A.tns"), np.float64)
- print(f"Compare result {np.allclose(actual, expected, rtol=0.01)}")
+ golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
+ out_file = os.path.join(test_dir, "A.tns")
+ pt.write(out_file, A)
+ #
+ # CHECK: Compare result True
+ #
+ print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
index 872c73681c7d0..6092301feaa7a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
@@ -10,6 +10,7 @@
sys.path.append(_SCRIPT_PATH)
from tools import mlir_pytaco_api as pt
+from tools import testing_utils as utils
# Define the CSR format.
csr = pt.format([pt.dense, pt.compressed], [0, 1])
@@ -33,6 +34,6 @@
out_file = os.path.join(test_dir, "C.tns")
pt.write(out_file, C)
#
- # CHECK: Compare files True
+ # CHECK: Compare result True
#
- print(f"Compare files {filecmp.cmp(golden_file, out_file)}")
+ print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
index 80bb023360ff8..41ee71fab4310 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
@@ -7,7 +7,9 @@
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
+
from tools import mlir_pytaco_api as pt
+from tools import testing_utils as utils
###### This PyTACO part is taken from the TACO open-source project. ######
# See http://tensor-compiler.org/docs/scientific_computing/index.html.
@@ -43,12 +45,12 @@
##########################################################################
-# CHECK: Compare result True
# Perform the SpMV computation and write the result to file
with tempfile.TemporaryDirectory() as test_dir:
- actual_file = os.path.join(test_dir, "y.tns")
- pt.write(actual_file, y)
- actual = np.loadtxt(actual_file, np.float64)
- expected = np.loadtxt(
- os.path.join(_SCRIPT_PATH, "data/gold_y.tns"), np.float64)
- print(f"Compare result {np.allclose(actual, expected, rtol=0.01)}")
+ golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
+ out_file = os.path.join(test_dir, "y.tns")
+ pt.write(out_file, y)
+ #
+ # CHECK: Compare result True
+ #
+ print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 24f114dba64a9..2d3b23e5ed864 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -667,6 +667,11 @@ def __init__(self,
"Must be a tuple or list for a shape or a single value"
f"if initializing a scalar tensor: {value_or_shape}.")
+ def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
+ """Records the MLIR sparse tensor pointer."""
+ self._sparse_value_location = _SparseValueInfo._PACKED
+ self._packed_sparse_value = pointer
+
def is_unpacked(self) -> bool:
"""Returns true if the tensor value is not packed as MLIR sparse tensor."""
return (self._sparse_value_location == _SparseValueInfo._UNPACKED)
@@ -826,11 +831,39 @@ def from_file(
sparse_tensor, shape = utils.create_sparse_tensor(filename,
fmt.format_pack.formats)
tensor = Tensor(shape.tolist(), fmt)
- tensor._sparse_value_location = _SparseValueInfo._PACKED
- tensor._packed_sparse_value = sparse_tensor
+ tensor._set_packed_sparse_tensor(sparse_tensor)
return tensor
+ def to_file(self, filename: str) -> None:
+ """Output the tensor value to a file.
+
+ This method evaluates any pending assignment to the tensor and outputs the
+ tensor value.
+
+ Args:
+ filename: A string file name.
+ """
+ self._sync_value()
+ if not self.is_unpacked():
+ utils.output_sparse_tensor(self._packed_sparse_value, filename,
+ self._format.format_pack.formats)
+ return
+
+ # TODO: Use MLIR code to output the value.
+ coords, values = self.get_coordinates_and_values()
+ assert len(coords) == len(values)
+ with open(filename, "w") as file:
+ # Output a comment line and the meta data.
+ file.write("; extended FROSTT format\n")
+ file.write(f"{self.order} {len(coords)}\n")
+ file.write(f"{' '.join(map(lambda i: str(i), self.shape))}\n")
+ # Output each (coordinate value) pair in a line.
+ for c, v in zip(coords, values):
+ # The coordinates are 1-based in the text file and 0-based in memory.
+ plus_one_to_str = lambda x: str(x + 1)
+ file.write(f"{' '.join(map(plus_one_to_str,c))} {v}\n")
+
@property
def dtype(self) -> DType:
"""Returns the data type for the Tensor."""
@@ -914,9 +947,7 @@ def evaluate(self) -> None:
assert isinstance(result, np.ndarray)
self._dense_storage = result
else:
- assert _all_instance_of(result, np.ndarray) and len(result) == 2
- assert (result[0].ndim, result[1].ndim) == (1, 2)
- (self._values, self._coords) = result
+ self._set_packed_sparse_tensor(result)
def _sync_value(self) -> None:
"""Updates the tensor value by evaluating the pending assignment."""
@@ -1349,7 +1380,7 @@ def evaluate(
self,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
- ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
+ ) -> Union[np.ndarray, ctypes.c_void_p]:
"""Evaluates tensor assignment dst[dst_indices] = expression.
Args:
@@ -1357,9 +1388,8 @@ def evaluate(
dst_indices: The tuple of IndexVar used to access the destination tensor.
Returns:
- The result of the dense tensor represented in numpy ndarray or the sparse
- tensor represented by two numpy ndarray for its non-zero values and
- indices.
+ The result of the dense tensor represented in numpy ndarray or the pointer
+ to the MLIR sparse tensor.
Raises:
ValueError: If the expression is not proper or not supported.
@@ -1397,17 +1427,8 @@ def evaluate(
if dst.is_dense():
return runtime.ranked_memref_to_numpy(arg_pointers[0][0])
- # Check and return the sparse tensor output.
- rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
- ctypes.cast(arg_pointers[-1][0], ctypes.c_void_p),
- np.float64,
- )
- assert (np.equal(rank, dst.order)
- and np.array_equal(shape, np.array(dst.shape)) and
- np.equal(values.ndim, 1) and np.equal(values.shape[0], nse) and
- np.equal(indices.ndim, 2) and np.equal(indices.shape[0], nse) and
- np.equal(indices.shape[1], rank))
- return (values, indices)
+ # Return the sparse tensor pointer.
+ return arg_pointers[-1][0]
@dataclasses.dataclass(frozen=True)
@@ -1438,6 +1459,13 @@ def __post_init__(self) -> None:
raise ValueError("Invalid indices for rank: "
f"str{self.tensor.order} != len({str(self.indices)}).")
+ def __repr__(self) -> str:
+ # The Tensor __repr__ method evaluates the pending assignment to the tensor.
+ # We want to define the __repr__ method here to avoid such evaluation of the
+ # tensor assignment.
+ indices_str = ", ".join(map(lambda i: i.name, self.indices))
+ return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})")
+
def _emit_expression(
self,
expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
index 5d446d6af1636..f66eb9b6fdd0e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
@@ -31,17 +31,6 @@
_TNS_FILENAME_SUFFIX = ".tns"
-def _write_tns(file: TextIO, tensor: Tensor) -> None:
- """Outputs a tensor to a file using .tns format."""
- coords, non_zeros = tensor.get_coordinates_and_values()
- assert len(coords) == len(non_zeros)
- # Output a coordinate and the corresponding value in a line.
- for c, v in zip(coords, non_zeros):
- # The coordinates are 1-based in the text file and 0-based in memory.
- plus_one_to_str = lambda x: str(x + 1)
- file.write(f"{' '.join(map(plus_one_to_str,c))} {v}\n")
-
-
def read(filename: str, fmt: Format) -> Tensor:
"""Inputs a tensor from a given file.
@@ -88,7 +77,4 @@ def write(filename: str, tensor: Tensor) -> None:
if not isinstance(tensor, Tensor):
raise ValueError(f"Expected a Tensor object: {tensor}.")
- # TODO: combine the evaluation and the outputing into one step.
- tensor._sync_value()
- with open(filename, "w") as file:
- return _write_tns(file, tensor)
+ tensor.to_file(filename)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index a719cd6d9f63f..c815220815126 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -270,3 +270,67 @@ def create_sparse_tensor(
engine.invoke(_ENTRY_NAME, *arg_pointers)
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
return c_tensor_desc.storage, shape
+
+
+# TODO: With better support from MLIR, we may improve the current implementation
+# by using Python code to generate the kernel instead of doing MLIR text code
+# stitching.
+def _get_output_sparse_tensor_kernel(
+ sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> str:
+ """Creates an MLIR text kernel to output a sparse tensor to a file.
+
+ The kernel returns void.
+ """
+ rank = len(sparsity_codes)
+
+ # Use ? to represent a dimension in the dynamic shape string representation.
+ shape = "x".join(map(lambda d: "?", range(rank)))
+
+ # Convert the encoded sparsity values to a string representation.
+ sparsity = ", ".join(
+ map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
+
+ # Return the MLIR text kernel.
+ return f"""
+!Ptr = type !llvm.ptr<i8>
+#enc = #sparse_tensor.encoding<{{
+ dimLevelType = [ {sparsity} ]
+}}>
+func @{_ENTRY_NAME}(%t: tensor<{shape}xf64, #enc>, %filename: !Ptr)
+attributes {{ llvm.emit_c_interface }} {{
+ sparse_tensor.out %t, %filename : tensor<{shape}xf64, #enc>, !Ptr
+ std.return
+}}"""
+
+
+def output_sparse_tensor(
+ tensor: ctypes.c_void_p, filename: str,
+ sparsity: Sequence[sparse_tensor.DimLevelType]) -> None:
+ """Outputs an MLIR sparse tensor to the given file.
+
+ Args:
+ tensor: A C pointer to the MLIR sparse tensor.
+ filename: A string for the name of the file that contains the tensor data in
+ a COO-flavored format.
+ sparsity: A sequence of DimLevelType values, one for each dimension of the
+ tensor.
+
+ Raises:
+ OSError: If there is any problem in loading the supporting C shared library.
+ ValueError: If the shared library doesn't contain the needed routine.
+ """
+ with ir.Context() as ctx, ir.Location.unknown():
+ module = _get_output_sparse_tensor_kernel(sparsity)
+ module = ir.Module.parse(module)
+ engine = compile_and_build_engine(module)
+
+ # Convert the filename to a byte stream.
+ c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+
+ arg_pointers = [
+ ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
+ ctypes.byref(c_filename)
+ ]
+
+ # Invoke the execution engine to run the module and return the result.
+ engine.invoke(_ENTRY_NAME, *arg_pointers)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
new file mode 100644
index 0000000000000..4e277d60f3eb4
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
@@ -0,0 +1,32 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file contains the utilities to support testing.
+
+import numpy as np
+
+
+def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
+ """Compares sparse tensor actual output file with expected output file.
+
+ This routine assumes the input files are in FROSTT format. See
+ http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
+
+ It also assumes the first line in the output file is a comment line.
+
+ """
+ with open(actual, "r") as actual_f:
+ with open(expected, "r") as expected_f:
+ # Skip the first comment line.
+ _ = actual_f.readline()
+ _ = expected_f.readline()
+
+ # Compare the two lines of meta data
+ if actual_f.readline() != expected_f.readline() or actual_f.readline(
+ ) != expected_f.readline():
+ return FALSE
+
+ actual_data = np.loadtxt(actual, np.float64, skiprows=3)
+ expected_data = np.loadtxt(expected, np.float64, skiprows=3)
+ return np.allclose(actual_data, expected_data, rtol=rtol)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
index 1466dc841dcd6..87246fd65f289 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
@@ -108,3 +108,46 @@ def test_read_tns():
passed += np.allclose(values, [2.0, 3.0, 4.0])
# CHECK: 4
print(passed)
+
+
+# CHECK-LABEL: test_write_unpacked_tns
+ at _run
+def test_write_unpacked_tns():
+ a = mlir_pytaco.Tensor([2, 3])
+ a.insert([0, 1], 10)
+ a.insert([1, 2], 40)
+ a.insert([0, 0], 20)
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.tns")
+ mlir_pytaco_io.write(file_name, a)
+ with open(file_name, "r") as file:
+ lines = file.readlines()
+ passed = 0
+ # Skip the comment line in the output.
+ if lines[1:] == ["2 3\n", "2 3\n", "1 2 10.0\n", "2 3 40.0\n", "1 1 20.0\n"]:
+ passed = 1
+ # CHECK: 1
+ print(passed)
+
+
+# CHECK-LABEL: test_write_packed_tns
+ at _run
+def test_write_packed_tns():
+ a = mlir_pytaco.Tensor([2, 3])
+ a.insert([0, 1], 10)
+ a.insert([1, 2], 40)
+ a.insert([0, 0], 20)
+ b = mlir_pytaco.Tensor([2, 3])
+ i, j = mlir_pytaco.get_index_vars(2)
+ b[i, j] = a[i, j] + a[i, j]
+ with tempfile.TemporaryDirectory() as test_dir:
+ file_name = os.path.join(test_dir, "data.tns")
+ mlir_pytaco_io.write(file_name, b)
+ with open(file_name, "r") as file:
+ lines = file.readlines()
+ passed = 0
+ # Skip the comment line in the output.
+ if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
+ passed = 1
+ # CHECK: 1
+ print(passed)
More information about the Mlir-commits
mailing list