[Mlir-commits] [mlir] c1864ab - [mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for operators foreach and convert.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 18 09:27:37 PDT 2022


Author: bixia1
Date: 2022-10-18T09:27:32-07:00
New Revision: c1864ab9534080ee77a016bca24dd9a318bc6d7e

URL: https://github.com/llvm/llvm-project/commit/c1864ab9534080ee77a016bca24dd9a318bc6d7e
DIFF: https://github.com/llvm/llvm-project/commit/c1864ab9534080ee77a016bca24dd9a318bc6d7e.diff

LOG: [mlir][sparse] Add options to sparse-tensor-rewrite to disable rewriting rules for operators foreach and convert.

This is to help simplify FileCheck tests for sparse-tensor-rewrite.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136093

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index fd99e4f57af0e..2230f433e8526 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -158,14 +158,21 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
 std::unique_ptr<Pass> createSparseTensorCodegenPass();
 
 //===----------------------------------------------------------------------===//
-// Other rewriting rules and passes.
+// The SparseTensorRewriting pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
+void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
+                                   bool enableForeach, bool enableConvert);
 
 std::unique_ptr<Pass> createSparseTensorRewritePass();
 std::unique_ptr<Pass>
-createSparseTensorRewritePass(const SparsificationOptions &options);
+createSparseTensorRewritePass(const SparsificationOptions &options,
+                              bool enableForeach = true,
+                              bool enableConvert = true);
+
+//===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
 
 std::unique_ptr<Pass> createDenseBufferizationPass(
     const bufferization::OneShotBufferizationOptions &options);

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 26c78aea50a81..eee33b08af7eb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -28,7 +28,11 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
   ];
   let options = [
     Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
-           "true", "Enable runtime library for manipulating sparse tensors">
+           "true", "Enable runtime library for manipulating sparse tensors">,
+    Option<"enableForeach", "enable-foreach", "bool",
+           "true", "Enable rewriting rules for the foreach operator">,
+    Option<"enableConvert", "enable-convert", "bool",
+           "true", "Enable rewriting rules for the convert operator">,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 5ea55c822634b..b524ac1f0796c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -43,14 +43,18 @@ struct SparseTensorRewritePass
 
   SparseTensorRewritePass() = default;
   SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
-  SparseTensorRewritePass(const SparsificationOptions &options) {
+  SparseTensorRewritePass(const SparsificationOptions &options, bool foreach,
+                          bool convert) {
     enableRuntimeLibrary = options.enableRuntimeLibrary;
+    enableForeach = foreach;
+    enableConvert = convert;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseTensorRewriting(patterns, enableRuntimeLibrary);
+    populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
+                                  enableConvert);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -255,8 +259,10 @@ std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
 }
 
 std::unique_ptr<Pass>
-mlir::createSparseTensorRewritePass(const SparsificationOptions &options) {
-  return std::make_unique<SparseTensorRewritePass>(options);
+mlir::createSparseTensorRewritePass(const SparsificationOptions &options,
+                                    bool enableForeach, bool enableConvert) {
+  return std::make_unique<SparseTensorRewritePass>(options, enableForeach,
+                                                   enableConvert);
 }
 
 std::unique_ptr<Pass> mlir::createSparsificationPass() {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 26548871be508..36a564be934ae 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -612,11 +612,14 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
-                                         bool enableRT) {
+                                         bool enableRT, bool enableForeach,
+                                         bool /*enableConvert*/) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
                ReshapeRewriter<tensor::ExpandShapeOp>,
-               ReshapeRewriter<tensor::CollapseShapeOp>, ForeachRewriter>(
-      patterns.getContext());
+               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+  if (enableForeach)
+    patterns.add<ForeachRewriter>(patterns.getContext());
+
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
   if (!enableRT)
     patterns.add<ConcatenateRewriter, NewRewriter,


        


More information about the Mlir-commits mailing list