[Mlir-commits] [mlir] [mlir][sparse] support 'batch' dimensions in sparse_tensor.print (PR #91411)

Aart Bik llvmlistbot at llvm.org
Tue May 7 16:42:40 PDT 2024


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/91411

None

>From b1a1d322e8bd31653a879060145aef42325a1fda Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 7 May 2024 16:32:04 -0700
Subject: [PATCH] [mlir][sparse] support 'batch' dimensions in
 sparse_tensor.print

---
 .../Transforms/SparseTensorCodegen.cpp        | 12 ++-
 .../Transforms/SparseTensorRewriting.cpp      | 66 ++++++++++-------
 .../SparseTensor/CPU/sparse_pack_d.mlir       | 12 +--
 .../SparseTensor/CPU/sparse_print_3d.mlir     | 74 +++++++++++++++++++
 4 files changed, 130 insertions(+), 34 deletions(-)
 create mode 100755 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d9b203a886488..164e722c45dba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
 /// Generates a subview into the sizes.
 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
                             Value sz) {
-  auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
+  auto memTp = llvm::cast<MemRefType>(mem.getType());
+  // For higher-dimensional memrefs, we assume that the innermost
+  // dimension is always of the right size.
+  // TODO: generate complex truncating view here too?
+  if (memTp.getRank() > 1)
+    return mem;
+  // Truncate linear memrefs to given size.
   return builder
       .create<memref::SubViewOp>(
-          loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
-          ValueRange{}, ValueRange{sz}, ValueRange{},
+          loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
+          mem, ValueRange{}, ValueRange{sz}, ValueRange{},
           ArrayRef<int64_t>{0},                    // static offset
           ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
           ArrayRef<int64_t>{1})                    // static stride
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7d469198a653c..025fd3331ba89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -785,45 +785,61 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
   }
 
 private:
-  // Helper to print contents of a single memref. Note that for the "push_back"
-  // vectors, this prints the full capacity, not just the size. This is done
-  // on purpose, so that clients see how much storage has been allocated in
-  // total. Contents of the extra capacity in the buffer may be uninitialized
-  // (unless the flag enable-buffer-initialization is set to true).
+  // Helper to print contents of a single memref. For "push_back" vectors,
+  // we assume that the previous getters for pos/crd/val have added a
+  // slice-to-size view to make sure we just print the size and not the
+  // full capacity.
   //
-  // Generates code to print:
+  // Generates code to print (1-dim or higher):
   //    ( a0, a1, ... )
   static void printContents(PatternRewriter &rewriter, Location loc,
                             Value vec) {
+    auto shape = cast<ShapedType>(vec.getType()).getShape();
+    SmallVector<Value> idxs;
+    printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+  }
+
+  // Helper to the helper.
+  static void printContentsLevel(PatternRewriter &rewriter, Location loc,
+                                 Value vec, unsigned i, ArrayRef<int64_t> shape,
+                                 SmallVectorImpl<Value> &idxs) {
     // Open bracket.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
-    // For loop over elements.
+    // Generate for loop.
     auto zero = constantIndex(rewriter, loc, 0);
-    auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+    auto index = constantIndex(rewriter, loc, i);
+    auto size = rewriter.create<memref::DimOp>(loc, vec, index);
     auto step = constantIndex(rewriter, loc, 1);
     auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+    idxs.push_back(forOp.getInductionVar());
     rewriter.setInsertionPointToStart(forOp.getBody());
-    auto idx = forOp.getInductionVar();
-    auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
-    if (llvm::isa<ComplexType>(val.getType())) {
-      // Since the vector dialect does not support complex types in any op,
-      // we split those into (real, imag) pairs here.
-      Value real = rewriter.create<complex::ReOp>(loc, val);
-      Value imag = rewriter.create<complex::ImOp>(loc, val);
-      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
-      rewriter.create<vector::PrintOp>(loc, real,
-                                       vector::PrintPunctuation::Comma);
-      rewriter.create<vector::PrintOp>(loc, imag,
-                                       vector::PrintPunctuation::Close);
-      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+    if (i < shape.size() - 1) {
+      // Enter deeper loop nest.
+      printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
     } else {
-      rewriter.create<vector::PrintOp>(loc, val,
-                                       vector::PrintPunctuation::Comma);
+      // Actual contents printing.
+      auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
+      if (llvm::isa<ComplexType>(val.getType())) {
+        // Since the vector dialect does not support complex types in any op,
+        // we split those into (real, imag) pairs here.
+        Value real = rewriter.create<complex::ReOp>(loc, val);
+        Value imag = rewriter.create<complex::ImOp>(loc, val);
+        rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+        rewriter.create<vector::PrintOp>(loc, real,
+                                         vector::PrintPunctuation::Comma);
+        rewriter.create<vector::PrintOp>(loc, imag,
+                                         vector::PrintPunctuation::Close);
+        rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+      } else {
+        rewriter.create<vector::PrintOp>(loc, val,
+                                         vector::PrintPunctuation::Comma);
+      }
     }
+    idxs.pop_back();
     rewriter.setInsertionPointAfter(forOp);
-    // Close bracket and end of line.
+    // Close bracket.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
-    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
   }
 
   // Helper method to print run-time lvl/dim sizes.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
index 20ae7e86285cc..467a77f30777a 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
@@ -29,7 +29,7 @@
   crdWidth = 32
 }>
 
-#BatchedCSR = #sparse_tensor.encoding<{
+#DenseCSR = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
   posWidth = 64,
   crdWidth = 32
@@ -42,7 +42,7 @@
 }>
 
 //
-// Test assembly operation with CCC, batched-CSR and CSR-dense.
+// Test assembly operation with CCC, dense-CSR and CSR-dense.
 //
 module {
   //
@@ -77,7 +77,7 @@ module {
         tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
 
     //
-    // Setup BatchedCSR.
+    // Setup DenseCSR.
     //
 
     %data1 = arith.constant dense<
@@ -88,7 +88,7 @@ module {
     %crd1 = arith.constant dense<
        [ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
 
-    %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
+    %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR>
 
     //
     // Setup CSRDense.
@@ -137,7 +137,7 @@ module {
     // CHECK-NEXT: ----
     //
     sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC>
-    sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR>
+    sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR>
     sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense>
 
     // TODO: This check is no longer needed once the codegen path uses the
@@ -148,7 +148,7 @@ module {
       // sparse_tensor.assemble copies buffers when running with the runtime
       // library. Deallocations are not needed when running in codegen mode.
       bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
-      bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
+      bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR>
       bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
     }
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
new file mode 100755
index 0000000000000..98dee304fa511
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
@@ -0,0 +1,74 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// TODO: make this work with libgen
+
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#BatchedCSR = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)
+}>
+
+module {
+
+  //
+  // Main driver that tests 3-D sparse tensor printing.
+  //
+  func.func @main() {
+
+    %pos = arith.constant dense<
+      [[ 0, 8, 16, 24, 32],
+       [ 0, 8, 16, 24, 32]]
+    > : tensor<2x5xindex>
+
+    %crd = arith.constant dense<
+      [[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
+       [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]]
+    > : tensor<2x32xindex>
+
+    %val = arith.constant dense<
+      [[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.,
+        12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
+        23., 24., 25., 26., 27., 28., 29., 30., 31., 32.],
+       [33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
+        44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
+        55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]]
+    > : tensor<2x32xf64>
+
+    %X = sparse_tensor.assemble (%pos, %crd), %val
+      : (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR>
+
+    // CHECK:      ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 32
+    // CHECK-NEXT: dim = ( 2, 4, 8 )
+    // CHECK-NEXT: lvl = ( 2, 4, 8 )
+    // CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32,  )( 0, 8, 16, 24, 32,  ) )
+    // CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,  )
+    // CHECK-SAME:            ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,  ) )
+    // CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,  )
+    // CHECK-SAME:            ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,  ) )
+    // CHECK-NEXT: ----
+    sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR>
+
+    return
+  }
+}



More information about the Mlir-commits mailing list