[Mlir-commits] [mlir] [mlir][sparse] add a sparse_tensor.print operation (PR #83321)

Aart Bik llvmlistbot at llvm.org
Wed Feb 28 12:14:18 PST 2024


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/83321

>From 30c0a8fce6bb9ae325f6bb6c93441b3b14f700ea Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 28 Feb 2024 10:37:41 -0800
Subject: [PATCH 1/2] [mlir][sparse] add a sparse_tensor.print operation

This operation is mainly used for testing and debugging
purposes but provides a very convenient way to quickly
inspect the contents of a sparse tensor (all components
over all stored levels).

Example:

[ [ 1, 0, 2, 0, 0, 0, 0, 0 ],
  [ 0, 0, 0, 0, 0, 0, 0, 0 ],
  [ 0, 0, 0, 0, 0, 0, 0, 0 ],
  [ 0, 0, 3, 4, 0, 5, 0, 0 ]

when stored sparse as DCSC prints as

---- Sparse Tensor ----
nse = 5
pos[0] : ( 0, 4,  )
crd[0] : ( 0, 2, 3, 5,  )
pos[1] : ( 0, 1, 3, 4, 5,  )
crd[1] : ( 0, 0, 3, 3, 3,  )
values : ( 1, 2, 3, 4, 5,  )
----
---
 .../SparseTensor/IR/SparseTensorOps.td        |  22 ++
 .../Transforms/SparseTensorRewriting.cpp      |  95 +++++++-
 .../SparseTensor/CPU/sparse_print.mlir        | 215 ++++++++++++++++++
 3 files changed, 331 insertions(+), 1 deletion(-)
 create mode 100755 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3127cf1b1bcf69..9007e4e98e3163 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1453,4 +1453,26 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Debugging Operations.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_PrintOp : SparseTensor_Op<"print">,
+    Arguments<(ins AnySparseTensor:$tensor)> {
+  string summary = "Prints a sparse tensor (for testing and debugging)";
+  string description = [{
+    Prints the individual components of a sparse tensors (the positions,
+    coordinates, and values components) to stdout for testing and debugging
+    purposes. This operation lowers to just a few primitives in a light-weight
+    runtime support to simplify supporting this operation on new platforms.
+
+    Example:
+
+    ```mlir
+    sparse_tensor.print %tensor : tensor<1024x1024xf64, #CSR>
+    ```
+  }];
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
+}
+
 #endif // SPARSETENSOR_OPS
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bcc131781d34d..fa078e2c2efd75 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -21,9 +21,11 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
   }
 };
 
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(PrintOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto tensor = op.getTensor();
+    auto stt = getSparseTensorType(tensor);
+    // Header with NSE.
+    auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+    rewriter.create<vector::PrintOp>(
+        loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+    rewriter.create<vector::PrintOp>(loc, nse);
+    // Use the "codegen" foreach loop construct to iterate over
+    // all typical sparse tensor components for printing.
+    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+                                            &tensor](Type tp, FieldIndex,
+                                                     SparseTensorFieldKind kind,
+                                                     Level l, LevelType) {
+      switch (kind) {
+      case SparseTensorFieldKind::StorageSpec: {
+        break;
+      }
+      case SparseTensorFieldKind::PosMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, poss);
+        break;
+      }
+      case SparseTensorFieldKind::CrdMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, crds);
+        break;
+      }
+      case SparseTensorFieldKind::ValMemRef: {
+        rewriter.create<vector::PrintOp>(loc,
+                                         rewriter.getStringAttr("values : "));
+        auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
+        printContents(rewriter, loc, tp, vals);
+        break;
+      }
+      }
+      return true;
+    });
+    rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
+    rewriter.eraseOp(op);
+    return success();
+  }
+
+private:
+  // Helper to print contents of a single memref. Note that for the "push_back"
+  // vectors, this prints the full capacity, not just the size. This is done
+  // on purpose, so that clients see how much storage has been allocated in
+  // total. Contents of the extra capacity in the buffer may be uninitialized
+  // (unless the flag enable-buffer-initialization is set to true).
+  //
+  // Generates code to print:
+  //    ( a0, a1, ... )
+  static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+                            Value vec) {
+    // Open bracket.
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+    // For loop over elements.
+    auto zero = constantIndex(rewriter, loc, 0);
+    auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+    auto step = constantIndex(rewriter, loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+    rewriter.setInsertionPointToStart(forOp.getBody());
+    auto idx = forOp.getInductionVar();
+    auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
+    rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+    rewriter.setInsertionPointAfter(forOp);
+    // Close bracket and end of line.
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+  }
+};
+
 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
 public:
@@ -1284,7 +1376,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
+               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+      patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
new file mode 100755
index 00000000000000..797a04bb9ff862
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
@@ -0,0 +1,215 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#AllDense = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : dense
+  )
+}>
+
+#AllDenseT = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j : dense,
+    i : dense
+  )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+#DCSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed,
+    j : compressed
+  )
+}>
+
+#CSC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j : dense,
+    i : compressed
+  )
+}>
+
+#DCSC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j : compressed,
+    i : compressed
+  )
+}>
+
+#BSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i floordiv 2 : compressed,
+    j floordiv 4 : compressed,
+    i mod 2 : dense,
+    j mod 4 : dense
+  )
+}>
+
+#BSRC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i floordiv 2 : compressed,
+    j floordiv 4 : compressed,
+    j mod 4 : dense,
+    i mod 2 : dense
+  )
+}>
+
+#BSC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j floordiv 4 : compressed,
+    i floordiv 2 : compressed,
+    i mod 2 : dense,
+    j mod 4 : dense
+  )
+}>
+
+#BSCC = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    j floordiv 4 : compressed,
+    i floordiv 2 : compressed,
+    j mod 4 : dense,
+    i mod 2 : dense
+  )
+}>
+
+module {
+
+  //
+  // Main driver that tests sparse tensor storage.
+  //
+  func.func @main() {
+    %x = arith.constant dense <[
+         [ 1, 0, 2, 0, 0, 0, 0, 0 ],
+         [ 0, 0, 0, 0, 0, 0, 0, 0 ],
+         [ 0, 0, 0, 0, 0, 0, 0, 0 ],
+         [ 0, 0, 3, 4, 0, 5, 0, 0 ] ]> : tensor<4x8xi32>
+
+    %a = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSR>
+    %b = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSR>
+    %c = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSC>
+    %d = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSC>
+    %e = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR>
+    %f = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSRC>
+    %g = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC>
+    %h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
+
+    //
+    // CHECK:      ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[1] : ( 0, 2, 2, 2, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %a : tensor<4x8xi32, #CSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 3,
+    // CHECK-NEXT: pos[1] : ( 0, 2, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %b : tensor<4x8xi32, #DCSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[1] : ( 0, 1, 1, 3, 4, 4, 5, 5, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %c : tensor<4x8xi32, #CSC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 4,
+    // CHECK-NEXT: crd[0] : ( 0, 2, 3, 5,
+    // CHECK-NEXT: pos[1] : ( 0, 1, 3, 4, 5,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
+    // CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %d : tensor<4x8xi32, #DCSC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 1, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 1,
+    // CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %e : tensor<4x8xi32, #BSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 1, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 0, 1,
+    // CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %f : tensor<4x8xi32, #BSRC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 2, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 1, 1,
+    // CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %g : tensor<4x8xi32, #BSC>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 24
+    // CHECK-NEXT: pos[0] : ( 0, 2,
+    // CHECK-NEXT: crd[0] : ( 0, 1,
+    // CHECK-NEXT: pos[1] : ( 0, 2, 3,
+    // CHECK-NEXT: crd[1] : ( 0, 1, 1,
+    // CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
+    // CHECK-NEXT: ----
+    sparse_tensor.print %h : tensor<4x8xi32, #BSCC>
+
+    // Release the resources.
+    bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
+    bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
+    bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
+    bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
+    bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
+    bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
+    bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
+    bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
+
+    return
+  }
+}

>From fa556ad8d673b98931247cb4be3fcafe26fb4639 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 28 Feb 2024 12:13:54 -0800
Subject: [PATCH 2/2] reviewer feedback

---
 .../Transforms/SparseTensorRewriting.cpp             | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa078e2c2efd75..c95b7b015b3725 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -632,8 +632,8 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
         rewriter.create<vector::PrintOp>(
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
-        auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
-        printContents(rewriter, loc, tp, poss);
+        auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, pos);
         break;
       }
       case SparseTensorFieldKind::CrdMemRef: {
@@ -642,15 +642,15 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
         rewriter.create<vector::PrintOp>(
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
-        auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
-        printContents(rewriter, loc, tp, crds);
+        auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, crd);
         break;
       }
       case SparseTensorFieldKind::ValMemRef: {
         rewriter.create<vector::PrintOp>(loc,
                                          rewriter.getStringAttr("values : "));
-        auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
-        printContents(rewriter, loc, tp, vals);
+        auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
+        printContents(rewriter, loc, tp, val);
         break;
       }
       }



More information about the Mlir-commits mailing list