[Mlir-commits] [mlir] 0637440 - [mlir][sparse] introduce a pass to stage complex sparse operations in… (#68436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 14:23:23 PDT 2023


Author: Peiming Liu
Date: 2023-10-06T14:23:18-07:00
New Revision: 06374400025d50460eafd32581a22b293b850cc8

URL: https://github.com/llvm/llvm-project/commit/06374400025d50460eafd32581a22b293b850cc8
DIFF: https://github.com/llvm/llvm-project/commit/06374400025d50460eafd32581a22b293b850cc8.diff

LOG: [mlir][sparse] introduce a pass to stage complex sparse operations in… (#68436)

…to simple steps

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c1e217675020f08..c537e92a51d5333 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -87,6 +87,15 @@ std::unique_ptr<Pass> createSparsificationPass();
 std::unique_ptr<Pass>
 createSparsificationPass(const SparsificationOptions &options);
 
+//===----------------------------------------------------------------------===//
+// The StageSparseOperations pass.
+//===----------------------------------------------------------------------===//
+
+/// Sets up StageSparseOperation rewriting rules.
+void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createStageSparseOperationsPass();
+
 //===----------------------------------------------------------------------===//
 // The PostSparsificationRewriting pass.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d8d5dbb5ad3ce75..8f116bff9b185a3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
   ];
 }
 
+def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
+  let summary = "Decompose a complex sparse operation into multiple stages";
+  let description = [{
+    A pass that decomposes a complex sparse operation into multiple stages.
+    E.g., CSR -> CSC is staged into CSR -> COO (unordered) -> sort -> CSC.
+  }];
+  let constructor = "mlir::createStageSparseOperationsPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules after sparsification";
   let description = [{

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 5ef9d906f0e8b7c..0ca6668c8c74745 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseVectorization.cpp
   Sparsification.cpp
   SparsificationAndBufferizationPass.cpp
+  StageSparseOperations.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f50d3d4606554a1..e1f88ad9c0e1140 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -30,6 +30,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
 #define GEN_PASS_DEF_SPARSEVECTORIZATION
 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
+#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
@@ -92,6 +93,18 @@ struct SparsificationPass
   }
 };
 
+struct StageSparseOperationsPass
+    : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
+  StageSparseOperationsPass() = default;
+  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateStageSparseOperationsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct PostSparsificationRewritePass
     : public impl::PostSparsificationRewriteBase<
           PostSparsificationRewritePass> {
@@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
   return std::make_unique<SparsificationPass>(options);
 }
 
+std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
+  return std::make_unique<StageSparseOperationsPass>();
+}
+
 std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
   return std::make_unique<PostSparsificationRewritePass>();
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
new file mode 100644
index 000000000000000..4adc4d131198cc7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -0,0 +1,4 @@
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+void mlir::populateStageSparseOperationsPatterns(
+    RewritePatternSet & /*patterns*/) {}


        


More information about the Mlir-commits mailing list