[Mlir-commits] [mlir] [mlir][sparse] add boilterplate code for a new reintepret map pass (PR #70393)
Aart Bik
llvmlistbot at llvm.org
Thu Oct 26 16:39:13 PDT 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/70393
The interesting stuff is of course still coming ;-)
>From af80079360392a329241c929b3859d7f51d1564e Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 26 Oct 2023 16:21:24 -0700
Subject: [PATCH] [mlir][sparse] add boilterplate code for a new reintepret map
pass
The interesting stuff is of course still coming ;-)
---
.../Dialect/SparseTensor/Transforms/Passes.h | 8 ++++++
.../Dialect/SparseTensor/Transforms/Passes.td | 18 +++++++++++++
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseReinterpretMap.cpp | 22 +++++++++++++++
.../Transforms/SparseTensorPasses.cpp | 27 ++++++++++++-------
.../SparseTensor/sparse_reinterpret_map.mlir | 10 +++++++
6 files changed, 77 insertions(+), 9 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
create mode 100644 mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 204bc1ec2def1bb..835c9baa2b9173c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,6 +47,14 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
#define GEN_PASS_DECL
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+//===----------------------------------------------------------------------===//
+// The SparseReinterpretMap pass.
+//===----------------------------------------------------------------------===//
+
+void populateSparseReinterpretMap(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseReinterpretMapPass();
+
//===----------------------------------------------------------------------===//
// The PreSparsificationRewriting pass.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 73ecf5061fa16ca..c23e062ef884115 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -11,6 +11,24 @@
include "mlir/Pass/PassBase.td"
+def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
+ let summary = "Reinterprets sparse tensor type mappings";
+ let description = [{
+ A pass that reinterprets the mappings in all sparse tensor types in a way that
+ enables subsequent sparification. This involves expressing all `linalg.generic`
+ operations in terms of level coordinates (rather than the dimension coordinates
+ of the input tensors) to align the iteration space with the potentially remapped
+ level space as well as resolving cycles in the resulting iteration graphs with
+ explicit sparse tensor conversions where needed.
+ }];
+ let constructor = "mlir::createSparseReinterpretMapPass()";
+ let dependentDialects = [
+ "affine::AffineDialect",
+ "linalg::LinalgDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
let summary = "Applies sparse tensor rewriting rules prior to sparsification";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 0ca6668c8c74745..b8a2ff26b6794f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
LoopEmitter.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
+ SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
new file mode 100644
index 000000000000000..881d235de7384f7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -0,0 +1,22 @@
+//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace {
+
+// TODO:
+// (1) insert the zero-cost sparse_tensor.reinterpret_map ops
+// (2) rewrite linalg.generic ops traits on level crds
+// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+
+} // namespace
+
+void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index eaf15ff29dd721b..241232f7c75cb93 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
+#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
@@ -44,9 +45,21 @@ namespace {
// Passes implementation.
//===----------------------------------------------------------------------===//
+struct SparseReinterpretMap
+ : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
+ SparseReinterpretMap() = default;
+ SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ populateSparseReinterpretMap(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct PreSparsificationRewritePass
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
-
PreSparsificationRewritePass() = default;
PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
default;
@@ -61,7 +74,6 @@ struct PreSparsificationRewritePass
struct SparsificationPass
: public impl::SparsificationPassBase<SparsificationPass> {
-
SparsificationPass() = default;
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
@@ -108,7 +120,6 @@ struct StageSparseOperationsPass
struct PostSparsificationRewritePass
: public impl::PostSparsificationRewriteBase<
PostSparsificationRewritePass> {
-
PostSparsificationRewritePass() = default;
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
default;
@@ -129,7 +140,6 @@ struct PostSparsificationRewritePass
struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
-
SparseTensorConversionPass() = default;
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
@@ -200,7 +210,6 @@ struct SparseTensorConversionPass
struct SparseTensorCodegenPass
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
-
SparseTensorCodegenPass() = default;
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
@@ -266,7 +275,6 @@ struct SparseTensorCodegenPass
struct SparseBufferRewritePass
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
-
SparseBufferRewritePass() = default;
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
SparseBufferRewritePass(bool enableInit) {
@@ -283,7 +291,6 @@ struct SparseBufferRewritePass
struct SparseVectorizationPass
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
-
SparseVectorizationPass() = default;
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
@@ -306,7 +313,6 @@ struct SparseVectorizationPass
struct SparseGPUCodegenPass
: public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
-
SparseGPUCodegenPass() = default;
SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
SparseGPUCodegenPass(unsigned nT) { numThreads = nT; }
@@ -321,7 +327,6 @@ struct SparseGPUCodegenPass
struct StorageSpecifierToLLVMPass
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
-
StorageSpecifierToLLVMPass() = default;
void runOnOperation() override {
@@ -363,6 +368,10 @@ struct StorageSpecifierToLLVMPass
// Pass creation methods.
//===----------------------------------------------------------------------===//
+std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
+ return std::make_unique<SparseReinterpretMap>();
+}
+
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
new file mode 100644
index 000000000000000..8517f2a27ae3fc8
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+
+// CHECK-LABEL: func @sparse_nop(
+// CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK: return %[[A0]]
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+ return %arg0 : tensor<?xf64, #SparseVector>
+}
More information about the Mlir-commits
mailing list