[Mlir-commits] [mlir] 1cd13e6 - [mlir][sparse][taco] Support more data types.

Bixia Zheng llvmlistbot at llvm.org
Wed May 4 10:05:30 PDT 2022


Author: Bixia Zheng
Date: 2022-05-04T10:05:20-07:00
New Revision: 1cd13e6e9851b8c933603f6bd0236690b4f0df90

URL: https://github.com/llvm/llvm-project/commit/1cd13e6e9851b8c933603f6bd0236690b4f0df90
DIFF: https://github.com/llvm/llvm-project/commit/1cd13e6e9851b8c933603f6bd0236690b4f0df90.diff

LOG: [mlir][sparse][taco] Support more data types.

Support int8, int16, int32 and int32. Also fix source code format in mlir_pytaco_utils.py.

Add tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D124925

Added: 
    mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py

Modified: 
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
new file mode 100644
index 0000000000000..b4fee5269c427
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
@@ -0,0 +1,33 @@
+# RUN: SUPPORTLIB=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext %PYTHON %s | FileCheck %s
+
+import numpy as np
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import mlir_pytaco_api as pt
+
+compressed = pt.compressed
+dense = pt.dense
+
+passed = 0
+all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float32, pt.float64]
+for t in all_types:
+  i, j = pt.get_index_vars(2)
+  A = pt.tensor([2, 3], dtype=t)
+  B = pt.tensor([2, 3], dtype=t)
+  C = pt.tensor([2, 3], compressed, dtype=t)
+  A.insert([0, 1], 10)
+  A.insert([1, 2], 40)
+  B.insert([0, 0], 20)
+  B.insert([1, 2], 30)
+  C[i, j] = A[i, j] + B[i, j]
+
+  indices, values = C.get_coordinates_and_values()
+  passed += isinstance(values[0], t.value)
+  passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+  passed += np.allclose(values, [20.0, 10.0, 70.0])
+
+# CHECK: Number of passed: 18
+print("Number of passed:", passed)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 8c77c8d55fc71..220a9ce5f9ec4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -67,6 +67,7 @@ class Type(enum.Enum):
 
   We use numpy data types to implement the enum data types.
   """
+  INT8 = np.int8
   INT16 = np.int16
   INT32 = np.int32
   INT64 = np.int64
@@ -78,10 +79,11 @@ class Type(enum.Enum):
 # All floating point type enums.
 _FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
 # All integral type enums.
-_INT_TYPES = (Type.INT16, Type.INT32, Type.INT64)
+_INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
 # Type alias for any numpy type used to implement the runtime support for the
 # enum data types.
-_AnyRuntimeType = Union[np.int16, np.int32, np.int64, np.float32, np.float64]
+_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32,
+                        np.float64]
 
 
 @dataclasses.dataclass(frozen=True)
@@ -117,6 +119,7 @@ def value(self) -> _AnyRuntimeType:
 def _dtype_to_mlir_str(dtype: DType) -> str:
   """Returns the MLIR string for the given dtype."""
   dtype_to_str = {
+      Type.INT16: "i8",
       Type.INT16: "i16",
       Type.INT32: "i32",
       Type.INT64: "i64",
@@ -129,6 +132,7 @@ def _dtype_to_mlir_str(dtype: DType) -> str:
 def _nptype_to_taco_type(ty: np.dtype) -> DType:
   """Returns the TACO type for the given numpy type."""
   nptype_to_dtype = {
+      np.int8: Type.INT8,
       np.int16: Type.INT16,
       np.int32: Type.INT32,
       np.int64: Type.INT64,
@@ -141,6 +145,7 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType:
 def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
   """Returns the MLIR type corresponding to the given TACO type."""
   dtype_to_irtype = {
+      Type.INT8: ir.IntegerType.get_signless(8),
       Type.INT16: ir.IntegerType.get_signless(16),
       Type.INT32: ir.IntegerType.get_signless(32),
       Type.INT64: ir.IntegerType.get_signless(64),

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
index d6072a407cf52..7573d6655360f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_api.py
@@ -35,6 +35,7 @@
 access = mlir_pytaco.Access
 
 # Data type constants defined by PyTACO API.
+int8 = mlir_pytaco.DType(mlir_pytaco.Type.INT8)
 int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16)
 int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32)
 int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
index 50d570f63b5cc..dce0d9e684111 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
@@ -37,29 +37,29 @@
 
 @functools.lru_cache()
 def _get_support_lib_name() -> str:
-    """Gets the string name for the supporting C shared library."""
-    return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
+  """Gets the string name for the supporting C shared library."""
+  return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
 
 
 @functools.lru_cache()
 def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
-    """Gets the MLIR sparse compiler with default setting."""
-    return mlir_sparse_compiler.SparseCompiler(
-        options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
+  """Gets the MLIR sparse compiler with default setting."""
+  return mlir_sparse_compiler.SparseCompiler(
+      options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
 
 
 def _record_support_funcs(
     ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
     ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
-    """Records the two supporting functions for a given data type."""
-    to_func.restype = ctypes.c_void_p
-    from_func.restype = ctypes.c_void_p
-    ty_to_funcs[ty] = (to_func, from_func)
+  """Records the two supporting functions for a given data type."""
+  to_func.restype = ctypes.c_void_p
+  from_func.restype = ctypes.c_void_p
+  ty_to_funcs[ty] = (to_func, from_func)
 
 
 @functools.lru_cache()
 def _get_support_func_locator() -> _SupportFuncLocator:
-    """Constructs a function to locate the supporting functions for a data type.
+  """Constructs a function to locate the supporting functions for a data type.
 
   Loads the supporting C shared library with the needed routines. Constructs a
   dictionary from the supported data types to the routines for the data types,
@@ -75,36 +75,42 @@ def _get_support_func_locator() -> _SupportFuncLocator:
     OSError: If there is any problem in loading the shared library.
     ValueError: If the shared library doesn't contain the needed routines.
   """
-    # This raises OSError exception if there is any problem in loading the shared
-    # library.
-    c_lib = ctypes.CDLL(_get_support_lib_name())
-
-    type_to_funcs = {}
-    try:
-        _record_support_funcs(np.float32, c_lib.convertToMLIRSparseTensorF32,
-                              c_lib.convertFromMLIRSparseTensorF32, type_to_funcs)
-    except Exception as e:
-        raise ValueError(f"Missing supporting function: {e}") from e
-
-    try:
-        _record_support_funcs(np.float64, c_lib.convertToMLIRSparseTensorF64,
-                              c_lib.convertFromMLIRSparseTensorF64, type_to_funcs)
-    except Exception as e:
-        raise ValueError(f"Missing supporting function: {e}") from e
-
-    def get_support_funcs(ty: np.dtype):
-        funcs = type_to_funcs[ty]
-        assert funcs is not None
-        return funcs
-
-    return get_support_funcs
+  # This raises OSError exception if there is any problem in loading the shared
+  # library.
+  c_lib = ctypes.CDLL(_get_support_lib_name())
+
+  type_to_funcs = {}
+  try:
+    support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8,
+                      c_lib.convertFromMLIRSparseTensorI8),
+                     (np.int16, c_lib.convertToMLIRSparseTensorI16,
+                      c_lib.convertFromMLIRSparseTensorI16),
+                     (np.int32, c_lib.convertToMLIRSparseTensorI32,
+                      c_lib.convertFromMLIRSparseTensorI32),
+                     (np.int64, c_lib.convertToMLIRSparseTensorI64,
+                      c_lib.convertFromMLIRSparseTensorI64),
+                     (np.float32, c_lib.convertToMLIRSparseTensorF32,
+                      c_lib.convertFromMLIRSparseTensorF32),
+                     (np.float64, c_lib.convertToMLIRSparseTensorF64,
+                      c_lib.convertFromMLIRSparseTensorF64)]
+  except Exception as e:
+    raise ValueError(f"Missing supporting function: {e}") from e
+  for i, info in enumerate(support_types):
+    _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
+
+  def get_support_funcs(ty: np.dtype):
+    funcs = type_to_funcs[ty]
+    assert funcs is not None
+    return funcs
+
+  return get_support_funcs
 
 
 def sparse_tensor_to_coo_tensor(
     sparse_tensor: ctypes.c_void_p,
     dtype: np.dtype,
 ) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
-    """Converts an MLIR sparse tensor to a COO-flavored format tensor.
+  """Converts an MLIR sparse tensor to a COO-flavored format tensor.
 
   Args:
      sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
@@ -124,26 +130,26 @@ def sparse_tensor_to_coo_tensor(
     OSError: If there is any problem in loading the shared library.
     ValueError: If the shared library doesn't contain the needed routines.
   """
-    convert_from = _get_support_func_locator()(dtype)[1]
-    rank = ctypes.c_ulonglong(0)
-    nse = ctypes.c_ulonglong(0)
-    shape = ctypes.POINTER(ctypes.c_ulonglong)()
-    values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
-    indices = ctypes.POINTER(ctypes.c_ulonglong)()
-    convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
-                 ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
-
-    # Convert the returned values to the corresponding numpy types.
-    shape = np.ctypeslib.as_array(shape, shape=[rank.value])
-    values = np.ctypeslib.as_array(values, shape=[nse.value])
-    indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
-    return rank.value, nse.value, shape, values, indices
+  convert_from = _get_support_func_locator()(dtype)[1]
+  rank = ctypes.c_ulonglong(0)
+  nse = ctypes.c_ulonglong(0)
+  shape = ctypes.POINTER(ctypes.c_ulonglong)()
+  values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
+  indices = ctypes.POINTER(ctypes.c_ulonglong)()
+  convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
+               ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
+
+  # Convert the returned values to the corresponding numpy types.
+  shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+  values = np.ctypeslib.as_array(values, shape=[nse.value])
+  indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+  return rank.value, nse.value, shape, values, indices
 
 
 def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
                                 np_indices: np.ndarray, np_perm: np.ndarray,
                                 np_sparse: np.ndarray) -> int:
-    """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
+  """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
 
   Args:
      np_shape: A 1D numpy array of integers, for the shape of the tensor.
@@ -164,26 +170,26 @@ def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
     ValueError: If the shared library doesn't contain the needed routines.
   """
 
-    r = len(np_shape)
-    rank = ctypes.c_ulonglong(r)
-    nse = ctypes.c_ulonglong(len(np_values))
-    shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-    values = np_values.ctypes.data_as(
-        ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype)))
-    indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+  r = len(np_shape)
+  rank = ctypes.c_ulonglong(r)
+  nse = ctypes.c_ulonglong(len(np_values))
+  shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+  values = np_values.ctypes.data_as(
+      ctypes.POINTER(np.ctypeslib.as_ctypes_type(np_values.dtype)))
+  indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
 
-    perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-    sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
+  perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+  sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
 
-    convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
-    ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
-    assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
-    return ptr
+  convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
+  ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
+  assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
+  return ptr
 
 
 def compile_and_build_engine(
     module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Compiles an MLIR module and builds a JIT execution engine.
+  """Compiles an MLIR module and builds a JIT execution engine.
 
   Args:
     module: The MLIR module.
@@ -192,22 +198,22 @@ def compile_and_build_engine(
     A JIT execution engine for the MLIR module.
 
   """
-    return _get_sparse_compiler().compile_and_jit(module)
+  return _get_sparse_compiler().compile_and_jit(module)
 
 
 class _SparseTensorDescriptor(ctypes.Structure):
-    """A C structure for an MLIR sparse tensor."""
-    _fields_ = [
-        # A pointer for the MLIR sparse tensor storage.
-        ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
-        # An MLIR MemRef descriptor for the shape of the sparse tensor.
-        ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
-    ]
+  """A C structure for an MLIR sparse tensor."""
+  _fields_ = [
+      # A pointer for the MLIR sparse tensor storage.
+      ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
+      # An MLIR MemRef descriptor for the shape of the sparse tensor.
+      ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
+  ]
 
 
 def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
-    """Produces the MLIR text code to output the size for the given dimension."""
-    return f"""
+  """Produces the MLIR text code to output the size for the given dimension."""
+  return f"""
   %c{dim} = arith.constant {dim} : index
   %d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc>
   memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
@@ -222,25 +228,25 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
 #     when tensor.dim supports non-constant dimension value.
 def _get_create_sparse_tensor_kernel(
     sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str:
-    """Creates an MLIR text kernel to contruct a sparse tensor from a file.
+  """Creates an MLIR text kernel to contruct a sparse tensor from a file.
 
   The kernel returns a _SparseTensorDescriptor structure.
   """
-    rank = len(sparsity_codes)
+  rank = len(sparsity_codes)
 
-    # Use ? to represent a dimension in the dynamic shape string representation.
-    shape = "x".join(map(lambda d: "?", range(rank)))
+  # Use ? to represent a dimension in the dynamic shape string representation.
+  shape = "x".join(map(lambda d: "?", range(rank)))
 
-    # Convert the encoded sparsity values to a string representation.
-    sparsity = ", ".join(
-        map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
+  # Convert the encoded sparsity values to a string representation.
+  sparsity = ", ".join(
+      map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
 
-    # Get the MLIR text code to write the dimension sizes to the output buffer.
-    output_dims = "\n".join(
-        map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
+  # Get the MLIR text code to write the dimension sizes to the output buffer.
+  output_dims = "\n".join(
+      map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
 
-    # Return the MLIR text kernel.
-    return f"""
+  # Return the MLIR text kernel.
+  return f"""
 !Ptr = type !llvm.ptr<i8>
 #enc = #sparse_tensor.encoding<{{
   dimLevelType = [ {sparsity} ]
@@ -257,7 +263,7 @@ def _get_create_sparse_tensor_kernel(
 def create_sparse_tensor(filename: str,
                          sparsity: Sequence[sparse_tensor.DimLevelType],
                          type: str) -> Tuple[ctypes.c_void_p, np.ndarray]:
-    """Creates an MLIR sparse tensor from the input file.
+  """Creates an MLIR sparse tensor from the input file.
 
   Args:
     filename: A string for the name of the file that contains the tensor data in
@@ -274,25 +280,25 @@ def create_sparse_tensor(filename: str,
     OSError: If there is any problem in loading the supporting C shared library.
     ValueError:  If the shared library doesn't contain the needed routine.
   """
-    with ir.Context() as ctx, ir.Location.unknown():
-        module = _get_create_sparse_tensor_kernel(sparsity, type)
-        module = ir.Module.parse(module)
-        engine = compile_and_build_engine(module)
+  with ir.Context() as ctx, ir.Location.unknown():
+    module = _get_create_sparse_tensor_kernel(sparsity, type)
+    module = ir.Module.parse(module)
+    engine = compile_and_build_engine(module)
 
-    # A sparse tensor descriptor to receive the kernel result.
-    c_tensor_desc = _SparseTensorDescriptor()
-    # Convert the filename to a byte stream.
-    c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+  # A sparse tensor descriptor to receive the kernel result.
+  c_tensor_desc = _SparseTensorDescriptor()
+  # Convert the filename to a byte stream.
+  c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
 
-    arg_pointers = [
-        ctypes.byref(ctypes.pointer(c_tensor_desc)),
-        ctypes.byref(c_filename)
-    ]
+  arg_pointers = [
+      ctypes.byref(ctypes.pointer(c_tensor_desc)),
+      ctypes.byref(c_filename)
+  ]
 
-    # Invoke the execution engine to run the module and return the result.
-    engine.invoke(_ENTRY_NAME, *arg_pointers)
-    shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
-    return c_tensor_desc.storage, shape
+  # Invoke the execution engine to run the module and return the result.
+  engine.invoke(_ENTRY_NAME, *arg_pointers)
+  shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
+  return c_tensor_desc.storage, shape
 
 
 # TODO: With better support from MLIR, we may improve the current implementation
@@ -301,22 +307,22 @@ def create_sparse_tensor(filename: str,
 def _get_output_sparse_tensor_kernel(
         sparsity_codes: Sequence[sparse_tensor.DimLevelType],
         type: str) -> str:
-    """Creates an MLIR text kernel to output a sparse tensor to a file.
+  """Creates an MLIR text kernel to output a sparse tensor to a file.
 
   The kernel returns void.
   """
-    rank = len(sparsity_codes)
+  rank = len(sparsity_codes)
 
-    # Use ? to represent a dimension in the dynamic shape string representation.
-    shape = "x".join(map(lambda d: "?", range(rank)))
+  # Use ? to represent a dimension in the dynamic shape string representation.
+  shape = "x".join(map(lambda d: "?", range(rank)))
 
-    # Convert the encoded sparsity values to a string representation.
-    sparsity = ", ".join(
-        map(lambda s: '"compressed"'
-            if s.value else '"dense"', sparsity_codes))
+  # Convert the encoded sparsity values to a string representation.
+  sparsity = ", ".join(
+      map(lambda s: '"compressed"'
+          if s.value else '"dense"', sparsity_codes))
 
-    # Return the MLIR text kernel.
-    return f"""
+  # Return the MLIR text kernel.
+  return f"""
 !Ptr = type !llvm.ptr<i8>
 #enc = #sparse_tensor.encoding<{{
   dimLevelType = [ {sparsity} ]
@@ -331,7 +337,7 @@ def _get_output_sparse_tensor_kernel(
 def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
                          sparsity: Sequence[sparse_tensor.DimLevelType],
                          type: str) -> None:
-    """Outputs an MLIR sparse tensor to the given file.
+  """Outputs an MLIR sparse tensor to the given file.
 
   Args:
     tensor: A C pointer to the MLIR sparse tensor.
@@ -345,18 +351,18 @@ def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
     OSError: If there is any problem in loading the supporting C shared library.
     ValueError:  If the shared library doesn't contain the needed routine.
   """
-    with ir.Context() as ctx, ir.Location.unknown():
-        module = _get_output_sparse_tensor_kernel(sparsity, type)
-        module = ir.Module.parse(module)
-        engine = compile_and_build_engine(module)
+  with ir.Context() as ctx, ir.Location.unknown():
+    module = _get_output_sparse_tensor_kernel(sparsity, type)
+    module = ir.Module.parse(module)
+    engine = compile_and_build_engine(module)
 
-    # Convert the filename to a byte stream.
-    c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+  # Convert the filename to a byte stream.
+  c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
 
-    arg_pointers = [
-        ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
-        ctypes.byref(c_filename)
-    ]
+  arg_pointers = [
+      ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
+      ctypes.byref(c_filename)
+  ]
 
-    # Invoke the execution engine to run the module and return the result.
-    engine.invoke(_ENTRY_NAME, *arg_pointers)
+  # Invoke the execution engine to run the module and return the result.
+  engine.invoke(_ENTRY_NAME, *arg_pointers)


        


More information about the Mlir-commits mailing list