[Mlir-commits] [mlir] [mlir][sparse] enable rt path for transpose COO (PR #76747)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 13 10:53:39 PDT 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {darker}-->


:warning: Python code formatter, darker found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
darker --check --diff -r 0351dc522a25df0473a63b414a5bfde5814d3dc3...e7d51a434d6e88ef0918a014e4245fb119665ea0 mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py mlir/test/Integration/Dialect/SparseTensor/python/test_output.py mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py
``````````

</details>

<details>
<summary>
View the diff from darker here.
</summary>

``````````diff
--- test_SDDMM.py	2024-09-13 17:44:14.000000 +0000
+++ test_SDDMM.py	2024-09-13 17:53:13.144940 +0000
@@ -161,12 +161,14 @@
                             attr = st.EncodingAttr.get(
                                 level, ordering, ordering, pwidth, iwidth
                             )
                             opt = f"parallelization-strategy=none"
                             compiler = sparsifier.Sparsifier(
-                                extras="", options=opt, opt_level=0,
-                                shared_libs=[support_lib]
+                                extras="",
+                                options=opt,
+                                opt_level=0,
+                                shared_libs=[support_lib],
                             )
                             build_compile_and_run_SDDMMM(attr, compiler)
                             count = count + 1
     # CHECK: Passed 10 tests
     print("Passed ", count, "tests")
--- test_all_dense.py	2024-09-13 17:44:14.000000 +0000
+++ test_all_dense.py	2024-09-13 17:53:13.229420 +0000
@@ -16,12 +16,12 @@
 sys.path.append(_SCRIPT_PATH)
 from tools import sparsifier
 
 
 def boilerplate():
-  """Returns boilerplate main method."""
-  return """
+    """Returns boilerplate main method."""
+    return """
 #Dense = #sparse_tensor.encoding<{
   map = (i, j) -> (i: dense, j: dense)
 }>
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
@@ -54,43 +54,41 @@
 }
 """
 
 
 def main():
-  support_lib = os.getenv("SUPPORT_LIB")
-  assert support_lib is not None, "SUPPORT_LIB is undefined"
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(
-        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
-    )
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
 
-  # CHECK-LABEL: TEST: all dense
-  # CHECK: ---- Sparse Tensor ----
-  # CHECK: nse = 12
-  # CHECK: dim = ( 3, 4 )
-  # CHECK: lvl = ( 3, 4 )
-  # CHECK: values : ( 1, 1, 0, 1, 0, 6, 2, 3, 0, 0, 0, 2 )
-  # CHECK: ----
-  print("\nTEST: all dense")
-  with ir.Context() as ctx, ir.Location.unknown():
-    compiler = sparsifier.Sparsifier(
-        extras="sparse-assembler,",
-        options="enable-runtime-library=false",
-        opt_level=2,
-        shared_libs=[support_lib],
-    )
-    module = ir.Module.parse(boilerplate())
-    engine = compiler.compile_and_jit(module)
-    print(module)
+    # CHECK-LABEL: TEST: all dense
+    # CHECK: ---- Sparse Tensor ----
+    # CHECK: nse = 12
+    # CHECK: dim = ( 3, 4 )
+    # CHECK: lvl = ( 3, 4 )
+    # CHECK: values : ( 1, 1, 0, 1, 0, 6, 2, 3, 0, 0, 0, 2 )
+    # CHECK: ----
+    print("\nTEST: all dense")
+    with ir.Context() as ctx, ir.Location.unknown():
+        compiler = sparsifier.Sparsifier(
+            extras="sparse-assembler,",
+            options="enable-runtime-library=false",
+            opt_level=2,
+            shared_libs=[support_lib],
+        )
+        module = ir.Module.parse(boilerplate())
+        engine = compiler.compile_and_jit(module)
+        print(module)
 
-    a = np.array([1, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, 1], dtype=np.float64)
-    b = np.array([0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 0, 1], dtype=np.float64)
-    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+        a = np.array([1, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, 1], dtype=np.float64)
+        b = np.array([0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 0, 1], dtype=np.float64)
+        mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+        mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
 
-    # Invoke the kernel and get numpy output.
-    # Built-in bufferization uses in-out buffers.
-    engine.invoke("add", mem_a, mem_b)
+        # Invoke the kernel and get numpy output.
+        # Built-in bufferization uses in-out buffers.
+        engine.invoke("add", mem_a, mem_b)
 
 
 if __name__ == "__main__":
-  main()
+    main()
--- test_stress.py	2024-09-13 17:44:14.000000 +0000
+++ test_stress.py	2024-09-13 17:53:13.403620 +0000
@@ -193,12 +193,14 @@
     # CHECK-LABEL: TEST: test_stress
     print("\nTEST: test_stress")
     with ir.Context() as ctx, ir.Location.unknown():
         sparsification_options = f"parallelization-strategy=none "
         compiler = sparsifier.Sparsifier(
-            extras="", options=sparsification_options, opt_level=0,
-            shared_libs=[support_lib]
+            extras="",
+            options=sparsification_options,
+            opt_level=0,
+            shared_libs=[support_lib],
         )
         f64 = ir.F64Type.get()
         # Be careful about increasing this because
         #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
         shape = range(2, 3)

``````````

</details>


https://github.com/llvm/llvm-project/pull/76747


More information about the Mlir-commits mailing list