[Mlir-commits] [mlir] [mlir][sparse] introduce a pass to stage complex sparse operations in… (PR #68436)

Peiming Liu llvmlistbot at llvm.org
Fri Oct 6 14:09:57 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/68436

>From 7f71eba0a72ed92be61d03284e61e003d90d25f1 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 6 Oct 2023 18:01:58 +0000
Subject: [PATCH 1/2] [mlir][sparse] introduce a pass to stage complex sparse
 operations into simple steps

---
 .../Dialect/SparseTensor/Transforms/Passes.h    |  9 +++++++++
 .../Dialect/SparseTensor/Transforms/Passes.td   | 12 ++++++++++++
 .../SparseTensor/Transforms/CMakeLists.txt      |  1 +
 .../Transforms/SparseTensorPasses.cpp           | 17 +++++++++++++++++
 .../Transforms/StageSparseOperations.cpp        |  4 ++++
 5 files changed, 43 insertions(+)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c1e217675020f08..c537e92a51d5333 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -87,6 +87,15 @@ std::unique_ptr<Pass> createSparsificationPass();
 std::unique_ptr<Pass>
 createSparsificationPass(const SparsificationOptions &options);
 
+//===----------------------------------------------------------------------===//
+// The StageSparseOperations pass.
+//===----------------------------------------------------------------------===//
+
+/// Sets up StageSparseOperation rewriting rules.
+void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createStageSparseOperationsPass();
+
 //===----------------------------------------------------------------------===//
 // The PostSparsificationRewriting pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d8d5dbb5ad3ce75..7071c3091d33f3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
   ];
 }
 
+def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
+  let summary = "Decompose a complex sparse operations into multiple stages";
+  let description = [{
+    A pass that decomposes a complex sparse operations into multiple stages.
+    E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
+  }];
+  let constructor = "mlir::createStageSparseOperationsPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules after sparsification";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 5ef9d906f0e8b7c..0ca6668c8c74745 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseVectorization.cpp
   Sparsification.cpp
   SparsificationAndBufferizationPass.cpp
+  StageSparseOperations.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f50d3d4606554a1..e1f88ad9c0e1140 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -30,6 +30,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
 #define GEN_PASS_DEF_SPARSEVECTORIZATION
 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
+#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
@@ -92,6 +93,18 @@ struct SparsificationPass
   }
 };
 
+struct StageSparseOperationsPass
+    : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
+  StageSparseOperationsPass() = default;
+  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateStageSparseOperationsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct PostSparsificationRewritePass
     : public impl::PostSparsificationRewriteBase<
           PostSparsificationRewritePass> {
@@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
   return std::make_unique<SparsificationPass>(options);
 }
 
+std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
+  return std::make_unique<StageSparseOperationsPass>();
+}
+
 std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
   return std::make_unique<PostSparsificationRewritePass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
new file mode 100644
index 000000000000000..4adc4d131198cc7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -0,0 +1,4 @@
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+void mlir::populateStageSparseOperationsPatterns(
+    RewritePatternSet & /*patterns*/) {}

>From b1e8f710bd2e7f4df6dd0d4be9f28b13388a0e6b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 6 Oct 2023 21:08:09 +0000
Subject: [PATCH 2/2] address comments

---
 mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 7071c3091d33f3a..8f116bff9b185a3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -124,10 +124,10 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
 }
 
 def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
-  let summary = "Decompose a complex sparse operations into multiple stages";
+  let summary = "Decompose a complex sparse operation into multiple stages";
   let description = [{
-    A pass that decomposes a complex sparse operations into multiple stages.
-    E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
+    A pass that decomposes a complex sparse operation into multiple stages.
+    E.g., CSR -> CSC is staged into CSR -> COO (unordered) -> sort -> CSC.
   }];
   let constructor = "mlir::createStageSparseOperationsPass()";
   let dependentDialects = [



More information about the Mlir-commits mailing list