[Mlir-commits] [mlir] [mlir][sparse] introduce new pass to recover sparse encodings dropped… (PR #92052)
Peiming Liu
llvmlistbot at llvm.org
Mon May 13 17:09:02 PDT 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/92052
… by tensor/linalg transformations.
>From 21a9476acd6c10e17d59b2a3e16c91b7d6c243be Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 14 May 2024 00:03:34 +0000
Subject: [PATCH] [mlir][sparse] introduce new pass to recover sparse encodings
dropped by tensor/linalg transformations.
---
.../Dialect/SparseTensor/Transforms/Passes.h | 6 +++
.../Dialect/SparseTensor/Transforms/Passes.td | 38 +++++++++++++++++++
.../Transforms/SparseTensorPasses.cpp | 13 +++++++
3 files changed, 57 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3e16dd53741bb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -65,6 +65,12 @@ void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
std::unique_ptr<Pass> createSparseAssembler();
std::unique_ptr<Pass> createSparseAssembler(bool directOut);
+//===----------------------------------------------------------------------===//
+// The SparseEncodingRecovery pass.
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseEncodingRecoveryPass();
+
//===----------------------------------------------------------------------===//
// The SparseReinterpretMap pass.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..3a66629921d2f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -40,6 +40,44 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
];
}
+def SparseEncodingRecovery : Pass<"sparse-encoding-recovery", "func::FuncOp"> {
+ let summary = "Recover dropped sparse tensor encodings";
+ let description = [{
+ A pass that recovers dropped sparse tensor encodings.
+
+ Background: To avoid introducing repetitive operations, sparse tensors
+ in MLIR try to reuse tensor operations whenever available. However, most
+ tensor operations are canonicalized/transformed without the knowledge
+ of sparsity. The pass tries to recover lost sparse encodings. Though,
+ ideally, tensor dialect should allow extenstions to infer/propagate
+ tensor encodings correctly.
+
+ For example:
+ ```mlir
+ %s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+ : tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse>
+
+ // After rank reducing (by tensor dialect transformation)
+ %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+ : tensor<2x3xf32, #sparse> to tensor<2xf32>
+ %s = tensor.expand_shape [[0, 1]] %t
+ : tensor<2xf32> to tensor<2x1xf32, #sparse>
+
+ // After sparsity recovery
+ %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+ : tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1>
+ %s = tensor.expand_shape [[0, 1]] %t
+ : tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse>
+ ```
+ }];
+
+ let constructor = "mlir::createSparseEncodingRecoveryPass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ "tensor::TensorDialect",
+ ];
+}
+
def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
let summary = "Reinterprets sparse tensor type mappings";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b42d58634a36c..5c2579a63e840 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -23,6 +23,7 @@
namespace mlir {
#define GEN_PASS_DEF_SPARSEASSEMBLER
+#define GEN_PASS_DEF_SPARSEENCODINGRECOVERY
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
@@ -60,6 +61,14 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
}
};
+struct SparseEncodingRecovery
+ : public impl::SparseEncodingRecoveryBase<SparseEncodingRecovery> {
+ SparseEncodingRecovery() = default;
+ SparseEncodingRecovery(const SparseEncodingRecovery &pass) = default;
+
+ void runOnOperation() override {}
+};
+
struct SparseReinterpretMap
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
SparseReinterpretMap() = default;
@@ -398,6 +407,10 @@ std::unique_ptr<Pass> mlir::createSparseAssembler() {
return std::make_unique<SparseAssembler>();
}
+std::unique_ptr<Pass> mlir::createSparseEncodingRecoveryPass() {
+ return std::make_unique<SparseEncodingRecovery>();
+}
+
std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
return std::make_unique<SparseReinterpretMap>();
}
More information about the Mlir-commits
mailing list