[Mlir-commits] [mlir] 5b87e05 - [mlir][sparse][taco] Split the evaluate method into compile and compute.
Bixia Zheng
llvmlistbot at llvm.org
Mon Mar 7 16:58:50 PST 2022
Author: Bixia Zheng
Date: 2022-03-07T16:58:41-08:00
New Revision: 5b87e0521d648e697c684238f7236fc4c7a04ed8
URL: https://github.com/llvm/llvm-project/commit/5b87e0521d648e697c684238f7236fc4c7a04ed8
DIFF: https://github.com/llvm/llvm-project/commit/5b87e0521d648e697c684238f7236fc4c7a04ed8.diff
LOG: [mlir][sparse][taco] Split the evaluate method into compile and compute.
This is to align with the PyTACO API better.
Modify an existing unit test to test the new routines.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D121083
Added:
Modified:
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index da521bd33504e..2c452f1725258 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -30,6 +30,7 @@
import threading
# Import MLIR related modules.
+from mlir import execution_engine
from mlir import ir
from mlir import runtime
from mlir.dialects import arith
@@ -644,6 +645,7 @@ def __init__(self,
dtype = dtype or DType(Type.FLOAT32)
self._name = name or self._get_unique_name()
self._assignment = None
+ self._engine = None
self._sparse_value_location = _SparseValueInfo._UNPACKED
self._dense_storage = None
self._dtype = dtype
@@ -978,17 +980,72 @@ def __setitem__(self, key, value) -> None:
f"len({indices}) != {self.order}.")
self._assignment = _Assignment(indices, value)
+ self._engine = None
- def evaluate(self) -> None:
- """Evaluates the assignment to the tensor."""
- result = self._assignment.expression.evaluate(self,
- self._assignment.indices)
- self._assignment = None
+ def compile(self, force_recompile: bool = False) -> None:
+ """Compiles the tensor assignment to an execution engine.
+
+ Calling compile the second time does not do anything unless
+ force_recompile is True.
+
+ Args:
+ force_recompile: A boolean value to enable recompilation, such as for the
+ purpose of timing.
+
+ Raises:
+ ValueError: If the assignment is not proper or not supported.
+ """
+ if self._assignment is None or (self._engine is not None and
+ not force_recompile):
+ return
+
+ self._engine = self._assignment.expression.compile(self,
+ self._assignment.indices)
+
+ def compute(self) -> None:
+ """Executes the engine for the tensor assignment.
+
+ Raises:
+ ValueError: If the assignment hasn't been compiled yet.
+ """
+ if self._assignment is None:
+ return
+
+ if self._engine is None:
+ raise ValueError("Need to invoke compile() before invoking compute().")
+
+ input_accesses = self._assignment.expression.get_input_accesses()
+ # Gather the pointers for the input buffers.
+ input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
+ if self.is_dense():
+ # The pointer to receive dense output is the first argument to the
+ # execution engine.
+ arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
+ else:
+ # The pointer to receive the sparse tensor output is the last argument
+ # to the execution engine and is a pointer to pointer of char.
+ arg_pointers = input_pointers + [
+ ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
+ ]
+
+ # Invoke the execution engine to run the module.
+ self._engine.invoke(_ENTRY_NAME, *arg_pointers)
+
+ # Retrieve the result.
if self.is_dense():
+ result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
assert isinstance(result, np.ndarray)
self._dense_storage = result
else:
- self._set_packed_sparse_tensor(result)
+ self._set_packed_sparse_tensor(arg_pointers[-1][0])
+
+ self._assignment = None
+ self._engine = None
+
+ def evaluate(self) -> None:
+ """Evaluates the tensor assignment."""
+ self.compile()
+ self.compute()
def _sync_value(self) -> None:
"""Updates the tensor value by evaluating the pending assignment."""
@@ -1444,29 +1501,31 @@ def linalg_funcop(*args):
linalg_funcop.func_op.attributes[
"llvm.emit_c_interface"] = ir.UnitAttr.get()
- def evaluate(
+ def get_input_accesses(self) -> List["Access"]:
+ """Compute the list of input accesses for the expression."""
+ input_accesses = []
+ self._visit(_gather_input_accesses_index_vars, (input_accesses,))
+ return input_accesses
+
+ def compile(
self,
dst: Tensor,
dst_indices: Tuple[IndexVar, ...],
- ) -> Union[np.ndarray, ctypes.c_void_p]:
- """Evaluates tensor assignment dst[dst_indices] = expression.
+ ) -> execution_engine.ExecutionEngine:
+ """Compiles the tensor assignment dst[dst_indices] = expression.
Args:
dst: The destination tensor.
dst_indices: The tuple of IndexVar used to access the destination tensor.
Returns:
- The result of the dense tensor represented in numpy ndarray or the pointer
- to the MLIR sparse tensor.
+ The execution engine for the tensor assignment.
Raises:
ValueError: If the expression is not proper or not supported.
"""
expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
-
- # Compute a list of input accesses.
- input_accesses = []
- self._visit(_gather_input_accesses_index_vars, (input_accesses,))
+ input_accesses = self.get_input_accesses()
# Build and compile the module to produce the execution engine.
with ir.Context(), ir.Location.unknown():
@@ -1475,29 +1534,7 @@ def evaluate(
input_accesses)
engine = utils.compile_and_build_engine(module)
- # Gather the pointers for the input buffers.
- input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
- if dst.is_dense():
- # The pointer to receive dense output is the first argument to the
- # execution engine.
- arg_pointers = [dst.dense_dst_ctype_pointer()] + input_pointers
- else:
- # The pointer to receive sparse output is the last argument to the
- # execution engine. The pointer to receive a sparse tensor output is a
- # pointer to pointer of char.
- arg_pointers = input_pointers + [
- ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
- ]
-
- # Invoke the execution engine to run the module and return the result.
- engine.invoke(_ENTRY_NAME, *arg_pointers)
-
- if dst.is_dense():
- return runtime.ranked_memref_to_numpy(arg_pointers[0][0])
-
- # Return the sparse tensor pointer.
- return arg_pointers[-1][0]
-
+ return engine
@dataclasses.dataclass(frozen=True)
class Access(IndexExpr):
diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
index 6b539ad06d1d4..1390f24d1027a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
@@ -279,13 +279,23 @@ def test_tensor_copy():
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I, J])
B[i, j] = A[i, j]
+ passed = (B._assignment is not None)
+ passed += (B._engine is None)
+ try:
+ B.compute()
+ except ValueError as e:
+ passed += (str(e).startswith("Need to invoke compile"))
+ B.compile()
+ passed += (B._engine is not None)
+ B.compute()
+ passed += (B._assignment is None)
+ passed += (B._engine is None)
indices, values = B.get_coordinates_and_values()
- passed = np.array_equal(indices, [[0, 1], [1, 2]])
+ passed += np.array_equal(indices, [[0, 1], [1, 2]])
passed += np.allclose(values, [5.0, 6.0])
# No temporary tensor is used.
passed += (B._stats.get_total() == 0)
-
- # CHECK: Number of passed: 3
+ # CHECK: Number of passed: 9
print("Number of passed:", passed)
More information about the Mlir-commits
mailing list