[Mlir-commits] [mlir] 779dcd2 - [mlir][sparse] move sparse tensor rewriting into its own pass

Aart Bik llvmlistbot at llvm.org
Wed Oct 5 14:53:09 PDT 2022


Author: Aart Bik
Date: 2022-10-05T14:52:55-07:00
New Revision: 779dcd2ecce84983fcec0a23dee66f85772cb17a

URL: https://github.com/llvm/llvm-project/commit/779dcd2ecce84983fcec0a23dee66f85772cb17a
DIFF: https://github.com/llvm/llvm-project/commit/779dcd2ecce84983fcec0a23dee66f85772cb17a.diff

LOG: [mlir][sparse] move sparse tensor rewriting into its own pass

Makes individual testing and debugging easier.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D135319

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/rewriting.mlir
    mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
    mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
    mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index e6e65b7b0d682..fd99e4f57af0e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -163,6 +163,10 @@ std::unique_ptr<Pass> createSparseTensorCodegenPass();
 
 void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
 
+std::unique_ptr<Pass> createSparseTensorRewritePass();
+std::unique_ptr<Pass>
+createSparseTensorRewritePass(const SparsificationOptions &options);
+
 std::unique_ptr<Pass> createDenseBufferizationPass(
     const bufferization::OneShotBufferizationOptions &options);
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d97bead8d296c..26c78aea50a81 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -11,6 +11,27 @@
 
 include "mlir/Pass/PassBase.td"
 
+def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
+  let summary = "Applies sparse tensor rewriting rules prior to sparsification";
+  let description = [{
+    A pass that applies rewriting rules to sparse tensor operations prior
+    to running the actual sparsification pass.
+  }];
+  let constructor = "mlir::createSparseTensorRewritePass()";
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "bufferization::BufferizationDialect",
+    "linalg::LinalgDialect",
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+  let options = [
+    Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
+           "true", "Enable runtime library for manipulating sparse tensors">
+  ];
+}
+
 def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
   let summary = "Automatically generate sparse tensor code from sparse tensor types";
   let description = [{
@@ -57,6 +78,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
     "arith::ArithDialect",
     "bufferization::BufferizationDialect",
     "LLVM::LLVMDialect",
+    "linalg::LinalgDialect",
     "memref::MemRefDialect",
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
@@ -193,4 +215,5 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
     "sparse_tensor::SparseTensorDialect",
   ];
 }
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index abecf4679e4e9..0cd17998e81bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -58,6 +58,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
           /*analysisOnly=*/options.testBufferizationAnalysisOnly)));
   if (options.testBufferizationAnalysisOnly)
     return;
+  pm.addPass(createSparseTensorRewritePass(options.sparsificationOptions()));
   pm.addPass(createSparsificationPass(options.sparsificationOptions()));
   if (options.enableRuntimeLibrary)
     pm.addPass(createSparseTensorConversionPass(

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b208dfeb5558b..5ea55c822634b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
+#define GEN_PASS_DEF_SPARSETENSORREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
@@ -37,6 +38,23 @@ namespace {
 // Passes implementation.
 //===----------------------------------------------------------------------===//
 
+struct SparseTensorRewritePass
+    : public impl::SparseTensorRewriteBase<SparseTensorRewritePass> {
+
+  SparseTensorRewritePass() = default;
+  SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
+  SparseTensorRewritePass(const SparsificationOptions &options) {
+    enableRuntimeLibrary = options.enableRuntimeLibrary;
+  }
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct SparsificationPass
     : public impl::SparsificationPassBase<SparsificationPass> {
 
@@ -53,14 +71,10 @@ struct SparsificationPass
 
   void runOnOperation() override {
     auto *ctx = &getContext();
-    RewritePatternSet prePatterns(ctx);
     // Translate strategy flags to strategy options.
     SparsificationOptions options(parallelization, vectorization, vectorLength,
                                   enableSIMDIndex32, enableVLAVectorization,
                                   enableRuntimeLibrary);
-    // Apply pre-rewriting.
-    populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
     // Apply sparsification and vector cleanup rewriting.
     RewritePatternSet patterns(ctx);
     populateSparsificationPatterns(patterns, options);
@@ -236,6 +250,15 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
 // Pass creation methods.
 //===----------------------------------------------------------------------===//
 
+std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
+  return std::make_unique<SparseTensorRewritePass>();
+}
+
+std::unique_ptr<Pass>
+mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
+  return std::make_unique<SparseTensorRewritePass>(options);
+}
+
 std::unique_ptr<Pass> mlir::createSparsificationPass() {
   return std::make_unique<SparsificationPass>();
 }

diff  --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir
index 000c3560f1e0f..f142ecf7ff341 100755
--- a/mlir/test/Dialect/SparseTensor/rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparsification | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = ["compressed"]

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 018f391122079..a46da498b3871 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparsification=enable-runtime-library=false | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --sparsification | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 132566653c972..2b388bafc2cfd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 1a521800f3336..ad1ad1d524be3 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s  --tensor-copy-insertion --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s  --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s
 
 #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
 


        


More information about the Mlir-commits mailing list