[Mlir-commits] [mlir] [mlir][sparse] fix bug with all-dense assembler (PR #108615)
Aart Bik
llvmlistbot at llvm.org
Fri Sep 13 10:56:47 PDT 2024
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/108615
>From e7d51a434d6e88ef0918a014e4245fb119665ea0 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 13 Sep 2024 10:44:14 -0700
Subject: [PATCH 1/2] [mlir][sparse] fix bug with all-dense assembler
When only all-dense "sparse" tensors occur in a function
prototype, the assembler would skip the method conversion
purely based on input/output counts. It should rewrite
based on the presence of any annotation, however.
---
.../Transforms/SparseAssembler.cpp | 14 ++-
.../Dialect/SparseTensor/python/test_SDDMM.py | 3 +-
.../Dialect/SparseTensor/python/test_SpMM.py | 2 +-
.../SparseTensor/python/test_all_dense.py | 96 +++++++++++++++++++
.../SparseTensor/python/test_output.py | 2 +-
.../SparseTensor/python/test_stress.py | 3 +-
.../SparseTensor/python/tools/sparsifier.py | 13 ++-
7 files changed, 122 insertions(+), 11 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a53bce16dad860..5461987fb49d93 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -24,7 +24,8 @@ using namespace sparse_tensor;
//===----------------------------------------------------------------------===//
// Convert type range to new types range, with sparse tensors externalized.
-static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
+static void convTypes(bool &hasAnnotation, TypeRange types,
+ SmallVectorImpl<Type> &convTypes,
SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
// All "dense" data passes through unmodified.
@@ -32,6 +33,7 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
convTypes.push_back(type);
continue;
}
+ hasAnnotation = true;
// Convert the external representations of the pos/crd/val arrays.
const SparseTensorType stt(cast<RankedTensorType>(type));
@@ -176,12 +178,14 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
- convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
- convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
+ bool hasAnnotation = false;
+ convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
+ false);
+ convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
+ directOut);
// Only sparse inputs or outputs need a wrapper method.
- if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
- outputTypes.size() == funcOp.getResultTypes().size())
+ if (!hasAnnotation)
return failure();
// Modify the original method into an internal, private method.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index e2050b98728f21..5ffb910e02d46d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -163,7 +163,8 @@ def main():
)
opt = f"parallelization-strategy=none"
compiler = sparsifier.Sparsifier(
- options=opt, opt_level=0, shared_libs=[support_lib]
+ extras="", options=opt, opt_level=0,
+ shared_libs=[support_lib]
)
build_compile_and_run_SDDMMM(attr, compiler)
count = count + 1
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index e7354c24d619e0..65fc6a0bdbe46b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -141,7 +141,7 @@ def main():
]
bitwidths = [0]
compiler = sparsifier.Sparsifier(
- options=opt, opt_level=0, shared_libs=[support_lib]
+ extra="", options=opt, opt_level=0, shared_libs=[support_lib]
)
for level in levels:
for ordering in orderings:
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py
new file mode 100644
index 00000000000000..eebed3afd7084b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py
@@ -0,0 +1,96 @@
+# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
+# RUN: %PYTHON %s | FileCheck %s
+
+import ctypes
+import os
+import sys
+import tempfile
+
+from mlir import ir
+from mlir import runtime as rt
+from mlir.dialects import builtin
+from mlir.dialects import sparse_tensor as st
+import numpy as np
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import sparsifier
+
+
+def boilerplate():
+ """Returns boilerplate main method."""
+ return """
+#Dense = #sparse_tensor.encoding<{
+ map = (i, j) -> (i: dense, j: dense)
+}>
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @add(%st_0 : tensor<3x4xf64, #Dense>,
+ %st_1 : tensor<3x4xf64, #Dense>) attributes { llvm.emit_c_interface } {
+ %out_st = tensor.empty() : tensor<3x4xf64, #Dense>
+ %res = linalg.generic {indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%st_0, %st_1 : tensor<3x4xf64, #Dense>, tensor<3x4xf64, #Dense>)
+ outs(%out_st : tensor<3x4xf64, #Dense>) {
+ ^bb0(%in_0: f64, %in_1: f64, %out: f64):
+ %2 = sparse_tensor.binary %in_0, %in_1 : f64, f64 to f64
+ overlap = {
+ ^bb0(%arg1: f64, %arg2: f64):
+ %3 = arith.addf %arg1, %arg2 : f64
+ sparse_tensor.yield %3 : f64
+ }
+ left = {
+ ^bb0(%arg1: f64):
+ sparse_tensor.yield %arg1 : f64
+ }
+ right = {
+ ^bb0(%arg1: f64):
+ sparse_tensor.yield %arg1 : f64
+ }
+ linalg.yield %2 : f64
+ } -> tensor<3x4xf64, #Dense>
+ sparse_tensor.print %res : tensor<3x4xf64, #Dense>
+ return
+}
+"""
+
+
+def main():
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+
+ # CHECK-LABEL: TEST: all dense
+ # CHECK: ---- Sparse Tensor ----
+ # CHECK: nse = 12
+ # CHECK: dim = ( 3, 4 )
+ # CHECK: lvl = ( 3, 4 )
+ # CHECK: values : ( 1, 1, 0, 1, 0, 6, 2, 3, 0, 0, 0, 2 )
+ # CHECK: ----
+ print("\nTEST: all dense")
+ with ir.Context() as ctx, ir.Location.unknown():
+ compiler = sparsifier.Sparsifier(
+ extras="sparse-assembler,",
+ options="enable-runtime-library=false",
+ opt_level=2,
+ shared_libs=[support_lib],
+ )
+ module = ir.Module.parse(boilerplate())
+ engine = compiler.compile_and_jit(module)
+ print(module)
+
+ a = np.array([1, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, 1], dtype=np.float64)
+ b = np.array([0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 0, 1], dtype=np.float64)
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+
+ # Invoke the kernel and get numpy output.
+ # Built-in bufferization uses in-out buffers.
+ engine.invoke("add", mem_a, mem_b)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 7da05303c7e1e1..544273eb18835e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -139,7 +139,7 @@ def main():
]
bitwidths = [8, 64]
compiler = sparsifier.Sparsifier(
- options="", opt_level=2, shared_libs=[support_lib]
+ extras="", options="", opt_level=2, shared_libs=[support_lib]
)
for level in levels:
for ordering, id_map in orderings:
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index ce3516e2edaf03..9db00454053481 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -195,7 +195,8 @@ def main():
with ir.Context() as ctx, ir.Location.unknown():
sparsification_options = f"parallelization-strategy=none "
compiler = sparsifier.Sparsifier(
- options=sparsification_options, opt_level=0, shared_libs=[support_lib]
+ extras="", options=sparsification_options, opt_level=0,
+ shared_libs=[support_lib]
)
f64 = ir.F64Type.get()
# Be careful about increasing this because
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py
index ab7208f23f61b6..91d1fb22542d4a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py
@@ -13,8 +13,17 @@
class Sparsifier:
"""Sparsifier class for compiling and building MLIR modules."""
- def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
- pipeline = f"builtin.module(sparsifier{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+ def __init__(
+ self,
+ extras: str,
+ options: str,
+ opt_level: int,
+ shared_libs: Sequence[str],
+ ):
+ pipeline = (
+ f"builtin.module({extras}sparsifier{{{options} reassociate-fp-reductions=1"
+ " enable-index-optimizations=1})"
+ )
self.pipeline = pipeline
self.opt_level = opt_level
self.shared_libs = shared_libs
>From d6660dd4309500e2b7c33e3b3e24dae382d77e59 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 13 Sep 2024 10:56:28 -0700
Subject: [PATCH 2/2] lint
---
.../Dialect/SparseTensor/python/test_SDDMM.py | 6 +-
.../SparseTensor/python/test_all_dense.py | 66 +++++++++----------
.../SparseTensor/python/test_stress.py | 6 +-
3 files changed, 40 insertions(+), 38 deletions(-)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 5ffb910e02d46d..b6f61a47dec1ec 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -163,8 +163,10 @@ def main():
)
opt = f"parallelization-strategy=none"
compiler = sparsifier.Sparsifier(
- extras="", options=opt, opt_level=0,
- shared_libs=[support_lib]
+ extras="",
+ options=opt,
+ opt_level=0,
+ shared_libs=[support_lib],
)
build_compile_and_run_SDDMMM(attr, compiler)
count = count + 1
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py
index eebed3afd7084b..9ab374ac820714 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_all_dense.py
@@ -18,8 +18,8 @@
def boilerplate():
- """Returns boilerplate main method."""
- return """
+ """Returns boilerplate main method."""
+ return """
#Dense = #sparse_tensor.encoding<{
map = (i, j) -> (i: dense, j: dense)
}>
@@ -56,41 +56,39 @@ def boilerplate():
def main():
- support_lib = os.getenv("SUPPORT_LIB")
- assert support_lib is not None, "SUPPORT_LIB is undefined"
- if not os.path.exists(support_lib):
- raise FileNotFoundError(
- errno.ENOENT, os.strerror(errno.ENOENT), support_lib
- )
+ support_lib = os.getenv("SUPPORT_LIB")
+ assert support_lib is not None, "SUPPORT_LIB is undefined"
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
- # CHECK-LABEL: TEST: all dense
- # CHECK: ---- Sparse Tensor ----
- # CHECK: nse = 12
- # CHECK: dim = ( 3, 4 )
- # CHECK: lvl = ( 3, 4 )
- # CHECK: values : ( 1, 1, 0, 1, 0, 6, 2, 3, 0, 0, 0, 2 )
- # CHECK: ----
- print("\nTEST: all dense")
- with ir.Context() as ctx, ir.Location.unknown():
- compiler = sparsifier.Sparsifier(
- extras="sparse-assembler,",
- options="enable-runtime-library=false",
- opt_level=2,
- shared_libs=[support_lib],
- )
- module = ir.Module.parse(boilerplate())
- engine = compiler.compile_and_jit(module)
- print(module)
+ # CHECK-LABEL: TEST: all dense
+ # CHECK: ---- Sparse Tensor ----
+ # CHECK: nse = 12
+ # CHECK: dim = ( 3, 4 )
+ # CHECK: lvl = ( 3, 4 )
+ # CHECK: values : ( 1, 1, 0, 1, 0, 6, 2, 3, 0, 0, 0, 2 )
+ # CHECK: ----
+ print("\nTEST: all dense")
+ with ir.Context() as ctx, ir.Location.unknown():
+ compiler = sparsifier.Sparsifier(
+ extras="sparse-assembler,",
+ options="enable-runtime-library=false",
+ opt_level=2,
+ shared_libs=[support_lib],
+ )
+ module = ir.Module.parse(boilerplate())
+ engine = compiler.compile_and_jit(module)
+ print(module)
- a = np.array([1, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, 1], dtype=np.float64)
- b = np.array([0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 0, 1], dtype=np.float64)
- mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
- mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+ a = np.array([1, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, 1], dtype=np.float64)
+ b = np.array([0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 0, 1], dtype=np.float64)
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
- # Invoke the kernel and get numpy output.
- # Built-in bufferization uses in-out buffers.
- engine.invoke("add", mem_a, mem_b)
+ # Invoke the kernel and get numpy output.
+ # Built-in bufferization uses in-out buffers.
+ engine.invoke("add", mem_a, mem_b)
if __name__ == "__main__":
- main()
+ main()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index 9db00454053481..dfaa4f462660b7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -195,8 +195,10 @@ def main():
with ir.Context() as ctx, ir.Location.unknown():
sparsification_options = f"parallelization-strategy=none "
compiler = sparsifier.Sparsifier(
- extras="", options=sparsification_options, opt_level=0,
- shared_libs=[support_lib]
+ extras="",
+ options=sparsification_options,
+ opt_level=0,
+ shared_libs=[support_lib],
)
f64 = ir.F64Type.get()
# Be careful about increasing this because
More information about the Mlir-commits
mailing list