[Mlir-commits] [mlir] [mlir][sparse] introduce a pass to stage complex sparse operations in… (PR #68436)
Peiming Liu
llvmlistbot at llvm.org
Fri Oct 6 14:08:22 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/68436
>From 7f71eba0a72ed92be61d03284e61e003d90d25f1 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 6 Oct 2023 18:01:58 +0000
Subject: [PATCH 1/2] [mlir][sparse] introduce a pass to stage complex sparse
operations into simple steps
---
.../Dialect/SparseTensor/Transforms/Passes.h | 9 +++++++++
.../Dialect/SparseTensor/Transforms/Passes.td | 12 ++++++++++++
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseTensorPasses.cpp | 17 +++++++++++++++++
.../Transforms/StageSparseOperations.cpp | 4 ++++
5 files changed, 43 insertions(+)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c1e217675020f08..c537e92a51d5333 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -87,6 +87,15 @@ std::unique_ptr<Pass> createSparsificationPass();
std::unique_ptr<Pass>
createSparsificationPass(const SparsificationOptions &options);
+//===----------------------------------------------------------------------===//
+// The StageSparseOperations pass.
+//===----------------------------------------------------------------------===//
+
+/// Sets up StageSparseOperation rewriting rules.
+void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createStageSparseOperationsPass();
+
//===----------------------------------------------------------------------===//
// The PostSparsificationRewriting pass.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d8d5dbb5ad3ce75..7071c3091d33f3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
];
}
+def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
+ let summary = "Decompose a complex sparse operations into multiple stages";
+ let description = [{
+ A pass that decomposes a complex sparse operations into multiple stages.
+ E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
+ }];
+ let constructor = "mlir::createStageSparseOperationsPass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
let summary = "Applies sparse tensor rewriting rules after sparsification";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 5ef9d906f0e8b7c..0ca6668c8c74745 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseVectorization.cpp
Sparsification.cpp
SparsificationAndBufferizationPass.cpp
+ StageSparseOperations.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f50d3d4606554a1..e1f88ad9c0e1140 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -30,6 +30,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#define GEN_PASS_DEF_SPARSEGPUCODEGEN
+#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -92,6 +93,18 @@ struct SparsificationPass
}
};
+struct StageSparseOperationsPass
+ : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
+ StageSparseOperationsPass() = default;
+ StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ populateStageSparseOperationsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct PostSparsificationRewritePass
: public impl::PostSparsificationRewriteBase<
PostSparsificationRewritePass> {
@@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
return std::make_unique<SparsificationPass>(options);
}
+std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
+ return std::make_unique<StageSparseOperationsPass>();
+}
+
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
return std::make_unique<PostSparsificationRewritePass>();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
new file mode 100644
index 000000000000000..4adc4d131198cc7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -0,0 +1,4 @@
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+void mlir::populateStageSparseOperationsPatterns(
+ RewritePatternSet & /*patterns*/) {}
>From bfe0b0a433f12bfbb3a0ddc3ac5f77e0e3499bab Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 6 Oct 2023 21:08:09 +0000
Subject: [PATCH 2/2] address comments
---
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 7071c3091d33f3a..0a4f2e3469ad60b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -124,7 +124,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
}
def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
- let summary = "Decompose a complex sparse operations into multiple stages";
+ let summary = "Decompose a complex sparse operation into multiple stages";
let description = [{
A pass that decomposes a complex sparse operations into multiple stages.
E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
More information about the Mlir-commits
mailing list