[Mlir-commits] [mlir] 6a93da9 - [mlir][sparse] add ReinterpretMapScopeOption for the pass (#70486)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 27 14:14:13 PDT 2023


Author: Peiming Liu
Date: 2023-10-27T14:14:09-07:00
New Revision: 6a93da99002bfd1e12ee42be79033aa0c47374aa

URL: https://github.com/llvm/llvm-project/commit/6a93da99002bfd1e12ee42be79033aa0c47374aa
DIFF: https://github.com/llvm/llvm-project/commit/6a93da99002bfd1e12ee42be79033aa0c47374aa.diff

LOG: [mlir][sparse] add ReinterpretMapScopeOption for the pass (#70486)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 835c9baa2b9173c..b1979f032393bab 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -40,6 +40,13 @@ enum class SparseParallelizationStrategy {
   kAnyStorageAnyLoop
 };
 
+/// Defines a scope for reinterpret map pass.
+enum class ReinterpretMapScope {
+  kAll,           // reinterprets all applicable operations
+  kGenericOnly,   // reinterprets only linalg.generic
+  kExceptGeneric, // reinterprets operation other than linalg.generic
+};
+
 /// Defines data movement strategy between host and device for GPU.
 // TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
 enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
@@ -51,9 +58,11 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
 // The SparseReinterpretMap pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseReinterpretMap(RewritePatternSet &patterns);
+void populateSparseReinterpretMap(RewritePatternSet &patterns,
+                                  ReinterpretMapScope scope);
 
 std::unique_ptr<Pass> createSparseReinterpretMapPass();
+std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
 
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c23e062ef884115..db9cc701d253c01 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -27,6 +27,19 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
     "linalg::LinalgDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
+  let options = [
+    Option<"scope", "scope", "mlir::ReinterpretMapScope",
+       "mlir::ReinterpretMapScope::kAll",
+       "Set the reiterpretation scope", [{llvm::cl::values(
+         clEnumValN(mlir::ReinterpretMapScope::kAll, "all",
+                    "Run on every applicable operations."),
+         clEnumValN(mlir::ReinterpretMapScope::kGenericOnly,
+                    "only-generic",
+                    "Run only on linalg.generic operations."),
+         clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
+                    "except-generic",
+                    "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
+  ];
 }
 
 def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 881d235de7384f7..10722ccb6eea743 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -19,4 +19,5 @@ namespace {
 
 } // namespace
 
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
+void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
+                                        ReinterpretMapScope scope) {}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 241232f7c75cb93..095a6ab9a508eb9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -49,11 +49,14 @@ struct SparseReinterpretMap
     : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
   SparseReinterpretMap() = default;
   SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
+  SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
+    scope = options.scope;
+  }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseReinterpretMap(patterns);
+    populateSparseReinterpretMap(patterns, scope);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -372,6 +375,13 @@ std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
   return std::make_unique<SparseReinterpretMap>();
 }
 
+std::unique_ptr<Pass>
+mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
+  SparseReinterpretMapOptions options;
+  options.scope = scope;
+  return std::make_unique<SparseReinterpretMap>(options);
+}
+
 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
   return std::make_unique<PreSparsificationRewritePass>();
 }


        


More information about the Mlir-commits mailing list