[Mlir-commits] [mlir] [mlir][sparse] enable Python BSR test (PR #72325)
Aart Bik
llvmlistbot at llvm.org
Tue Nov 14 15:27:11 PST 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/72325
None
>From 9c736adfd7486a542cd03b913771b207ab06f954 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 14 Nov 2023 15:24:12 -0800
Subject: [PATCH] [mlir][sparse] enable Python BSR test
---
.../SparseTensor/python/test_output.py | 61 +++++++++++++++----
1 file changed, 49 insertions(+), 12 deletions(-)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index c9efadb60480c54..216922207b7820a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -32,11 +32,28 @@ def boilerplate(attr: st.EncodingAttr):
def expected(id_map):
"""Returns expected contents of output.
+ +-----+-----+-----+-----+-----+
+ | 1 0 | . . | . . | . . | 0 3 |
+ | 0 2 | . . | . . | . . | 0 0 |
+ +-----+-----+-----+-----+-----+
+ | . . | . . | . . | . . | . . |
+ | . . | . . | . . | . . | . . |
+ +-----+-----+-----+-----+-----+
+ | . . | . . | 5 0 | . . | . . |
+ | . . | . . | 0 0 | . . | . . |
+ +-----+-----+-----+-----+-----+
+ | . . | . . | . . | . . | . . |
+ | . . | . . | . . | . . | . . |
+ +-----+-----+-----+-----+-----+
+ | 0 0 | . . | . . | . . | . . |
+ | 4 0 | . . | . . | . . | . . |
+ +-----+-----+-----+-----+-----+
+
Output appears as dimension coordinates but lexicographically
- sorted by level coordinates.
+ sorted by level coordinates. For BSR, the blocks are filled.
"""
- return (
- f"""# extended FROSTT format
+ if id_map is 0:
+ return f"""# extended FROSTT format
2 5
10 10
1 1 1
@@ -45,8 +62,8 @@ def expected(id_map):
5 5 5
10 1 4
"""
- if id_map
- else f"""# extended FROSTT format
+ if id_map is 1:
+ return f"""# extended FROSTT format
2 5
10 10
1 1 1
@@ -55,7 +72,28 @@ def expected(id_map):
5 5 5
1 10 3
"""
- )
+ if id_map is 2:
+ return f"""# extended FROSTT format
+2 16
+10 10
+1 1 1
+1 2 0
+2 1 0
+2 2 2
+1 9 0
+1 10 3
+2 9 0
+2 10 0
+5 5 5
+5 6 0
+6 5 0
+6 6 0
+9 1 0
+9 2 0
+10 1 4
+10 2 0
+"""
+ raise AssertionError("unexpected id_map")
def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
@@ -93,10 +131,10 @@ def main():
[st.DimLevelType.compressed, st.DimLevelType.compressed],
]
orderings = [
- (ir.AffineMap.get_permutation([0, 1]), True),
- (ir.AffineMap.get_permutation([1, 0]), False),
+ (ir.AffineMap.get_permutation([0, 1]), 0),
+ (ir.AffineMap.get_permutation([1, 0]), 1),
]
- bitwidths = [8, 16, 32, 64]
+ bitwidths = [8, 64]
compiler = sparse_compiler.SparseCompiler(
options="", opt_level=2, shared_libs=[support_lib]
)
@@ -135,11 +173,10 @@ def main():
l3 = ir.AffineDimExpr.get(3)
lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3])
attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0)
- # TODO: enable this one CONVERSION on BSR is working
- # build_compile_and_run_output(attr, compiler, block_expected())
+ build_compile_and_run_output(attr, compiler, expected(2))
count = count + 1
- # CHECK: Passed 33 tests
+ # CHECK: Passed 17 tests
print("Passed", count, "tests")
More information about the Mlir-commits
mailing list