[Mlir-commits] [mlir] a63d6a0 - [mlir][sparse] make UnpackOp return the actual filled length of unpacked memory

Peiming Liu llvmlistbot at llvm.org
Fri Jun 30 14:35:21 PDT 2023


Author: Peiming Liu
Date: 2023-06-30T21:35:15Z
New Revision: a63d6a00140d25289e0cba294c9228dbb74b06fa

URL: https://github.com/llvm/llvm-project/commit/a63d6a00140d25289e0cba294c9228dbb74b06fa
DIFF: https://github.com/llvm/llvm-project/commit/a63d6a00140d25289e0cba294c9228dbb74b06fa.diff

LOG: [mlir][sparse] make UnpackOp return the actual filled length of unpacked memory

This might simplify frontend implementation by avoiding recomputation for the same value.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D154244

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9fd81b98a2dcd6..01320a2e4c5acb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -102,18 +102,21 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
+def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultSize]>,
     Arguments<(ins AnySparseTensor:$tensor,
                    TensorOf<[AnyType]>:$out_values,
                    Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
     Results<(outs TensorOf<[AnyType]>:$ret_values,
-                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels)> {
+                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                  Index:$val_len,
+                  Variadic<Index>:$lvl_lens)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
   let description = [{
     The unpack operation is the inverse of `sparse_tensor::pack`.  It returns
     the values and per-level position and coordinate array to the user
-    from the sparse tensor. This operation can be used for returning an
+    from the sparse tensor along with the actual length of the memory used in
+    each returned buffer. This operation can be used for returning an
     unpacked MLIR sparse tensor to frontend; e.g., returning two numpy arrays to Python.
 
     Disclaimer: This is the user's responsibility to allocate large enough buffers
@@ -128,18 +131,22 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
     // input COO format |1.1, 0.0, 0.0, 0.0|
     //    of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
     //                  |0.0, 0.0, 0.0, 0.0|
-    %values, %pos, %coords = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector>
-                             outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-                             -> tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
-    // %values      = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
-    // %pos         = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
-    // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
+    %v, %p, %c, %v_len, %p_len, %c_len = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector>
+                                         outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
+                                      -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+    // %v = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
+    // %p = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
+    // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
+    // %v_len = 3
+    // %p_len = 2
+    // %c_len = 6 (3x2)
     ```
   }];
 
   let assemblyFormat =
-    "$tensor `:` type($tensor) `outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)`"
-    "attr-dict `->` type($ret_values) `,` type($ret_levels)";
+    "$tensor `:` type($tensor) "
+    "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
+    "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
 
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e712c9396466bb..cd49205627ea7c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -166,13 +166,13 @@ struct UnpackOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     // We write into the output operand.
-    assert(op->getNumOperands() == op->getNumResults() + 1);
+    assert(2 * (op->getNumOperands() - 1) == op->getNumResults());
     return opOperand.getOperandNumber() > 0;
   }
 
   AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
                                             const AnalysisState &state) const {
-    assert(op->getNumOperands() == op->getNumResults() + 1);
+    assert(2 * (op->getNumOperands() - 1) == op->getNumResults());
 
     if (opOperand.getOperandNumber() == 0)
       return {};

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index a7f37e8189ea03..73840ba1327572 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1311,7 +1311,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     Location loc = op.getLoc();
     SmallVector<Value> retMem;
-    desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem](
+    SmallVector<Value> retLen;
+    desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen](
                                       FieldIndex fid,
                                       SparseTensorFieldKind fKind, Level lvl,
                                       DimLevelType dlt) -> bool {
@@ -1329,6 +1330,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         // TODO: maybe change unpack/pack operation instead to be
         // consistent.
         retMem.insert(retMem.begin(), dst);
+        retLen.insert(retLen.begin(), sz);
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1339,6 +1341,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         src = desc.getMemRefField(fid);
         dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
         retMem.push_back(dst);
+        retLen.push_back(sz);
       }
       Value flatOut = dst;
       if (dst.getType().getRank() != 1) {
@@ -1352,12 +1355,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     });
 
     // Converts MemRefs back to Tensors.
-    SmallVector<Value> retTensor = llvm::to_vector(
+    SmallVector<Value> retValues = llvm::to_vector(
         llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
           return rewriter.create<bufferization::ToTensorOp>(loc, v);
         }));
-
-    rewriter.replaceOp(op, retTensor);
+    // Appends the actual memory length used in each buffer returned.
+    retValues.append(retLen.begin(), retLen.end());
+    rewriter.replaceOp(op, retValues);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7a6c4824aabed6..1510c3d65b57a5 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -60,9 +60,9 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
 
 func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
   // expected-error at +1 {{input/output element-types don't match}}
-  %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
+  %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
                   outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
-                  -> tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+                  -> tensor<6xf64>, (tensor<2xi32>, tensor<6x1xi32>), index, (index, index)
   return
 }
 
@@ -72,9 +72,9 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: ten
 
 func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
   // expected-error at +1 {{input/output trailing COO level-ranks don't match}}
-  %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
+  %rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
                   outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
-                  -> tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>
+                  -> tensor<6xf64>, (tensor<2xi32>, tensor<6x3xi32>), index, (index, index)
   return
 }
 
@@ -84,9 +84,9 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: t
 
 func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
   // expected-error at +1 {{inconsistent number of fields between input/output}}
-  %rv, %rc = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
+  %rv, %rc, %vl, %pl = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
              outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
-             -> tensor<6xf64>, tensor<6xi32>
+             -> tensor<6xf64>, (tensor<6xi32>), index, (index)
   return
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 43429f454e1225..ace180369da7db 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -36,16 +36,16 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
 //  CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
 //  CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
 //  CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
-//       CHECK: %[[D:.*]], %[[P:.*]]:2 = sparse_tensor.unpack %[[T]]
+//       CHECK: %[[D:.*]], %[[P:.*]]:2, %[[DL:.*]], %[[PL:.*]]:2 = sparse_tensor.unpack %[[T]]
 //       CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
 func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
                          %od : tensor<6xf64>,
                          %op : tensor<2xindex>,
                          %oi : tensor<6x1xi32>)
                        -> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
-  %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
+  %rd, %rp, %ri, %vl, %pl, %cl = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
                   outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
-                  -> tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
+                  -> tensor<6xf64>, (tensor<2xindex>, tensor<6x1xi32>), index, (index, index)
   return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 5d7305dac54cc6..0377ecc3090a17 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -70,8 +70,8 @@ func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
                          %op : tensor<2xindex>,
                          %oi : tensor<6x2xi32>)
                        -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
-  %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
-                  outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
-                  -> tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
+  %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
+                                 outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
+                                 -> tensor<6xf64>, (tensor<2xindex>, tensor<6x2xi32>), index, (index, index)
   return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 3014407a95a1a6..acb5dca06863d9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -171,9 +171,9 @@ module {
     %d_csr = tensor.empty() : tensor<4xf64>
     %p_csr = tensor.empty() : tensor<3xi32>
     %i_csr = tensor.empty() : tensor<3xi32>
-    %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
+    %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
                  outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
-                 -> tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
+                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
 
     // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
     %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
@@ -196,9 +196,9 @@ module {
     %od = tensor.empty() : tensor<3xf64>
     %op = tensor.empty() : tensor<2xi32>
     %oi = tensor.empty() : tensor<3x2xi32>
-    %d, %p, %i = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
+    %d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
                  outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
-                 -> tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
+                 -> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
 
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -212,17 +212,21 @@ module {
     %bod = tensor.empty() : tensor<6xf64>
     %bop = tensor.empty() : tensor<4xindex>
     %boi = tensor.empty() : tensor<6x2xindex>
-    %bd, %bp, %bi = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
+    %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
                     outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-                    -> tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>
+                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
 
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
     %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
     vector.print %vbd : vector<6xf64>
+    // CHECK-NEXT: 5
+    vector.print %ld : index
 
     // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) )
     %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
     vector.print %vbi : vector<6x2xindex>
+    // CHECK-NEXT: 10
+    vector.print %li : index
 
     return
   }


        


More information about the Mlir-commits mailing list