[Mlir-commits] [mlir] f81f0cb - [mlir][sparse] Split SparseTensorRewrite into PreSparsificationRewrite and PostSparsificationRewrite.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 17 07:14:00 PST 2022
Author: bixia1
Date: 2022-11-17T07:13:55-08:00
New Revision: f81f0cb75a2808a67d2662f044ad07628fc9d900
URL: https://github.com/llvm/llvm-project/commit/f81f0cb75a2808a67d2662f044ad07628fc9d900
DIFF: https://github.com/llvm/llvm-project/commit/f81f0cb75a2808a67d2662f044ad07628fc9d900.diff
LOG: [mlir][sparse] Split SparseTensorRewrite into PreSparsificationRewrite and PostSparsificationRewrite.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D138153
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
mlir/test/Dialect/SparseTensor/rewriting.mlir
mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index badc3d0dfd876..0961b5e000868 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -138,16 +138,25 @@ std::unique_ptr<Pass>
createSparseTensorCodegenPass(bool enableBufferInitialization);
//===----------------------------------------------------------------------===//
-// The SparseTensorRewriting pass.
+// The PreSparsificationRewriting pass.
//===----------------------------------------------------------------------===//
-void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
- bool enableForeach, bool enableConvert);
+void populatePreSparsificationRewriting(RewritePatternSet &patterns);
-std::unique_ptr<Pass> createSparseTensorRewritePass();
-std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
- bool enableForeach = true,
- bool enableConvert = true);
+std::unique_ptr<Pass> createPreSparsificationRewritePass();
+
+//===----------------------------------------------------------------------===//
+// The PostSparsificationRewriting pass.
+//===----------------------------------------------------------------------===//
+
+void populatePostSparsificationRewriting(RewritePatternSet &patterns,
+ bool enableRT, bool enableForeach,
+ bool enableConvert);
+
+std::unique_ptr<Pass> createPostSparsificationRewritePass();
+std::unique_ptr<Pass>
+createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
+ bool enableConvert = true);
//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 74784afbee3f2..32bba3a1552e4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -11,13 +11,13 @@
include "mlir/Pass/PassBase.td"
-def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
+def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
let summary = "Applies sparse tensor rewriting rules prior to sparsification";
let description = [{
A pass that applies rewriting rules to sparse tensor operations prior
to running the actual sparsification pass.
}];
- let constructor = "mlir::createSparseTensorRewritePass()";
+ let constructor = "mlir::createPreSparsificationRewritePass()";
let dependentDialects = [
"arith::ArithDialect",
"bufferization::BufferizationDialect",
@@ -26,14 +26,6 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
- let options = [
- Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
- "true", "Enable runtime library for manipulating sparse tensors">,
- Option<"enableForeach", "enable-foreach", "bool",
- "true", "Enable rewriting rules for the foreach operator">,
- Option<"enableConvert", "enable-convert", "bool",
- "true", "Enable rewriting rules for the convert operator">,
- ];
}
def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
@@ -109,6 +101,31 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
];
}
+def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
+ let summary = "Applies sparse tensor rewriting rules after sparsification";
+ let description = [{
+ A pass that applies rewriting rules to sparse tensor operations after
+ running the actual sparsification pass.
+ }];
+ let constructor = "mlir::createPostSparsificationRewritePass()";
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "bufferization::BufferizationDialect",
+ "linalg::LinalgDialect",
+ "memref::MemRefDialect",
+ "scf::SCFDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+ let options = [
+ Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
+ "true", "Enable runtime library for manipulating sparse tensors">,
+ Option<"enableForeach", "enable-foreach", "bool",
+ "true", "Enable rewriting rules for the foreach operator">,
+ Option<"enableConvert", "enable-convert", "bool",
+ "true", "Enable rewriting rules for the convert operator">,
+ ];
+}
+
def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
let summary = "Convert sparse tensors and primitives to library calls";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index e48a760df0fa2..b816fad8c2477 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -57,8 +57,9 @@ void mlir::sparse_tensor::buildSparseCompiler(
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
if (options.testBufferizationAnalysisOnly)
return;
- pm.addPass(createSparseTensorRewritePass(options.enableRuntimeLibrary));
+ pm.addPass(createPreSparsificationRewritePass());
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
+ pm.addPass(createPostSparsificationRewritePass(options.enableRuntimeLibrary));
if (options.enableRuntimeLibrary) {
pm.addPass(createSparseTensorConversionPass(
options.sparseTensorConversionOptions()));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index da7c6ffda2080..d1491dfffcb7d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -21,8 +21,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
-#define GEN_PASS_DEF_SPARSETENSORREWRITE
+#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
@@ -38,22 +39,17 @@ namespace {
// Passes implementation.
//===----------------------------------------------------------------------===//
-struct SparseTensorRewritePass
- : public impl::SparseTensorRewriteBase<SparseTensorRewritePass> {
+struct PreSparsificationRewritePass
+ : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
- SparseTensorRewritePass() = default;
- SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
- SparseTensorRewritePass(bool enableRT, bool foreach, bool convert) {
- enableRuntimeLibrary = enableRT;
- enableForeach = foreach;
- enableConvert = convert;
- }
+ PreSparsificationRewritePass() = default;
+ PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
+ default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
- enableConvert);
+ populatePreSparsificationRewriting(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -80,6 +76,28 @@ struct SparsificationPass
}
};
+struct PostSparsificationRewritePass
+ : public impl::PostSparsificationRewriteBase<
+ PostSparsificationRewritePass> {
+
+ PostSparsificationRewritePass() = default;
+ PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
+ default;
+ PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
+ enableRuntimeLibrary = enableRT;
+ enableForeach = foreach;
+ enableConvert = convert;
+ }
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
+ enableForeach, enableConvert);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
@@ -254,15 +272,8 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
// Pass creation methods.
//===----------------------------------------------------------------------===//
-std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
- return std::make_unique<SparseTensorRewritePass>();
-}
-
-std::unique_ptr<Pass> mlir::createSparseTensorRewritePass(bool enableRT,
- bool enableForeach,
- bool enableConvert) {
- return std::make_unique<SparseTensorRewritePass>(enableRT, enableForeach,
- enableConvert);
+std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
+ return std::make_unique<PreSparsificationRewritePass>();
}
std::unique_ptr<Pass> mlir::createSparsificationPass() {
@@ -274,6 +285,17 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
return std::make_unique<SparsificationPass>(options);
}
+std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
+ return std::make_unique<PostSparsificationRewritePass>();
+}
+
+std::unique_ptr<Pass>
+mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
+ bool enableConvert) {
+ return std::make_unique<PostSparsificationRewritePass>(
+ enableRT, enableForeach, enableConvert);
+}
+
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6f38796719662..29430d35f7108 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1021,11 +1021,17 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
-void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
- bool enableRT, bool enableForeach,
- bool enableConvert) {
- patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
- ReshapeRewriter<tensor::ExpandShapeOp>,
+
+void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
+ patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
+ patterns.getContext());
+}
+
+void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
+ bool enableRT,
+ bool enableForeach,
+ bool enableConvert) {
+ patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 5c9d19b1f4c7d..8336a142528bf 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index 26019866f04ff..2c5de95d775ef 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 61047500027f9..496d5941580cd 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -6,7 +6,7 @@
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
// RUN: --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector64 = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir
index f142ecf7ff341..1744861e40b93 100755
--- a/mlir/test/Dialect/SparseTensor/rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s
+// RUN: mlir-opt %s -post-sparsification-rewrite | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"]
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index 3a6cf999df90a..94c373bab9972 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\
+// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" |\
// RUN: FileCheck %s
#CSR = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 717819bd0cb16..c2a1fd813b1e8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: --sparsification | FileCheck %s
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 240a94028c93b..7571e3dbb9087 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 94ee50197fa9c..a6790458b0902 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
// RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index d2ec5cafd59d2..b55f8cbe33ae3 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s --tensor-copy-insertion --pre-sparsification-rewrite --sparsification --cse | FileCheck %s
#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
More information about the Mlir-commits
mailing list