[Mlir-commits] [mlir] [mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops (PR #83343)

Peiming Liu llvmlistbot at llvm.org
Wed Feb 28 14:53:20 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/83343

>From f97cf1fe122c05491e114f3750fc9f1929166c27 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 28 Feb 2024 21:37:45 +0000
Subject: [PATCH] [mlir][sparse] infer returned type for
 sparse_tensor.to_[sparse_buffer] ops.

---
 .../SparseTensor/IR/SparseTensorOps.td        |  20 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  65 +++++++
 .../Transforms/SparseTensorCodegen.cpp        |  12 +-
 .../Transforms/SparseTensorRewriting.cpp      |  27 +--
 .../SparseTensor/CPU/sparse_insert_3d.mlir    | 171 ++++--------------
 5 files changed, 132 insertions(+), 163 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9007e4e98e3163..3a5447d29f866d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -257,9 +257,10 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
+def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th positions array of the `tensor`";
   let description = [{
     Returns the positions array of the tensor's storage at the given
@@ -283,9 +284,10 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
+def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th coordinates array of the `tensor`";
   let description = [{
     Returns the coordinates array of the tensor's storage at the given
@@ -309,9 +311,10 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [Pure]>,
+def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the linear coordinates array from a tensor";
   let description = [{
     Returns the linear coordinates array for a sparse tensor with
@@ -340,9 +343,10 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [
   let hasVerifier = 1;
 }
 
-def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>,
+def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
+      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
-    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+    Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts numerical values array from a tensor";
   let description = [{
     Returns the values array of the sparse storage format for the given
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 69c3413f35ea9c..232635ca84a47e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1445,6 +1445,38 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+template <typename ToBufferOp>
+static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
+                                           OpaqueProperties prop,
+                                           RegionRange region,
+                                           SmallVectorImpl<mlir::Type> &ret) {
+  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  Type elemTp = nullptr;
+  bool withStride = false;
+  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
+    elemTp = stt.getPosType();
+  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
+                       std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
+    elemTp = stt.getCrdType();
+    if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
+      withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
+  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
+    elemTp = stt.getElementType();
+  }
+
+  assert(elemTp && "unhandled operation.");
+  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
+  bufShape.push_back(ShapedType::kDynamic);
+
+  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
+                                 stt.getContext(), ShapedType::kDynamic,
+                                 {ShapedType::kDynamic})
+                           : StridedLayoutAttr();
+  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
+  return success();
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1454,6 +1486,14 @@ LogicalResult ToPositionsOp::verify() {
   return success();
 }
 
+LogicalResult
+ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+                                ValueRange ops, DictionaryAttr attr,
+                                OpaqueProperties prop, RegionRange region,
+                                SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
+}
+
 LogicalResult ToCoordinatesOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1463,6 +1503,14 @@ LogicalResult ToCoordinatesOp::verify() {
   return success();
 }
 
+LogicalResult
+ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+                                  ValueRange ops, DictionaryAttr attr,
+                                  OpaqueProperties prop, RegionRange region,
+                                  SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
+}
+
 LogicalResult ToCoordinatesBufferOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   if (stt.getAoSCOOStart() >= stt.getLvlRank())
@@ -1470,6 +1518,14 @@ LogicalResult ToCoordinatesBufferOp::verify() {
   return success();
 }
 
+LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
+                                                      ret);
+}
+
 LogicalResult ToValuesOp::verify() {
   auto stt = getSparseTensorType(getTensor());
   auto mtp = getMemRefType(getResult());
@@ -1478,6 +1534,15 @@ LogicalResult ToValuesOp::verify() {
   return success();
 }
 
+LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
+                                           std::optional<Location> loc,
+                                           ValueRange ops, DictionaryAttr attr,
+                                           OpaqueProperties prop,
+                                           RegionRange region,
+                                           SmallVectorImpl<mlir::Type> &ret) {
+  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
+}
+
 LogicalResult ToSliceOffsetOp::verify() {
   auto rank = getRankedTensorType(getSlice()).getRank();
   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d5eec4ae67e798..4e3393195813c3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1058,17 +1058,9 @@ class SparseToCoordinatesConverter
     // Replace the requested coordinates access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    Location loc = op.getLoc();
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
-
-    // Insert a cast to bridge the actual type to the user expected type. If the
-    // actual type and the user expected type aren't compatible, the compiler or
-    // the runtime will issue an error.
-    Type resType = op.getResult().getType();
-    if (resType != field.getType())
-      field = rewriter.create<memref::CastOp>(loc, resType, field);
-    rewriter.replaceOp(op, field);
+    rewriter.replaceOp(
+        op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
 
     return success();
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c95b7b015b3725..6ff21468e05764 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -618,10 +618,10 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
     rewriter.create<vector::PrintOp>(loc, nse);
     // Use the "codegen" foreach loop construct to iterate over
     // all typical sparse tensor components for printing.
-    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
-                                            &tensor](Type tp, FieldIndex,
-                                                     SparseTensorFieldKind kind,
-                                                     Level l, LevelType) {
+    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
+                                            &stt](Type, FieldIndex,
+                                                  SparseTensorFieldKind kind,
+                                                  Level l, LevelType) {
       switch (kind) {
       case SparseTensorFieldKind::StorageSpec: {
         break;
@@ -632,8 +632,8 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
         rewriter.create<vector::PrintOp>(
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
-        auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
-        printContents(rewriter, loc, tp, pos);
+        auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
+        printContents(rewriter, loc, pos);
         break;
       }
       case SparseTensorFieldKind::CrdMemRef: {
@@ -642,15 +642,20 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
         rewriter.create<vector::PrintOp>(
             loc, lvl, vector::PrintPunctuation::NoPunctuation);
         rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
-        auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
-        printContents(rewriter, loc, tp, crd);
+        Value crd = nullptr;
+        // TODO: eliminates ToCoordinateBufferOp!
+        if (stt.getAoSCOOStart() == l)
+          crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
+        else
+          crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
+        printContents(rewriter, loc, crd);
         break;
       }
       case SparseTensorFieldKind::ValMemRef: {
         rewriter.create<vector::PrintOp>(loc,
                                          rewriter.getStringAttr("values : "));
-        auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
-        printContents(rewriter, loc, tp, val);
+        auto val = rewriter.create<ToValuesOp>(loc, tensor);
+        printContents(rewriter, loc, val);
         break;
       }
       }
@@ -670,7 +675,7 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
   //
   // Generates code to print:
   //    ( a0, a1, ... )
-  static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+  static void printContents(PatternRewriter &rewriter, Location loc,
                             Value vec) {
     // Open bracket.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir
index c141df64c22e76..3a32ff28527001 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_insert_3d.mlir
@@ -45,91 +45,6 @@
 
 
 module {
-
-  func.func @dump(%arg0: tensor<5x4x3xf64, #TensorCSR>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %i0 = sparse_tensor.coordinates  %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %p2 = sparse_tensor.positions %arg0 { level = 2 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %i2 = sparse_tensor.coordinates  %arg0 { level = 2 : index } : tensor<5x4x3xf64, #TensorCSR> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #TensorCSR> to memref<?xf64>
-    %vp0 = vector.transfer_read %p0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vp0 : vector<2xindex>
-    %vi0 = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vi0 : vector<2xindex>
-    %vp2 = vector.transfer_read %p2[%c0], %c0: memref<?xindex>, vector<9xindex>
-    vector.print %vp2 : vector<9xindex>
-    %vi2 = vector.transfer_read %i2[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi2 : vector<5xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<5xf64>
-    vector.print %vv : vector<5xf64>
-    return
-  }
-
-  func.func @dump_row(%arg0: tensor<5x4x3xf64, #TensorRow>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %i0 = sparse_tensor.coordinates  %arg0 { level = 0 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %p1 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %i1 = sparse_tensor.coordinates  %arg0 { level = 1 : index } : tensor<5x4x3xf64, #TensorRow> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #TensorRow> to memref<?xf64>
-    %vp0 = vector.transfer_read %p0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vp0 : vector<2xindex>
-    %vi0 = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vi0 : vector<2xindex>
-    %vp1 = vector.transfer_read %p1[%c0], %c0: memref<?xindex>, vector<3xindex>
-    vector.print %vp1 : vector<3xindex>
-    %vi1 = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<4xindex>
-    vector.print %vi1 : vector<4xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<12xf64>
-    vector.print %vv : vector<12xf64>
-    return
-  }
-
-func.func @dump_ccoo(%arg0: tensor<5x4x3xf64, #CCoo>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %i0 = sparse_tensor.coordinates  %arg0 { level = 0 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %p1 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %i1 = sparse_tensor.coordinates  %arg0 { level = 1 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %i2 = sparse_tensor.coordinates  %arg0 { level = 2 : index } : tensor<5x4x3xf64, #CCoo> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #CCoo> to memref<?xf64>
-    %vp0 = vector.transfer_read %p0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vp0 : vector<2xindex>
-    %vi0 = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<2xindex>
-    vector.print %vi0 : vector<2xindex>
-    %vp1 = vector.transfer_read %p1[%c0], %c0: memref<?xindex>, vector<3xindex>
-    vector.print %vp1 : vector<3xindex>
-    %vi1 = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi1 : vector<5xindex>
-    %vi2 = vector.transfer_read %i2[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi2 : vector<5xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<5xf64>
-    vector.print %vv : vector<5xf64>
-    return
-  }
-
-func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
-    %c0 = arith.constant 0 : index
-    %fu = arith.constant 99.0 : f64
-    %p1 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<5x4x3xf64, #DCoo> to memref<?xindex>
-    %i1 = sparse_tensor.coordinates  %arg0 { level = 1 : index } : tensor<5x4x3xf64, #DCoo> to memref<?xindex>
-    %i2 = sparse_tensor.coordinates  %arg0 { level = 2 : index } : tensor<5x4x3xf64, #DCoo> to memref<?xindex>
-    %v = sparse_tensor.values %arg0 : tensor<5x4x3xf64, #DCoo> to memref<?xf64>
-    %vp1 = vector.transfer_read %p1[%c0], %c0: memref<?xindex>, vector<6xindex>
-    vector.print %vp1 : vector<6xindex>
-    %vi1 = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi1 : vector<5xindex>
-    %vi2 = vector.transfer_read %i2[%c0], %c0: memref<?xindex>, vector<5xindex>
-    vector.print %vi2 : vector<5xindex>
-    %vv = vector.transfer_read %v[%c0], %fu: memref<?xf64>, vector<5xf64>
-    vector.print %vv : vector<5xf64>
-    return
-}
-
   //
   // Main driver.
   //
@@ -145,13 +60,14 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %f4 = arith.constant 4.4 : f64
     %f5 = arith.constant 5.5 : f64
 
-    //
-    // CHECK:      ( 0, 2 )
-    // CHECK-NEXT: ( 3, 4 )
-    // CHECK-NEXT: ( 0, 2, 2, 2, 3, 3, 3, 4, 5 )
-    // CHECK-NEXT: ( 1, 2, 1, 2, 2 )
-    // CHECK-NEXT: ( 1.1, 2.2, 3.3, 4.4, 5.5 )
-    //
+    // CHECK: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 2
+    // CHECK-NEXT: crd[0] : ( 3, 4
+    // CHECK-NEXT: pos[2] : ( 0, 2, 2, 2, 3, 3, 3, 4, 5
+    // CHECK-NEXT: crd[2] : ( 1, 2, 1, 2, 2
+    // CHECK-NEXT: values : ( 1.1, 2.2, 3.3, 4.4, 5.5
+    // CHECK-NEXT: ----
     %tensora = tensor.empty() : tensor<5x4x3xf64, #TensorCSR>
     %tensor1 = sparse_tensor.insert %f1 into %tensora[%c3, %c0, %c1] : tensor<5x4x3xf64, #TensorCSR>
     %tensor2 = sparse_tensor.insert %f2 into %tensor1[%c3, %c0, %c2] : tensor<5x4x3xf64, #TensorCSR>
@@ -159,15 +75,16 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %tensor4 = sparse_tensor.insert %f4 into %tensor3[%c4, %c2, %c2] : tensor<5x4x3xf64, #TensorCSR>
     %tensor5 = sparse_tensor.insert %f5 into %tensor4[%c4, %c3, %c2] : tensor<5x4x3xf64, #TensorCSR>
     %tensorm = sparse_tensor.load %tensor5 hasInserts : tensor<5x4x3xf64, #TensorCSR>
-    call @dump(%tensorm) : (tensor<5x4x3xf64, #TensorCSR>) -> ()
-
-    //
-    // CHECK-NEXT: ( 0, 2 )
-    // CHECK-NEXT: ( 3, 4 )
-    // CHECK-NEXT: ( 0, 2, 4 )
-    // CHECK-NEXT: ( 0, 3, 2, 3 )
-    // CHECK-NEXT: ( 0, 1.1, 2.2, 0, 3.3, 0, 0, 0, 4.4, 0, 0, 5.5 )
-    //
+    sparse_tensor.print %tensorm : tensor<5x4x3xf64, #TensorCSR>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 12
+    // CHECK-NEXT: pos[0] : ( 0, 2
+    // CHECK-NEXT: crd[0] : ( 3, 4
+    // CHECK-NEXT: pos[1] : ( 0, 2, 4
+    // CHECK-NEXT: crd[1] : ( 0, 3, 2, 3
+    // CHECK-NEXT: values : ( 0, 1.1, 2.2, 0, 3.3, 0, 0, 0, 4.4, 0, 0, 5.5
+    // CHECK-NEXT: ----
     %rowa = tensor.empty() : tensor<5x4x3xf64, #TensorRow>
     %row1 = sparse_tensor.insert %f1 into %rowa[%c3, %c0, %c1] : tensor<5x4x3xf64, #TensorRow>
     %row2 = sparse_tensor.insert %f2 into %row1[%c3, %c0, %c2] : tensor<5x4x3xf64, #TensorRow>
@@ -175,15 +92,16 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %row4 = sparse_tensor.insert %f4 into %row3[%c4, %c2, %c2] : tensor<5x4x3xf64, #TensorRow>
     %row5 = sparse_tensor.insert %f5 into %row4[%c4, %c3, %c2] : tensor<5x4x3xf64, #TensorRow>
     %rowm = sparse_tensor.load %row5 hasInserts : tensor<5x4x3xf64, #TensorRow>
-    call @dump_row(%rowm) : (tensor<5x4x3xf64, #TensorRow>) -> ()
-
-    //
-    // CHECK: ( 0, 2 )
-    // CHECK-NEXT: ( 3, 4 )
-    // CHECK-NEXT: ( 0, 3, 5 )
-    // CHECK-NEXT: ( 0, 0, 3, 2, 3 )
-    // CHECK-NEXT: ( 1, 2, 1, 2, 2 )
-    // CHECK-NEXT: ( 1.1, 2.2, 3.3, 4.4, 5.5 )
+    sparse_tensor.print %rowm : tensor<5x4x3xf64, #TensorRow>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[0] : ( 0, 2
+    // CHECK-NEXT: crd[0] : ( 3, 4
+    // CHECK-NEXT: pos[1] : ( 0, 3, 5
+    // CHECK-NEXT: crd[1] : ( 0, 1, 0, 2, 3, 1, 2, 2, 3, 2
+    // CHECK-NEXT: values : ( 1.1, 2.2, 3.3, 4.4, 5.5
+    // CHECK-NEXT: ----
     %ccoo = tensor.empty() : tensor<5x4x3xf64, #CCoo>
     %ccoo1 = sparse_tensor.insert %f1 into %ccoo[%c3, %c0, %c1] : tensor<5x4x3xf64, #CCoo>
     %ccoo2 = sparse_tensor.insert %f2 into %ccoo1[%c3, %c0, %c2] : tensor<5x4x3xf64, #CCoo>
@@ -191,13 +109,14 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %ccoo4 = sparse_tensor.insert %f4 into %ccoo3[%c4, %c2, %c2] : tensor<5x4x3xf64, #CCoo>
     %ccoo5 = sparse_tensor.insert %f5 into %ccoo4[%c4, %c3, %c2] : tensor<5x4x3xf64, #CCoo>
     %ccoom = sparse_tensor.load %ccoo5 hasInserts : tensor<5x4x3xf64, #CCoo>
-    call @dump_ccoo(%ccoom) : (tensor<5x4x3xf64, #CCoo>) -> ()
-
-    //
-    // CHECK-NEXT: ( 0, 0, 0, 0, 3, 5 )
-    // CHECK-NEXT: ( 0, 0, 3, 2, 3 )
-    // CHECK-NEXT: ( 1, 2, 1, 2, 2 )
-    // CHECK-NEXT: ( 1.1, 2.2, 3.3, 4.4, 5.5 )
+    sparse_tensor.print %ccoom : tensor<5x4x3xf64, #CCoo>
+
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 5
+    // CHECK-NEXT: pos[1] : ( 0, 0, 0, 0, 3, 5
+    // CHECK-NEXT: crd[1] : ( 0, 1, 0, 2, 3, 1, 2, 2, 3, 2
+    // CHECK-NEXT: values : ( 1.1, 2.2, 3.3, 4.4, 5.5
+    // CHECK-NEXT: ----
     %dcoo = tensor.empty() : tensor<5x4x3xf64, #DCoo>
     %dcoo1 = sparse_tensor.insert %f1 into %dcoo[%c3, %c0, %c1] : tensor<5x4x3xf64, #DCoo>
     %dcoo2 = sparse_tensor.insert %f2 into %dcoo1[%c3, %c0, %c2] : tensor<5x4x3xf64, #DCoo>
@@ -205,23 +124,7 @@ func.func @dump_dcoo(%arg0: tensor<5x4x3xf64, #DCoo>) {
     %dcoo4 = sparse_tensor.insert %f4 into %dcoo3[%c4, %c2, %c2] : tensor<5x4x3xf64, #DCoo>
     %dcoo5 = sparse_tensor.insert %f5 into %dcoo4[%c4, %c3, %c2] : tensor<5x4x3xf64, #DCoo>
     %dcoom = sparse_tensor.load %dcoo5 hasInserts : tensor<5x4x3xf64, #DCoo>
-    call @dump_dcoo(%dcoom) : (tensor<5x4x3xf64, #DCoo>) -> ()
-
-    // NOE sanity check.
-    //
-    // CHECK-NEXT: 5
-    // CHECK-NEXT: 12
-    // CHECK-NEXT: 5
-    // CHECK-NEXT: 5
-    //
-    %noe1 = sparse_tensor.number_of_entries %tensorm : tensor<5x4x3xf64, #TensorCSR>
-    vector.print %noe1 : index
-    %noe2 = sparse_tensor.number_of_entries %rowm : tensor<5x4x3xf64, #TensorRow>
-    vector.print %noe2 : index
-    %noe3 = sparse_tensor.number_of_entries %ccoom : tensor<5x4x3xf64, #CCoo>
-    vector.print %noe3 : index
-    %noe4 = sparse_tensor.number_of_entries %dcoom : tensor<5x4x3xf64, #DCoo>
-    vector.print %noe4 : index
+    sparse_tensor.print %dcoom : tensor<5x4x3xf64, #DCoo>
 
     // Release resources.
     bufferization.dealloc_tensor %tensorm : tensor<5x4x3xf64, #TensorCSR>



More information about the Mlir-commits mailing list