[Mlir-commits] [mlir] [mlir][sparse] add boilterplate code for a new reintepret map pass (PR #70393)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 16:40:44 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>

The interesting stuff is of course still coming ;-)

---
Full diff: https://github.com/llvm/llvm-project/pull/70393.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+8) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+18) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+22) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+18-9) 
- (added) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 204bc1ec2def1bb..835c9baa2b9173c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,6 +47,14 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
 #define GEN_PASS_DECL
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 
+//===----------------------------------------------------------------------===//
+// The SparseReinterpretMap pass.
+//===----------------------------------------------------------------------===//
+
+void populateSparseReinterpretMap(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseReinterpretMapPass();
+
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 73ecf5061fa16ca..c23e062ef884115 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -11,6 +11,24 @@
 
 include "mlir/Pass/PassBase.td"
 
+def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
+  let summary = "Reinterprets sparse tensor type mappings";
+  let description = [{
+    A pass that reinterprets the mappings in all sparse tensor types in a way that
+    enables subsequent sparification. This involves expressing all `linalg.generic`
+    operations in terms of level coordinates (rather than the dimension coordinates
+    of the input tensors) to align the iteration space with the potentially remapped
+    level space as well as resolving cycles in the resulting iteration graphs with
+    explicit sparse tensor conversions where needed.
+  }];
+  let constructor = "mlir::createSparseReinterpretMapPass()";
+  let dependentDialects = [
+    "affine::AffineDialect",
+    "linalg::LinalgDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules prior to sparsification";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 0ca6668c8c74745..b8a2ff26b6794f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   LoopEmitter.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
+  SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
new file mode 100644
index 000000000000000..881d235de7384f7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -0,0 +1,22 @@
+//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace {
+
+// TODO:
+//   (1) insert the zero-cost sparse_tensor.reinterpret_map ops
+//   (2) rewrite linalg.generic ops traits on level crds
+//   (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+
+} // namespace
+
+void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index eaf15ff29dd721b..241232f7c75cb93 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
+#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
 #define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
@@ -44,9 +45,21 @@ namespace {
 // Passes implementation.
 //===----------------------------------------------------------------------===//
 
+struct SparseReinterpretMap
+    : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
+  SparseReinterpretMap() = default;
+  SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateSparseReinterpretMap(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct PreSparsificationRewritePass
     : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
-
   PreSparsificationRewritePass() = default;
   PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
       default;
@@ -61,7 +74,6 @@ struct PreSparsificationRewritePass
 
 struct SparsificationPass
     : public impl::SparsificationPassBase<SparsificationPass> {
-
   SparsificationPass() = default;
   SparsificationPass(const SparsificationPass &pass) = default;
   SparsificationPass(const SparsificationOptions &options) {
@@ -108,7 +120,6 @@ struct StageSparseOperationsPass
 struct PostSparsificationRewritePass
     : public impl::PostSparsificationRewriteBase<
           PostSparsificationRewritePass> {
-
   PostSparsificationRewritePass() = default;
   PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
       default;
@@ -129,7 +140,6 @@ struct PostSparsificationRewritePass
 
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
-
   SparseTensorConversionPass() = default;
   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
 
@@ -200,7 +210,6 @@ struct SparseTensorConversionPass
 
 struct SparseTensorCodegenPass
     : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
-
   SparseTensorCodegenPass() = default;
   SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
   SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
@@ -266,7 +275,6 @@ struct SparseTensorCodegenPass
 
 struct SparseBufferRewritePass
     : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
-
   SparseBufferRewritePass() = default;
   SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
   SparseBufferRewritePass(bool enableInit) {
@@ -283,7 +291,6 @@ struct SparseBufferRewritePass
 
 struct SparseVectorizationPass
     : public impl::SparseVectorizationBase<SparseVectorizationPass> {
-
   SparseVectorizationPass() = default;
   SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
   SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
@@ -306,7 +313,6 @@ struct SparseVectorizationPass
 
 struct SparseGPUCodegenPass
     : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
-
   SparseGPUCodegenPass() = default;
   SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
   SparseGPUCodegenPass(unsigned nT) { numThreads = nT; }
@@ -321,7 +327,6 @@ struct SparseGPUCodegenPass
 
 struct StorageSpecifierToLLVMPass
     : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
-
   StorageSpecifierToLLVMPass() = default;
 
   void runOnOperation() override {
@@ -363,6 +368,10 @@ struct StorageSpecifierToLLVMPass
 // Pass creation methods.
 //===----------------------------------------------------------------------===//
 
+std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
+  return std::make_unique<SparseReinterpretMap>();
+}
+
 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
   return std::make_unique<PreSparsificationRewritePass>();
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
new file mode 100644
index 000000000000000..8517f2a27ae3fc8
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+
+// CHECK-LABEL: func @sparse_nop(
+//  CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
+//       CHECK: return %[[A0]]
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+  return %arg0 : tensor<?xf64, #SparseVector>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/70393


More information about the Mlir-commits mailing list