[Mlir-commits] [mlir] f82bee1 - [mlir][sparse] split post-sparsification-rewriting into two passes. (#70727)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 30 15:22:26 PDT 2023


Author: Peiming Liu
Date: 2023-10-30T15:22:21-07:00
New Revision: f82bee1367a1d612d688790b81c1c146ce99f2ea

URL: https://github.com/llvm/llvm-project/commit/f82bee1367a1d612d688790b81c1c146ce99f2ea
DIFF: https://github.com/llvm/llvm-project/commit/f82bee1367a1d612d688790b81c1c146ce99f2ea.diff

LOG: [mlir][sparse] split post-sparsification-rewriting into two passes. (#70727)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
    mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
    mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
    mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
    mlir/test/Dialect/SparseTensor/sparse_concat.mlir
    mlir/test/Dialect/SparseTensor/sparse_expand.mlir
    mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
    mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index b1979f032393bab..a8d4d752dff8882 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -114,17 +114,23 @@ void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
 std::unique_ptr<Pass> createStageSparseOperationsPass();
 
 //===----------------------------------------------------------------------===//
-// The PostSparsificationRewriting pass.
+// The LowerSparseOpsToForeach pass.
 //===----------------------------------------------------------------------===//
 
-void populatePostSparsificationRewriting(RewritePatternSet &patterns,
-                                         bool enableRT, bool enableForeach,
-                                         bool enableConvert);
+void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
+                                             bool enableRT, bool enableConvert);
 
-std::unique_ptr<Pass> createPostSparsificationRewritePass();
-std::unique_ptr<Pass>
-createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
-                                    bool enableConvert = true);
+std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
+std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
+                                                        bool enableConvert);
+
+//===----------------------------------------------------------------------===//
+// The LowerForeachToSCF pass.
+//===----------------------------------------------------------------------===//
+
+void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerForeachToSCFPass();
 
 //===----------------------------------------------------------------------===//
 // The SparseTensorConversion pass.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 99dbd7ab3677e74..995e842289035b1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -167,13 +167,12 @@ def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
   ];
 }
 
-def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
+def LowerSparseOpsToForeach : Pass<"lower-sparse-ops-to-foreach", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules after sparsification";
   let description = [{
-    A pass that applies rewriting rules to sparse tensor operations after
-    running the actual sparsification pass.
+    A pass that lowers high-level sparse operations to sparse_tensor.foreach.
   }];
-  let constructor = "mlir::createPostSparsificationRewritePass()";
+  let constructor = "mlir::createLowerSparseOpsToForeachPass()";
   let dependentDialects = [
     "affine::AffineDialect",
     "arith::ArithDialect",
@@ -186,13 +185,25 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
   let options = [
     Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
            "true", "Enable runtime library for manipulating sparse tensors">,
-    Option<"enableForeach", "enable-foreach", "bool",
-           "true", "Enable rewriting rules for the foreach operator">,
     Option<"enableConvert", "enable-convert", "bool",
            "true", "Enable rewriting rules for the convert operator">,
   ];
 }
 
+def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
+  let summary = "Decompose a complex sparse operation into multiple stages";
+  let description = [{
+    A pass that lowers sparse_tensor.foreach operation to scf dialect.
+  }];
+  let constructor = "mlir::createLowerForeachToSCFPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
+
 def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
   let summary = "Convert sparse tensors and primitives to library calls";
   let description = [{

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 095a6ab9a508eb9..c5fd19a811d6bb0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -25,7 +25,8 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
-#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
+#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
+#define GEN_PASS_DEF_LOWERFOREACHTOSCF
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
@@ -120,23 +121,34 @@ struct StageSparseOperationsPass
   }
 };
 
-struct PostSparsificationRewritePass
-    : public impl::PostSparsificationRewriteBase<
-          PostSparsificationRewritePass> {
-  PostSparsificationRewritePass() = default;
-  PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
+struct LowerSparseOpsToForeachPass
+    : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
+  LowerSparseOpsToForeachPass() = default;
+  LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
       default;
-  PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
+  LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
     enableRuntimeLibrary = enableRT;
-    enableForeach = foreach;
     enableConvert = convert;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
-                                        enableForeach, enableConvert);
+    populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
+                                            enableConvert);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+struct LowerForeachToSCFPass
+    : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
+  LowerForeachToSCFPass() = default;
+  LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateLowerForeachToSCFPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -399,15 +411,17 @@ std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
   return std::make_unique<StageSparseOperationsPass>();
 }
 
-std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
-  return std::make_unique<PostSparsificationRewritePass>();
+std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
+  return std::make_unique<LowerSparseOpsToForeachPass>();
 }
 
 std::unique_ptr<Pass>
-mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
-                                          bool enableConvert) {
-  return std::make_unique<PostSparsificationRewritePass>(
-      enableRT, enableForeach, enableConvert);
+mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
+  return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
+}
+
+std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
+  return std::make_unique<LowerForeachToSCFPass>();
 }
 
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index e9bcb5dc070ade9..528e70bd3b1ef5f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1303,10 +1303,9 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
                GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
 }
 
-void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
-                                               bool enableRT,
-                                               bool enableForeach,
-                                               bool enableConvert) {
+void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
+                                                   bool enableRT,
+                                                   bool enableConvert) {
   patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
                ReshapeRewriter<tensor::ExpandShapeOp>,
                ReshapeRewriter<tensor::CollapseShapeOp>,
@@ -1314,10 +1313,13 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
                SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
       patterns.getContext());
-  if (enableForeach)
-    patterns.add<ForeachRewriter>(patterns.getContext());
+
   if (enableConvert)
     patterns.add<DirectConvertRewriter>(patterns.getContext());
   if (!enableRT)
     patterns.add<NewRewriter>(patterns.getContext());
 }
+
+void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
+  patterns.add<ForeachRewriter>(patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index d8a24ea3527b199..f3f3828e0c5bdff 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -141,7 +141,10 @@ class SparsificationAndBufferizationPass
       OpPassManager pm("builtin.module");
       pm.addPass(createSparsificationPass(sparsificationOptions));
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
-      pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
+      pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
+                                                   /*enableConvert=*/true));
+      // TODO: DemapPass here!
+      pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
       if (vectorLength > 0) {
         pm.addPass(mlir::createLoopInvariantCodeMotionPass());
         pm.addPass(createSparseVectorizationPass(

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 8993333d6e5333d..c53ec7408bc3b8a 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen  --canonicalize -cse | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen  --canonicalize -cse | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 092ba6b8358b598..27d8f296c9ad0ce 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   map = (d0) -> (d0 : compressed)

diff  --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 4dba16df39f5c65..4f37ae9207be9cc 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   map = (d0) -> (d0 : compressed)

diff  --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index e2dcb068e11851e..730a5452df39449 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   map = (d0) -> (d0 : compressed)

diff  --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 0280e27b4e312a0..896bc02212971f0 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
 
 #SparseVector64 = #sparse_tensor.encoding<{
   map = (d0) -> (d0 : compressed),

diff  --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index 1e72f059baec294..93e802bc6065e42 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" | \
-// RUN: FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
+// RUN: --lower-sparse-foreach-to-scf | FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed)

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index f3d3dd28563e891..e4e2748112d78c4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" --lower-sparse-foreach-to-scf \
 // RUN: | FileCheck %s
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" --lower-sparse-foreach-to-scf \
 // RUN: | FileCheck %s
 
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index 3ee6e84a2382a9e..0f367f12483f63a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -4,7 +4,8 @@
 // RUN:   FileCheck %s --check-prefix=CHECK-SPARSE
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:             --linalg-fuse-elementwise-ops \
-// RUN:             --sparsification --post-sparsification-rewrite \
+// RUN:             --sparsification --lower-sparse-ops-to-foreach \
+// RUN:             --lower-sparse-foreach-to-scf \
 // RUN:             --sparse-tensor-conversion --cse | \
 // RUN:   FileCheck %s --check-prefix=CHECK-CONVERT
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index bbce42c100641ab..5983289c752efc1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-foreach-to-scf --canonicalize | FileCheck %s
 
 // CHECK-LABEL: func.func @sparse_foreach_constant
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 9af998be2f68297..80cfa3c635f3613 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
+// RUN: mlir-opt %s --canonicalize --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
 
 #COO = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 4f105f3e19b3e75..d3d6d8c91fa45a1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
-// RUN: --cse --canonicalize  | FileCheck %s
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
-// RUN: --cse --canonicalize  | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" \
+// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize  | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
+// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize  | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 #SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
index a1578eb20b8ba3b..339d65ce5716fab 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
-// RUN: --cse --canonicalize  | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
+// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize  | FileCheck %s
 
 #SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 


        


More information about the Mlir-commits mailing list