[Mlir-commits] [mlir] [mlir][sparse] introduce a pass to stage complex sparse operations in… (PR #68436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 11:06:43 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

<details>
<summary>Changes</summary>

…to simple steps

---
Full diff: https://github.com/llvm/llvm-project/pull/68436.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+9) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+12) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+17) 
- (added) mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp (+4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c1e217675020f08..c537e92a51d5333 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -87,6 +87,15 @@ std::unique_ptr<Pass> createSparsificationPass();
 std::unique_ptr<Pass>
 createSparsificationPass(const SparsificationOptions &options);
 
+//===----------------------------------------------------------------------===//
+// The StageSparseOperations pass.
+//===----------------------------------------------------------------------===//
+
+/// Sets up StageSparseOperation rewriting rules.
+void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createStageSparseOperationsPass();
+
 //===----------------------------------------------------------------------===//
 // The PostSparsificationRewriting pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d8d5dbb5ad3ce75..7071c3091d33f3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
   ];
 }
 
+def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
+  let summary = "Decompose a complex sparse operations into multiple stages";
+  let description = [{
+    A pass that decomposes a complex sparse operations into multiple stages.
+    E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
+  }];
+  let constructor = "mlir::createStageSparseOperationsPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules after sparsification";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 5ef9d906f0e8b7c..0ca6668c8c74745 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseVectorization.cpp
   Sparsification.cpp
   SparsificationAndBufferizationPass.cpp
+  StageSparseOperations.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f50d3d4606554a1..e1f88ad9c0e1140 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -30,6 +30,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
 #define GEN_PASS_DEF_SPARSEVECTORIZATION
 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
+#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
@@ -92,6 +93,18 @@ struct SparsificationPass
   }
 };
 
+struct StageSparseOperationsPass
+    : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
+  StageSparseOperationsPass() = default;
+  StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateStageSparseOperationsPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct PostSparsificationRewritePass
     : public impl::PostSparsificationRewriteBase<
           PostSparsificationRewritePass> {
@@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
   return std::make_unique<SparsificationPass>(options);
 }
 
+std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
+  return std::make_unique<StageSparseOperationsPass>();
+}
+
 std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
   return std::make_unique<PostSparsificationRewritePass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
new file mode 100644
index 000000000000000..4adc4d131198cc7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -0,0 +1,4 @@
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+void mlir::populateStageSparseOperationsPatterns(
+    RewritePatternSet & /*patterns*/) {}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68436


More information about the Mlir-commits mailing list