[Mlir-commits] [mlir] [mlir][sparse] add a sparse_tensor.print operation (PR #83321)
Aart Bik
llvmlistbot at llvm.org
Wed Feb 28 12:14:18 PST 2024
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/83321
>From 30c0a8fce6bb9ae325f6bb6c93441b3b14f700ea Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 28 Feb 2024 10:37:41 -0800
Subject: [PATCH 1/2] [mlir][sparse] add a sparse_tensor.print operation
This operation is mainly used for testing and debugging
purposes but provides a very convenient way to quickly
inspect the contents of a sparse tensor (all components
over all stored levels).
Example:
[ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
[ 0, 0, 3, 4, 0, 5, 0, 0 ]
when stored sparse as DCSC prints as
---- Sparse Tensor ----
nse = 5
pos[0] : ( 0, 4, )
crd[0] : ( 0, 2, 3, 5, )
pos[1] : ( 0, 1, 3, 4, 5, )
crd[1] : ( 0, 0, 3, 3, 3, )
values : ( 1, 2, 3, 4, 5, )
----
---
.../SparseTensor/IR/SparseTensorOps.td | 22 ++
.../Transforms/SparseTensorRewriting.cpp | 95 +++++++-
.../SparseTensor/CPU/sparse_print.mlir | 215 ++++++++++++++++++
3 files changed, 331 insertions(+), 1 deletion(-)
create mode 100755 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3127cf1b1bcf69..9007e4e98e3163 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1453,4 +1453,26 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Debugging Operations.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_PrintOp : SparseTensor_Op<"print">,
+ Arguments<(ins AnySparseTensor:$tensor)> {
+ string summary = "Prints a sparse tensor (for testing and debugging)";
+ string description = [{
+ Prints the individual components of a sparse tensors (the positions,
+ coordinates, and values components) to stdout for testing and debugging
+ purposes. This operation lowers to just a few primitives in a light-weight
+ runtime support to simplify supporting this operation on new platforms.
+
+ Example:
+
+ ```mlir
+ sparse_tensor.print %tensor : tensor<1024x1024xf64, #CSR>
+ ```
+ }];
+ let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
+}
+
#endif // SPARSETENSOR_OPS
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bcc131781d34d..fa078e2c2efd75 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -21,9 +21,11 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
}
};
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(PrintOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto tensor = op.getTensor();
+ auto stt = getSparseTensorType(tensor);
+ // Header with NSE.
+ auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+ rewriter.create<vector::PrintOp>(
+ loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+ rewriter.create<vector::PrintOp>(loc, nse);
+ // Use the "codegen" foreach loop construct to iterate over
+ // all typical sparse tensor components for printing.
+ foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+ &tensor](Type tp, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level l, LevelType) {
+ switch (kind) {
+ case SparseTensorFieldKind::StorageSpec: {
+ break;
+ }
+ case SparseTensorFieldKind::PosMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+ printContents(rewriter, loc, tp, poss);
+ break;
+ }
+ case SparseTensorFieldKind::CrdMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+ printContents(rewriter, loc, tp, crds);
+ break;
+ }
+ case SparseTensorFieldKind::ValMemRef: {
+ rewriter.create<vector::PrintOp>(loc,
+ rewriter.getStringAttr("values : "));
+ auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
+ printContents(rewriter, loc, tp, vals);
+ break;
+ }
+ }
+ return true;
+ });
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ // Helper to print contents of a single memref. Note that for the "push_back"
+ // vectors, this prints the full capacity, not just the size. This is done
+ // on purpose, so that clients see how much storage has been allocated in
+ // total. Contents of the extra capacity in the buffer may be uninitialized
+ // (unless the flag enable-buffer-initialization is set to true).
+ //
+ // Generates code to print:
+ // ( a0, a1, ... )
+ static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+ Value vec) {
+ // Open bracket.
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ // For loop over elements.
+ auto zero = constantIndex(rewriter, loc, 0);
+ auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+ auto step = constantIndex(rewriter, loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto idx = forOp.getInductionVar();
+ auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
+ rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+ rewriter.setInsertionPointAfter(forOp);
+ // Close bracket and end of line.
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+ }
+};
+
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
public:
@@ -1284,7 +1376,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
- GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
+ GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+ patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
new file mode 100755
index 00000000000000..797a04bb9ff862
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
@@ -0,0 +1,215 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+// do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#AllDense = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : dense,
+ j : dense
+ )
+}>
+
+#AllDenseT = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ j : dense,
+ i : dense
+ )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : dense,
+ j : compressed
+ )
+}>
+
+#DCSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed,
+ j : compressed
+ )
+}>
+
+#CSC = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ j : dense,
+ i : compressed
+ )
+}>
+
+#DCSC = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ j : compressed,
+ i : compressed
+ )
+}>
+
+#BSR = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i floordiv 2 : compressed,
+ j floordiv 4 : compressed,
+ i mod 2 : dense,
+ j mod 4 : dense
+ )
+}>
+
+#BSRC = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i floordiv 2 : compressed,
+ j floordiv 4 : compressed,
+ j mod 4 : dense,
+ i mod 2 : dense
+ )
+}>
+
+#BSC = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ j floordiv 4 : compressed,
+ i floordiv 2 : compressed,
+ i mod 2 : dense,
+ j mod 4 : dense
+ )
+}>
+
+#BSCC = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ j floordiv 4 : compressed,
+ i floordiv 2 : compressed,
+ j mod 4 : dense,
+ i mod 2 : dense
+ )
+}>
+
+module {
+
+ //
+ // Main driver that tests sparse tensor storage.
+ //
+ func.func @main() {
+ %x = arith.constant dense <[
+ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
+ [ 0, 0, 0, 0, 0, 0, 0, 0 ],
+ [ 0, 0, 0, 0, 0, 0, 0, 0 ],
+ [ 0, 0, 3, 4, 0, 5, 0, 0 ] ]> : tensor<4x8xi32>
+
+ %a = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSR>
+ %b = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSR>
+ %c = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSC>
+ %d = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSC>
+ %e = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR>
+ %f = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSRC>
+ %g = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC>
+ %h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
+
+ //
+ // CHECK: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 5
+ // CHECK-NEXT: pos[1] : ( 0, 2, 2, 2, 5,
+ // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+ // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %a : tensor<4x8xi32, #CSR>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 5
+ // CHECK-NEXT: pos[0] : ( 0, 2,
+ // CHECK-NEXT: crd[0] : ( 0, 3,
+ // CHECK-NEXT: pos[1] : ( 0, 2, 5,
+ // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+ // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %b : tensor<4x8xi32, #DCSR>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 5
+ // CHECK-NEXT: pos[1] : ( 0, 1, 1, 3, 4, 4, 5, 5, 5,
+ // CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
+ // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %c : tensor<4x8xi32, #CSC>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 5
+ // CHECK-NEXT: pos[0] : ( 0, 4,
+ // CHECK-NEXT: crd[0] : ( 0, 2, 3, 5,
+ // CHECK-NEXT: pos[1] : ( 0, 1, 3, 4, 5,
+ // CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
+ // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %d : tensor<4x8xi32, #DCSC>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 24
+ // CHECK-NEXT: pos[0] : ( 0, 2,
+ // CHECK-NEXT: crd[0] : ( 0, 1,
+ // CHECK-NEXT: pos[1] : ( 0, 1, 3,
+ // CHECK-NEXT: crd[1] : ( 0, 0, 1,
+ // CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %e : tensor<4x8xi32, #BSR>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 24
+ // CHECK-NEXT: pos[0] : ( 0, 2,
+ // CHECK-NEXT: crd[0] : ( 0, 1,
+ // CHECK-NEXT: pos[1] : ( 0, 1, 3,
+ // CHECK-NEXT: crd[1] : ( 0, 0, 1,
+ // CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %f : tensor<4x8xi32, #BSRC>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 24
+ // CHECK-NEXT: pos[0] : ( 0, 2,
+ // CHECK-NEXT: crd[0] : ( 0, 1,
+ // CHECK-NEXT: pos[1] : ( 0, 2, 3,
+ // CHECK-NEXT: crd[1] : ( 0, 1, 1,
+ // CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %g : tensor<4x8xi32, #BSC>
+
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 24
+ // CHECK-NEXT: pos[0] : ( 0, 2,
+ // CHECK-NEXT: crd[0] : ( 0, 1,
+ // CHECK-NEXT: pos[1] : ( 0, 2, 3,
+ // CHECK-NEXT: crd[1] : ( 0, 1, 1,
+ // CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
+ // CHECK-NEXT: ----
+ sparse_tensor.print %h : tensor<4x8xi32, #BSCC>
+
+ // Release the resources.
+ bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
+ bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
+ bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
+ bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
+ bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
+ bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
+ bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
+ bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
+
+ return
+ }
+}
>From fa556ad8d673b98931247cb4be3fcafe26fb4639 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 28 Feb 2024 12:13:54 -0800
Subject: [PATCH 2/2] reviewer feedback
---
.../Transforms/SparseTensorRewriting.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa078e2c2efd75..c95b7b015b3725 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -632,8 +632,8 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
- auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
- printContents(rewriter, loc, tp, poss);
+ auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+ printContents(rewriter, loc, tp, pos);
break;
}
case SparseTensorFieldKind::CrdMemRef: {
@@ -642,15 +642,15 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
- auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
- printContents(rewriter, loc, tp, crds);
+ auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+ printContents(rewriter, loc, tp, crd);
break;
}
case SparseTensorFieldKind::ValMemRef: {
rewriter.create<vector::PrintOp>(loc,
rewriter.getStringAttr("values : "));
- auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
- printContents(rewriter, loc, tp, vals);
+ auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
+ printContents(rewriter, loc, tp, val);
break;
}
}
More information about the Mlir-commits
mailing list