[Mlir-commits] [mlir] c1864ab - [mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for operators foreach and convert.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 18 09:27:37 PDT 2022
Author: bixia1
Date: 2022-10-18T09:27:32-07:00
New Revision: c1864ab9534080ee77a016bca24dd9a318bc6d7e
URL: https://github.com/llvm/llvm-project/commit/c1864ab9534080ee77a016bca24dd9a318bc6d7e
DIFF: https://github.com/llvm/llvm-project/commit/c1864ab9534080ee77a016bca24dd9a318bc6d7e.diff
LOG: [mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for operators foreach and convert.
This is to help simplify FileCheck tests for sparse-tensor-rewrite.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D136093
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index fd99e4f57af0e..2230f433e8526 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -158,14 +158,21 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
std::unique_ptr<Pass> createSparseTensorCodegenPass();
//===----------------------------------------------------------------------===//
-// Other rewriting rules and passes.
+// The SparseTensorRewriting pass.
//===----------------------------------------------------------------------===//
-void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
+void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
+ bool enableForeach, bool enableConvert);
std::unique_ptr<Pass> createSparseTensorRewritePass();
std::unique_ptr<Pass>
-createSparseTensorRewritePass(const SparsificationOptions &options);
+createSparseTensorRewritePass(const SparsificationOptions &options,
+ bool enableForeach = true,
+ bool enableConvert = true);
+
+//===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
std::unique_ptr<Pass> createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 26c78aea50a81..eee33b08af7eb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -28,7 +28,11 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
];
let options = [
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
- "true", "Enable runtime library for manipulating sparse tensors">
+ "true", "Enable runtime library for manipulating sparse tensors">,
+ Option<"enableForeach", "enable-foreach", "bool",
+ "true", "Enable rewriting rules for the foreach operator">,
+ Option<"enableConvert", "enable-convert", "bool",
+ "true", "Enable rewriting rules for the convert operator">,
];
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 5ea55c822634b..b524ac1f0796c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -43,14 +43,18 @@ struct SparseTensorRewritePass
SparseTensorRewritePass() = default;
SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
- SparseTensorRewritePass(const SparsificationOptions &options) {
+ SparseTensorRewritePass(const SparsificationOptions &options, bool foreach,
+ bool convert) {
enableRuntimeLibrary = options.enableRuntimeLibrary;
+ enableForeach = foreach;
+ enableConvert = convert;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
+ populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
+ enableConvert);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -255,8 +259,10 @@ std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
}
std::unique_ptr<Pass>
-mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
- return std::make_unique<SparseTensorRewritePass>(options);
+mlir::createSparseTensorRewritePass(const SparsificationOptions &options,
+ bool enableForeach, bool enableConvert) {
+ return std::make_unique<SparseTensorRewritePass>(options, enableForeach,
+ enableConvert);
}
std::unique_ptr<Pass> mlir::createSparsificationPass() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 26548871be508..36a564be934ae 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -612,11 +612,14 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
- bool enableRT) {
+ bool enableRT, bool enableForeach,
+ bool /*enableConvert*/) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>, ForeachRewriter>(
- patterns.getContext());
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+ if (enableForeach)
+ patterns.add<ForeachRewriter>(patterns.getContext());
+
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT)
patterns.add<ConcatenateRewriter, NewRewriter,
More information about the Mlir-commits
mailing list