[Mlir-commits] [mlir] 6a93da9 - [mlir][sparse] add ReinterpretMapScopeOption for the pass (#70486)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 27 14:14:13 PDT 2023
Author: Peiming Liu
Date: 2023-10-27T14:14:09-07:00
New Revision: 6a93da99002bfd1e12ee42be79033aa0c47374aa
URL: https://github.com/llvm/llvm-project/commit/6a93da99002bfd1e12ee42be79033aa0c47374aa
DIFF: https://github.com/llvm/llvm-project/commit/6a93da99002bfd1e12ee42be79033aa0c47374aa.diff
LOG: [mlir][sparse] add ReinterpretMapScopeOption for the pass (#70486)
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 835c9baa2b9173c..b1979f032393bab 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -40,6 +40,13 @@ enum class SparseParallelizationStrategy {
kAnyStorageAnyLoop
};
+/// Defines a scope for reinterpret map pass.
+enum class ReinterpretMapScope {
+ kAll, // reinterprets all applicable operations
+ kGenericOnly, // reinterprets only linalg.generic
+ kExceptGeneric, // reinterprets operation other than linalg.generic
+};
+
/// Defines data movement strategy between host and device for GPU.
// TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
@@ -51,9 +58,11 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
// The SparseReinterpretMap pass.
//===----------------------------------------------------------------------===//
-void populateSparseReinterpretMap(RewritePatternSet &patterns);
+void populateSparseReinterpretMap(RewritePatternSet &patterns,
+ ReinterpretMapScope scope);
std::unique_ptr<Pass> createSparseReinterpretMapPass();
+std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
//===----------------------------------------------------------------------===//
// The PreSparsificationRewriting pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c23e062ef884115..db9cc701d253c01 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -27,6 +27,19 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
"linalg::LinalgDialect",
"sparse_tensor::SparseTensorDialect",
];
+ let options = [
+ Option<"scope", "scope", "mlir::ReinterpretMapScope",
+ "mlir::ReinterpretMapScope::kAll",
+ "Set the reiterpretation scope", [{llvm::cl::values(
+ clEnumValN(mlir::ReinterpretMapScope::kAll, "all",
+ "Run on every applicable operations."),
+ clEnumValN(mlir::ReinterpretMapScope::kGenericOnly,
+ "only-generic",
+ "Run only on linalg.generic operations."),
+ clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
+ "except-generic",
+ "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
+ ];
}
def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 881d235de7384f7..10722ccb6eea743 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -19,4 +19,5 @@ namespace {
} // namespace
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
+void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
+ ReinterpretMapScope scope) {}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 241232f7c75cb93..095a6ab9a508eb9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -49,11 +49,14 @@ struct SparseReinterpretMap
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
SparseReinterpretMap() = default;
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
+ SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
+ scope = options.scope;
+ }
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseReinterpretMap(patterns);
+ populateSparseReinterpretMap(patterns, scope);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -372,6 +375,13 @@ std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
return std::make_unique<SparseReinterpretMap>();
}
+std::unique_ptr<Pass>
+mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
+ SparseReinterpretMapOptions options;
+ options.scope = scope;
+ return std::make_unique<SparseReinterpretMap>(options);
+}
+
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
More information about the Mlir-commits
mailing list