[Mlir-commits] [mlir] [mlir][sparse] add ReinterpretMapScopeOption for the pass (PR #70486)
Peiming Liu
llvmlistbot at llvm.org
Fri Oct 27 12:50:08 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70486
>From 8ab7b7487bff12fdf0d747192b214778ee105fcd Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 18:01:16 +0000
Subject: [PATCH 1/2] [mlir][sparse] add ReinterpretMapScopeOption for the pass
---
.../mlir/Dialect/SparseTensor/Transforms/.#Passes.h | 1 +
.../mlir/Dialect/SparseTensor/Transforms/Passes.h | 12 +++++++++++-
.../mlir/Dialect/SparseTensor/Transforms/Passes.td | 13 +++++++++++++
.../Transforms/SparseReinterpretMap.cpp | 3 ++-
.../SparseTensor/Transforms/SparseTensorPasses.cpp | 12 +++++++++++-
5 files changed, 38 insertions(+), 3 deletions(-)
create mode 120000 mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
new file mode 120000
index 000000000000000..4d45164b3318ba4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
@@ -0,0 +1 @@
+peiming at peiming.c.googlers.com.13658:1697570271
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 835c9baa2b9173c..fff0615aa4cdd49 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -40,6 +40,14 @@ enum class SparseParallelizationStrategy {
kAnyStorageAnyLoop
};
+/// Define a scope for reinterpret map pass.
+enum class ReinterpretMapScope {
+ kAll, // reinterpret all applicable operations.
+ kGenericOnly, // reinterpret only linalg.generic.
+ kExceptGeneric, // reinterpret operation other than linalg.generic (e.g.,
+ // foreach)
+};
+
/// Defines data movement strategy between host and device for GPU.
// TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
@@ -51,9 +59,11 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
// The SparseReinterpretMap pass.
//===----------------------------------------------------------------------===//
-void populateSparseReinterpretMap(RewritePatternSet &patterns);
+void populateSparseReinterpretMap(RewritePatternSet &patterns,
+ ReinterpretMapScope scope);
std::unique_ptr<Pass> createSparseReinterpretMapPass();
+std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
//===----------------------------------------------------------------------===//
// The PreSparsificationRewriting pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c23e062ef884115..db9cc701d253c01 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -27,6 +27,19 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
"linalg::LinalgDialect",
"sparse_tensor::SparseTensorDialect",
];
+ let options = [
+ Option<"scope", "scope", "mlir::ReinterpretMapScope",
+ "mlir::ReinterpretMapScope::kAll",
+ "Set the reiterpretation scope", [{llvm::cl::values(
+ clEnumValN(mlir::ReinterpretMapScope::kAll, "all",
+ "Run on every applicable operations."),
+ clEnumValN(mlir::ReinterpretMapScope::kGenericOnly,
+ "only-generic",
+ "Run only on linalg.generic operations."),
+ clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
+ "except-generic",
+ "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
+ ];
}
def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 881d235de7384f7..10722ccb6eea743 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -19,4 +19,5 @@ namespace {
} // namespace
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
+void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
+ ReinterpretMapScope scope) {}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 241232f7c75cb93..095a6ab9a508eb9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -49,11 +49,14 @@ struct SparseReinterpretMap
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
SparseReinterpretMap() = default;
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
+ SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
+ scope = options.scope;
+ }
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseReinterpretMap(patterns);
+ populateSparseReinterpretMap(patterns, scope);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -372,6 +375,13 @@ std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
return std::make_unique<SparseReinterpretMap>();
}
+std::unique_ptr<Pass>
+mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
+ SparseReinterpretMapOptions options;
+ options.scope = scope;
+ return std::make_unique<SparseReinterpretMap>(options);
+}
+
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
>From ce2ce190495aa045f41a8f443234972b79e0ac74 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 19:49:53 +0000
Subject: [PATCH 2/2] address comments.
---
.../mlir/Dialect/SparseTensor/Transforms/.#Passes.h | 1 -
.../include/mlir/Dialect/SparseTensor/Transforms/Passes.h | 8 ++++----
2 files changed, 4 insertions(+), 5 deletions(-)
delete mode 120000 mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
deleted file mode 120000
index 4d45164b3318ba4..000000000000000
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/.#Passes.h
+++ /dev/null
@@ -1 +0,0 @@
-peiming at peiming.c.googlers.com.13658:1697570271
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index fff0615aa4cdd49..4b7eba591c4a4ee 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -40,11 +40,11 @@ enum class SparseParallelizationStrategy {
kAnyStorageAnyLoop
};
-/// Define a scope for reinterpret map pass.
+/// Defines a scope for reinterpret map pass.
enum class ReinterpretMapScope {
- kAll, // reinterpret all applicable operations.
- kGenericOnly, // reinterpret only linalg.generic.
- kExceptGeneric, // reinterpret operation other than linalg.generic (e.g.,
+ kAll, // reinterprets all applicable operations
+ kGenericOnly, // reinterprets only linalg.generic
+ kExceptGeneric, // reinterprets operation other than linalg.generic (e.g.,
// foreach)
};
More information about the Mlir-commits
mailing list