[Mlir-commits] [mlir] [mlir][sparse] add ReinterpretMapScopeOption for the pass (PR #70486)

Peiming Liu llvmlistbot at llvm.org
Fri Oct 27 11:02:18 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70486

None

>From 8ab7b7487bff12fdf0d747192b214778ee105fcd Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 18:01:16 +0000
Subject: [PATCH] [mlir][sparse] add ReinterpretMapScopeOption for the pass

---
 .../mlir/Dialect/SparseTensor/Transforms/.#Passes.h |  1 +
 .../mlir/Dialect/SparseTensor/Transforms/Passes.h   | 12 +++++++++++-
 .../mlir/Dialect/SparseTensor/Transforms/Passes.td  | 13 +++++++++++++
 .../Transforms/SparseReinterpretMap.cpp             |  3 ++-
 .../SparseTensor/Transforms/SparseTensorPasses.cpp  | 12 +++++++++++-
 5 files changed, 38 insertions(+), 3 deletions(-)
 create mode 120000 mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
new file mode 120000
index 000000000000000..4d45164b3318ba4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
@@ -0,0 +1 @@
+peiming at peiming.c.googlers.com.13658:1697570271
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 835c9baa2b9173c..fff0615aa4cdd49 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -40,6 +40,14 @@ enum class SparseParallelizationStrategy {
   kAnyStorageAnyLoop
 };
 
+/// Define a scope for reinterpret map pass.
+enum class ReinterpretMapScope {
+  kAll,           // reinterpret all applicable operations.
+  kGenericOnly,   // reinterpret only linalg.generic.
+  kExceptGeneric, // reinterpret operation other than linalg.generic (e.g.,
+                  // foreach)
+};
+
 /// Defines data movement strategy between host and device for GPU.
 // TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
 enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
@@ -51,9 +59,11 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
 // The SparseReinterpretMap pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseReinterpretMap(RewritePatternSet &patterns);
+void populateSparseReinterpretMap(RewritePatternSet &patterns,
+                                  ReinterpretMapScope scope);
 
 std::unique_ptr<Pass> createSparseReinterpretMapPass();
+std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
 
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c23e062ef884115..db9cc701d253c01 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -27,6 +27,19 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
     "linalg::LinalgDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
+  let options = [
+    Option<"scope", "scope", "mlir::ReinterpretMapScope",
+       "mlir::ReinterpretMapScope::kAll",
+       "Set the reiterpretation scope", [{llvm::cl::values(
+         clEnumValN(mlir::ReinterpretMapScope::kAll, "all",
+                    "Run on every applicable operations."),
+         clEnumValN(mlir::ReinterpretMapScope::kGenericOnly,
+                    "only-generic",
+                    "Run only on linalg.generic operations."),
+         clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
+                    "except-generic",
+                    "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
+  ];
 }
 
 def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 881d235de7384f7..10722ccb6eea743 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -19,4 +19,5 @@ namespace {
 
 } // namespace
 
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
+void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
+                                        ReinterpretMapScope scope) {}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 241232f7c75cb93..095a6ab9a508eb9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -49,11 +49,14 @@ struct SparseReinterpretMap
     : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
   SparseReinterpretMap() = default;
   SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
+  SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
+    scope = options.scope;
+  }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseReinterpretMap(patterns);
+    populateSparseReinterpretMap(patterns, scope);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -372,6 +375,13 @@ std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
   return std::make_unique<SparseReinterpretMap>();
 }
 
+std::unique_ptr<Pass>
+mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
+  SparseReinterpretMapOptions options;
+  options.scope = scope;
+  return std::make_unique<SparseReinterpretMap>(options);
+}
+
 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
   return std::make_unique<PreSparsificationRewritePass>();
 }



More information about the Mlir-commits mailing list