[Mlir-commits] [mlir] b2e6b73 - [mlir][sparse] extend unpack operation to unpack arbitrary encodings.

Peiming Liu llvmlistbot at llvm.org
Tue May 23 15:34:06 PDT 2023


Author: Peiming Liu
Date: 2023-05-23T22:34:01Z
New Revision: b2e6b7354452c10ed6f38958253fd76aca0877de

URL: https://github.com/llvm/llvm-project/commit/b2e6b7354452c10ed6f38958253fd76aca0877de
DIFF: https://github.com/llvm/llvm-project/commit/b2e6b7354452c10ed6f38958253fd76aca0877de.diff

LOG: [mlir][sparse] extend unpack operation to unpack arbitrary encodings.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D151174

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 865c1aa38f61f..e37062f5f8104 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -73,7 +73,7 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
     This operation can be used to materialize a sparse tensor from external
     sources; e.g., when passing two numpy arrays from Python.
 
-    Disclaimer: This is users' responsibility to provide input that can be
+    Disclaimer: This is the user's responsibility to provide input that can be
     correctly interpreted by the sparse compiler, which does not perform
     any sanity test during runtime to verify data integrity.
 
@@ -102,29 +102,25 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
+def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
     Arguments<(ins AnySparseTensor:$tensor,
-                   OptionalAttr<IndexAttr>:$batched_lvls)>,
-    Results<(outs TensorOf<[AnyType]>:$values,
-                  TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
-                  AnySignlessIntegerOrIndex:$nse)> {
+                   TensorOf<[AnyType]>:$out_values,
+                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
+    Results<(outs TensorOf<[AnyType]>:$ret_values,
+                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
   let description = [{
     The unpack operation is the inverse of `sparse_tensor::pack`.  It returns
-    the values, level-coordinates, and number-of-stored-entries extracted
-    from the sparse tensor.  The source tensor is allowed (in principle)
-    to have non-identity dimOrdering/higherOrdering mappings.  Regardless
-    of the mappings, the returned `coordinates` are always level-coordinates,
-    because this is what we mean by "unpacking" as opposed to other forms
-    of exposing sparse tensors to external clients.  This operation can be
-    used for returning an unpacked MLIR sparse tensor to frontend; e.g.,
-    returning two numpy arrays to Python.
+    the values and per-level position and coordinate array to the user
+    from the sparse tensor. This operation can be used for returning an
+    unpacked MLIR sparse tensor to frontend; e.g., returning two numpy arrays to Python.
 
-    TODO: the current implementation does not yet support non-identity mappings.
+    Disclaimer: This is the user's responsibility to allocate large enough buffers
+    to hold the sparse tensor. The sparse compiler simply copies each fields
+    of the sparse tensor into the user-supplied buffer without bound checking.
 
-    This operation ends the lifetime of the sparse tensor, and using
-    the tensor after the unpack is undefined behavior.
+    TODO: the current implementation does not yet support non-identity mappings.
 
     Example:
 
@@ -132,51 +128,18 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
     // input COO format |1.1, 0.0, 0.0, 0.0|
     //    of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
     //                  |0.0, 0.0, 0.0, 0.0|
-    %values, %coordinates, %nse
-      = sparse_tensor.unpack %st
-      : tensor<3x4xf64, #COO> to tensor<2xf64>, tensor<2x2xindex>, index
+    %values, %pos, %coords = sparse_tensor.unpack %sp : tensor<3x4xf64, #SparseVector>
+                             outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
+                             -> tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
     // %values      = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
+    // %pos         = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
     // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
-    // %nse = 3
     ```
-
-    If `batched_lvls` is provided, the operation unpacks each batch of the tensors
-    separately. The returned `nse` is the maximum nse of all batches. For a batch with
-    a smaller nse, trailing zeros are appended in the result.
-    Example:
-
-    ```mlir
-    // input BCOO format |1.1, 2.2, 3.3, 0.0|
-    //      of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
-    %values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1
-        : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
-    // %values      = arith.constant dense<[[ 1.1,   2.2,   3.3 ],
-    //                                      [ 1.2,   2.3,   0.0 ]]> : tensor<2x3xf64>
-    // %coordinates = arith.constant dense<[[ [0],   [1],   [2] ],
-    //                                      [ [1],   [2],   [0] ]> : tensor<2x3x1xindex>
-    ```
-  }];
-
-  let extraClassDeclaration = [{
-    /// Returns the number of leading levels that are batched.
-    unsigned getNumBatchedLvls();
   }];
 
-  let builders = [
-    OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor),
-    [{
-      build($_builder, $_state, values, coordinates, nse, tensor, nullptr);
-    }]>,
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor),
-    [{
-      build($_builder, $_state, resultTypes, tensor, nullptr);
-    }]>
-  ];
-
-
   let assemblyFormat =
-    "$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`"
-    "type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)";
+    "$tensor `:` type($tensor) `outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)`"
+    "attr-dict `->` type($ret_values) `,` type($ret_levels)";
 
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 41805107588a4..0ecc77f228e42 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -786,65 +786,6 @@ static LogicalResult verifySparsifierGetterSetter(
   return success();
 }
 
-// DEPRECATED: This function is deprecated! Remove it after unpack supports
-// arbitrary sparse encoding.
-static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
-                                      SparseTensorType tensorTp,
-                                      RankedTensorType valuesTp,
-                                      RankedTensorType coordinatesTp,
-                                      IntegerAttr batchedLvls) {
-  unsigned nBatched = batchedLvls ? batchedLvls.getValue().getZExtValue() : 0;
-  if (requiresStaticShape && !tensorTp.hasStaticDimShape())
-    return op->emitError("the sparse-tensor must have static shape");
-  if (!tensorTp.hasEncoding())
-    return op->emitError("the sparse-tensor must have an encoding attribute");
-  if (!tensorTp.isIdentity())
-    return op->emitError("the sparse-tensor must have the identity mapping");
-  if (!isCOOType(tensorTp.getEncoding(), nBatched, true))
-    return op->emitError("the sparse-tensor must have a COO type");
-
-  if (coordinatesTp.getRank() != 2 + nBatched)
-    return op->emitError("coordinates must have rank 2 + batched_lvls");
-  if (requiresStaticShape && !coordinatesTp.hasStaticShape())
-    return op->emitError("coordinates must have static shape");
-  if (coordinatesTp.getElementType() != tensorTp.getCrdType())
-    return op->emitError("input/output coordinate-types don't match");
-
-  if (valuesTp.getRank() != 1 + nBatched)
-    return op->emitError("values must have rank 1 + batched_lvls");
-  if (requiresStaticShape && !valuesTp.hasStaticShape())
-    return op->emitError("values must have static shape");
-  if (valuesTp.getElementType() != tensorTp.getElementType())
-    return op->emitError("input/output element-types don't match");
-
-  for (unsigned i = 0; i < nBatched; i++) {
-    const auto valBatch = valuesTp.getShape()[i];
-    const auto crdBatch = coordinatesTp.getShape()[i];
-    if (ShapedType::isDynamic(valBatch) || ShapedType::isDynamic(crdBatch) ||
-        crdBatch != valBatch) {
-      return op->emitError(
-          "values/coordinates batched level sizes don't match statically");
-    }
-  }
-
-  const auto valuesNSE = valuesTp.getShape()[nBatched];
-  const auto coordsNSE = coordinatesTp.getShape()[nBatched];
-  if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) &&
-      valuesNSE != coordsNSE)
-    return op->emitError("values/coordinates number-of-elements don't match");
-
-  // NOTE: We use `getLvlRank` because the `coordinatesTp` is for
-  // level-coordinates (cf., the op documentation).
-  const DynSize coordsRank = coordinatesTp.getShape()[1 + nBatched];
-  const Level tensorRank = tensorTp.getLvlRank();
-  // FIXME: replace the `operator!=` with our backported `safelyNE`.
-  if (!ShapedType::isDynamic(coordsRank) &&
-      coordsRank != static_cast<DynSize>(tensorRank) - nBatched)
-    return op->emitError("input/output level-ranks don't match");
-
-  return success();
-}
-
 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
   switch (kind) {
   case SparseTensorFieldKind::CrdMemRef:
@@ -925,15 +866,17 @@ LogicalResult PackOp::verify() {
 }
 
 LogicalResult UnpackOp::verify() {
-  const auto valuesTp = getRankedTensorType(getValues());
-  const auto coordinatesTp = getRankedTensorType(getCoordinates());
-  const auto srcTp = getSparseTensorType(getTensor());
-  return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
-                          getBatchedLvlsAttr());
-}
+  if (getOutValues().getType() != getRetValues().getType())
+    return emitError("output values and return value type mismatch");
 
-unsigned UnpackOp::getNumBatchedLvls() {
-  return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
+  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
+    if (ot.getType() != rt.getType())
+      return emitError("output levels and return levels type mismatch");
+
+  const auto valuesTp = getRankedTensorType(getRetValues());
+  const auto lvlsTp = getRetLevels().getTypes();
+  const auto srcTp = getSparseTensorType(getTensor());
+  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
 }
 
 LogicalResult ConvertOp::verify() {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index f17c001308bf0..e712c9396466b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -153,28 +153,32 @@ struct UnpackOpInterface
     : public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
                                                     sparse_tensor::UnpackOp> {
   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
-    // We allocate and return unpacked memory if this is a batched unpack.
-    // When the number of batched levels equals to zero, we reuse the
-    // coordinates/values memref (and reallocation if the requested output size
-    // is larger than the actual size). Similar to InsertOp, reallocation is
-    // not considered to allocate a new piece of memory.
-    return llvm::cast<UnpackOp>(op).getNumBatchedLvls() != 0;
+    // The output buffer is pre-allocated by the user.
+    return false;
   }
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    return true;
+    // The first operand is the sparse tensor that we are unpacking.
+    return opOperand.getOperandNumber() == 0;
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
-    return false;
+    // We write into the output operand.
+    assert(op->getNumOperands() == op->getNumResults() + 1);
+    return opOperand.getOperandNumber() > 0;
   }
 
   AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
                                             const AnalysisState &state) const {
-    // Conceptually, UnpackOp equals to a list of toCoordinates/toValueOp
-    return {};
+    assert(op->getNumOperands() == op->getNumResults() + 1);
+
+    if (opOperand.getOperandNumber() == 0)
+      return {};
+    // We write directly into the output tensors and returns them.
+    return {{op->getResult(opOperand.getOperandNumber() - 1),
+             BufferRelation::Equivalent}};
   }
 };
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 4f2e18f43c117..7d4efa8961eb5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -539,48 +539,27 @@ static void genEndInsert(OpBuilder &builder, Location loc,
   }
 }
 
-/// Returns a memref that fits the requested length (reallocates if requested
-/// length is larger, or creates a subview if it is smaller).
-static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
-                              Value buffer) {
-  MemRefType memTp = getMemRefType(buffer);
-  auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
-
-  Value targetLen = constantIndex(builder, loc, len);
-  Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
-  // Reallocates if target length is greater than the actual buffer len.
-  Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
-                                                 targetLen, bufferLen);
-  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
-  // If targetLen > bufferLen, reallocate to get enough sparse to return.
-  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
-  builder.create<scf::YieldOp>(loc, reallocBuf);
-  // Else, return a subview to fit the size.
-  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  Value subViewBuf = builder.create<memref::SubViewOp>(
-      loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
-      /*size=*/ArrayRef<int64_t>{len},
-      /*stride=*/ArrayRef<int64_t>{1});
-  builder.create<scf::YieldOp>(loc, subViewBuf);
-  // Resets insertion point.
-  builder.setInsertionPointAfter(ifOp);
-  return ifOp.getResult(0);
+static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
+                                              Value tensor) {
+  auto tTp = tensor.getType().cast<TensorType>();
+  auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
+  return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
+      .getResult();
 }
 
-static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs,
-                       ValueRange bounds) {
-  assert(ivs.size() == bounds.size());
-  Value crd = constantIndex(builder, loc, 0);
-  for (unsigned i = 0, e = ivs.size(); i < e; i++) {
-    crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
-    if (i != ivs.size() - 1)
-      crd = builder.create<arith::MulIOp>(loc, crd, bounds[i + 1]);
-  }
-  return crd;
+Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
+  auto elemTp = mem.getType().cast<MemRefType>().getElementType();
+  return builder
+      .create<memref::SubViewOp>(
+          loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
+          ValueRange{}, ValueRange{sz}, ValueRange{},
+          ArrayRef<int64_t>{0},                    // static offset
+          ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
+          ArrayRef<int64_t>{1})                    // static stride
+      .getResult();
 }
 
-ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
+static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
   ReassociationIndices reassociation;
   for (int i = 0, e = srcTp.getRank(); i < e; i++)
     reassociation.push_back(i);
@@ -1243,23 +1222,21 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
                 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
           } else {
             // Else simply takes the inputs.
-            Value field = fKind == SparseTensorFieldKind::ValMemRef
-                              ? op.getValues()
-                              : op.getLevels()[fIdx];
-
-            auto tensorType = field.getType().cast<RankedTensorType>();
-            auto memrefType = MemRefType::get(tensorType.getShape(),
-                                              tensorType.getElementType());
-            field = rewriter.create<bufferization::ToMemrefOp>(
-                op->getLoc(), memrefType, field);
-            if (memrefType.getRank() > 1) {
+            Value tensor = fKind == SparseTensorFieldKind::ValMemRef
+                               ? op.getValues()
+                               : op.getLevels()[fIdx];
+
+            TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
+            if (mem.getType().getRank() > 1) {
               // Flattens the buffer to rank 1.
-              auto reassoc = getReassociationForFlattening(memrefType);
-              field =
-                  rewriter.create<memref::CollapseShapeOp>(loc, field, reassoc);
+              auto reassoc = getReassociationForFlattening(mem.getType());
+              mem = rewriter.create<memref::CastOp>(
+                  loc, fType,
+                  rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
+            } else {
+              mem = rewriter.create<memref::CastOp>(loc, fType, mem);
             }
-            field = rewriter.create<memref::CastOp>(loc, fType, field);
-            fields.push_back(field);
+            fields.push_back(mem);
           }
           return true;
         });
@@ -1269,6 +1246,9 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
     Value c2 = constantIndex(rewriter, loc, 2);
     Value posBack = c1; // index to the last value in the postion array
     Value memSize = c2; // memory size for current array
+
+    Level trailCOOStart = getCOOStart(stt.getEncoding());
+    Level trailCOORank = stt.getLvlRank() - trailCOOStart;
     // Sets up SparseTensorSpecifier.
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
       assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
@@ -1277,6 +1257,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
       // Sets up the level size.
       auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
       desc.setLvlSize(rewriter, loc, lvl, lvlSize);
+      // We use a single AOS array to store the trailing COO, so there is only
+      // one memory size to set for the entire COO section.
+      if (lvl > trailCOOStart)
+        continue;
 
       // Sets up the memory size by reading the last value in position array.
       DimLevelType dlt = stt.getLvlType(lvl);
@@ -1298,8 +1282,15 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
         memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
         posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
       }
-      assert(isDLTWithCrd(dlt));
-      desc.setCrdMemSize(rewriter, loc, lvl, memSize);
+      assert(isDLTWithCrd(dlt) && lvl <= trailCOOStart);
+      // FIXME: This seems to be unnecessarily complex, can we simplify it?
+      if (lvl == trailCOOStart) {
+        Value cooSz = rewriter.create<arith::MulIOp>(
+            loc, memSize, constantIndex(rewriter, loc, trailCOORank));
+        desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
+      } else {
+        desc.setCrdMemSize(rewriter, loc, lvl, memSize);
+      }
     }
     desc.setValMemSize(rewriter, loc, memSize);
 
@@ -1308,166 +1299,6 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
   }
 };
 
-static LogicalResult genUnBatchedUnpackOp(UnpackOp op,
-                                          SparseTensorDescriptor desc,
-                                          ConversionPatternRewriter &rewriter) {
-  Location loc = op.getLoc();
-  const auto srcTp = getSparseTensorType(op.getTensor());
-  const Level lvlRank = srcTp.getLvlRank();
-  Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
-                               : desc.getAOSMemRef();
-  Value valuesBuf = desc.getValMemRef();
-
-  // If frontend requests a static buffer, we reallocate the
-  // values/coordinates to ensure that we meet their need.
-  const auto valuesTp = getRankedTensorType(op.getValues());
-  if (valuesTp.hasStaticShape()) {
-    // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
-    // tensor that is packed from constants.
-    valuesBuf =
-        reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
-  }
-
-  const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
-  if (coordinatesTp.hasStaticShape()) {
-    // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
-    // tensor that is packed from constants.
-    auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
-    flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
-  }
-
-  Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
-      loc,
-      MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()),
-      flatBuf, ArrayRef{ReassociationIndices{0, 1}});
-
-  // Converts MemRefs back to Tensors.
-  Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
-  Value coordinates =
-      rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
-  Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
-                      op.getNse().getType());
-
-  rewriter.replaceOp(op, {values, coordinates, nse});
-  return success();
-}
-
-static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
-                                        SparseTensorDescriptor desc,
-                                        ConversionPatternRewriter &rewriter) {
-  assert(nBatched != 0);
-  Location loc = op.getLoc();
-  Value c0 = constantIndex(rewriter, loc, 0);
-  Value c1 = constantIndex(rewriter, loc, 1);
-  Value c2 = constantIndex(rewriter, loc, 2);
-
-  auto genZeroedAlloc = [loc,
-                         &rewriter](TensorType tt) -> TypedValue<MemRefType> {
-    auto mem = rewriter
-                   .create<memref::AllocOp>(
-                       loc, MemRefType::get(tt.getShape(), tt.getElementType()))
-                   .getMemref();
-    // TODO: Instead of filling the entire buffer, we can only fill the
-    // trailing zeros.
-    rewriter.create<linalg::FillOp>(
-        loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem);
-    return mem;
-  };
-  SparseTensorType stt = getSparseTensorType(op.getTensor());
-  TensorType valTensorTp = op.getValues().getType();
-  TensorType crdTensorTp = op.getCoordinates().getType();
-  TypedValue<MemRefType> valMemref = genZeroedAlloc(valTensorTp);
-  TypedValue<MemRefType> crdMemref = genZeroedAlloc(crdTensorTp);
-  assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape());
-
-  SmallVector<Value> lbs(nBatched, c0), steps(nBatched, c1);
-  SmallVector<Value> ubs;
-  for (unsigned i = 0; i < nBatched; i++) {
-    assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
-    ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i]));
-  }
-
-  DimLevelType dlt = stt.getLvlType(nBatched);
-  assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
-  Value posStep = isCompressedDLT(dlt) ? c1  // forward position index by 1
-                                       : c2; // forward position index by 2
-  auto loopNest = scf::buildLoopNest(
-      rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/},
-      [&ubs, c0, c1, posStep, desc, nBatched, &valMemref,
-       &crdMemref](OpBuilder &builder, Location loc, ValueRange ivs,
-                   ValueRange args) -> scf::ValueVector {
-        // crdMemref has shape: <... x nse x rank>
-        unsigned unBatchedRank = crdMemref.getType().getShape().back();
-        Value values = desc.getValMemRef();
-        Value flatCrds = unBatchedRank == 1
-                             ? desc.getCrdMemRefOrView(builder, loc, 0)
-                             : desc.getAOSMemRef();
-
-        Value positions = desc.getPosMemRef(nBatched);
-        Value positLo = builder.create<arith::MulIOp>(
-            loc, linearize(builder, loc, ivs, ubs), posStep);
-        Value positHi = builder.create<arith::AddIOp>(loc, positLo, c1);
-
-        Value pLo = genIndexLoad(builder, loc, positions, positLo);
-        Value pHi = genIndexLoad(builder, loc, positions, positHi);
-        Value nse = builder.create<arith::SubIOp>(loc, pHi, pLo);
-
-        Value crdLo = builder.create<arith::MulIOp>(
-            loc, pLo, constantIndex(builder, loc, unBatchedRank));
-        Value nCrd = builder.create<arith::MulIOp>(
-            loc, nse, constantIndex(builder, loc, unBatchedRank));
-
-        SmallVector<Value> offsets, sizes, strides;
-        for (unsigned i = 0; i < nBatched; i++) {
-          offsets.push_back(ivs[i]);
-          sizes.push_back(c1);
-          strides.push_back(c1);
-        }
-        // [0, nse, 1].
-        offsets.push_back(c0);
-        sizes.push_back(nse);
-        strides.push_back(c1);
-
-        auto valView = builder.create<memref::SubViewOp>(
-            loc, valMemref, offsets, sizes, strides);
-        auto valReass = getReassociationForFlattening(valView.getType());
-        Value valDst =
-            builder.create<memref::CollapseShapeOp>(loc, valView, valReass);
-        Value valSrc =
-            builder.create<memref::SubViewOp>(loc, values, pLo, nse, c1);
-        builder.create<memref::CopyOp>(loc, valSrc, valDst);
-
-        // [0, rank, 1].
-        offsets.push_back(c0);
-        sizes.push_back(constantIndex(builder, loc, unBatchedRank));
-        strides.push_back(c1);
-
-        auto crdView = builder.create<memref::SubViewOp>(
-            loc, crdMemref, offsets, sizes, strides);
-        auto crdReass = getReassociationForFlattening(crdView.getType());
-        Value crdDst =
-            builder.create<memref::CollapseShapeOp>(loc, crdView, crdReass);
-        Value crdSrc =
-            builder.create<memref::SubViewOp>(loc, flatCrds, crdLo, nCrd, c1);
-        builder.create<memref::CopyOp>(loc, crdSrc, crdDst);
-
-        Value pred = builder.create<arith::CmpIOp>(
-            loc, arith::CmpIPredicate::ugt, nse, args[0]);
-        // Choose the larger NSE
-        return {builder.create<arith::SelectOp>(loc, pred, nse, args[0])};
-      });
-
-  // Converts MemRefs back to Tensors.
-  Value values = rewriter.create<bufferization::ToTensorOp>(loc, valMemref);
-  Value coordinates =
-      rewriter.create<bufferization::ToTensorOp>(loc, crdMemref);
-  Value nse =
-      genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType());
-
-  rewriter.replaceOp(op, {values, coordinates, nse});
-  return success();
-}
-
 struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
   using OpConversionPattern::OpConversionPattern;
   SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context)
@@ -1477,13 +1308,56 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
   matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    const auto srcTp = getSparseTensorType(op.getTensor());
-    const unsigned nBatched = op.getNumBatchedLvls();
-    assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
-           desc.getFields().size() == 4); // specifier + pos + crds + values
-    (void)srcTp;
-    return nBatched == 0 ? genUnBatchedUnpackOp(op, desc, rewriter)
-                         : genBatchedUnpackOp(op, nBatched, desc, rewriter);
+    Location loc = op.getLoc();
+    SmallVector<Value> retMem;
+    desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem](
+                                      FieldIndex fid,
+                                      SparseTensorFieldKind fKind, Level lvl,
+                                      DimLevelType dlt) -> bool {
+      if (fKind == SparseTensorFieldKind::StorageSpec)
+        return true;
+      SparseTensorType stt(desc.getRankedTensorType());
+      Value sz, src;
+      TypedValue<BaseMemRefType> dst;
+      if (fKind == SparseTensorFieldKind::ValMemRef) {
+        sz = desc.getValMemSize(rewriter, loc);
+        src = desc.getValMemRef();
+        dst = genToMemref(rewriter, loc, op.getOutValues());
+        // Values is the last field in descriptor, but it is the first
+        // operand in unpack operation.
+        // TODO: maybe change unpack/pack operation instead to be
+        // consistent.
+        retMem.insert(retMem.begin(), dst);
+      } else {
+        assert(fKind == SparseTensorFieldKind::PosMemRef ||
+               fKind == SparseTensorFieldKind::CrdMemRef);
+
+        sz = fKind == SparseTensorFieldKind::PosMemRef
+                 ? desc.getPosMemSize(rewriter, loc, lvl)
+                 : desc.getCrdMemSize(rewriter, loc, lvl);
+        src = desc.getMemRefField(fid);
+        dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
+        retMem.push_back(dst);
+      }
+      Value flatOut = dst;
+      if (dst.getType().getRank() != 1) {
+        auto reassoc = getReassociationForFlattening(dst.getType());
+        flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
+      }
+      Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
+      Value srcMem = genSliceToSize(rewriter, loc, src, sz);
+      rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
+      return true;
+    });
+
+    // Converts MemRefs back to Tensors.
+    SmallVector<Value> retTensor = llvm::to_vector(
+        llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
+          return rewriter.create<bufferization::ToTensorOp>(loc, v);
+        }));
+
+    rewriter.replaceOp(op, retTensor);
+    return success();
   }
 };
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index 9b18394dee7e2..cf7532a522fa7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -156,6 +156,7 @@ class SparseTensorDescriptorImpl {
 
   RankedTensorType getRankedTensorType() const { return rType; }
   ValueArrayRef getFields() const { return fields; }
+  StorageLayout getLayout() const { return layout; }
 
 protected:
   SparseTensorType rType;

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3531400f75e81..c1e8afd9206ba 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -56,50 +56,38 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], posWidth=32, crdWidth=32}>
 
-func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
-                            -> (tensor<6xf64>, tensor<6x1xi32>, i32) {
+func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
   // expected-error at +1 {{input/output element-types don't match}}
-  %values, %coordinates, %nse = sparse_tensor.unpack %sp
-     : tensor<100xf32, #SparseVector> to tensor<6xf64>, tensor<6x1xi32>, i32
-  return %values, %coordinates, %nse : tensor<6xf64>, tensor<6x1xi32>, i32
-}
-
-// -----
-
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
-
-func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
-                            -> (tensor<5xf32>, tensor<6x1xi32>, i32) {
-  // expected-error at +1 {{values/coordinates number-of-elements don't match}}
-  %values, %coordinates, %nse = sparse_tensor.unpack %sp
-     : tensor<100xf32, #SparseVector> to tensor<5xf32>, tensor<6x1xi32>, i32
-  return %values, %coordinates, %nse : tensor<5xf32>, tensor<6x1xi32>, i32
+  %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100xf32, #SparseVector>
+                  outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>)
+                  -> tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+  return
 }
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed-nu", "singleton"], posWidth=32, crdWidth=32}>
 
-func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
-                            -> (tensor<6xf32>, tensor<6x2xi32>, i32) {
-  // expected-error at +1 {{input/output level-ranks don't match}}
-  %values, %coordinates, %nse = sparse_tensor.unpack %sp
-     : tensor<100xf32, #SparseVector> to tensor<6xf32>, tensor<6x2xi32>, i32
-  return %values, %coordinates, %nse : tensor<6xf32>, tensor<6x2xi32>, i32
+func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
+  // expected-error at +1 {{input/output trailing COO level-ranks don't match}}
+  %rv, %rp, %rc = sparse_tensor.unpack %sp : tensor<100x2xf64, #SparseVector>
+                  outs(%values, %pos, %coordinates : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>)
+                  -> tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32>
+  return
 }
 
 // -----
 
-#BCOO = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed-hi"], crdWidth=32}>
+#CSR = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed"], posWidth=32, crdWidth=32}>
 
-func.func @invalid_unpack_type(%sp: tensor<2x100xf32, #BCOO>)
-                            -> (tensor<2x6xf32>, tensor<3x6x2xi32>, i32) {
-  // expected-error at +1 {{values/coordinates batched level sizes don't match statically}}
-  %values, %coordinates, %nse = sparse_tensor.unpack %sp batched_lvls=1
-     : tensor<2x100xf32, #BCOO> to tensor<2x6xf32>, tensor<3x6x2xi32>, i32
-  return %values, %coordinates, %nse : tensor<2x6xf32>, tensor<3x6x2xi32>, i32
+func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
+  // expected-error at +1 {{inconsistent number of fields between input/output}}
+  %rv, %rc = sparse_tensor.unpack %sp : tensor<2x100xf64, #CSR>
+             outs(%values, %coordinates : tensor<6xf64>, tensor<6xi32>)
+             -> tensor<6xf64>, tensor<6xi32>
+  return
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 41cc5e775c98c..57dff1e53edc3 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -33,28 +33,20 @@ func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor
 #SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
 // CHECK-LABEL: func @sparse_unpack(
 //  CHECK-SAME: %[[T:.*]]: tensor<100xf64, #
-//       CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]]
-//       CHECK: return %[[D]], %[[I]], %[[N]]
-func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>)
-                       -> (tensor<6xf64>, tensor<6x1xi32>, i32) {
-  %data, %indices, %nnz = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
-                                                  to tensor<6xf64>, tensor<6x1xi32>, i32
-  return %data, %indices, %nnz : tensor<6xf64>, tensor<6x1xi32>, i32
-}
-
-// -----
-
-#BatchedSparseVector = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed-hi"], crdWidth=32}>
-
-// CHECK-LABEL: func @sparse_unpack(
-//  CHECK-SAME: %[[T:.*]]: tensor<2x100xf64, #
-//       CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]] batched_lvls = 1
-//       CHECK: return %[[D]], %[[I]], %[[N]]
-func.func @sparse_unpack(%sp : tensor<2x100xf64, #BatchedSparseVector>)
-                           -> (tensor<2x6xf64>, tensor<2x6x1xi32>, i32) {
-  %data, %indices, %nnz = sparse_tensor.unpack %sp batched_lvls=1
-       : tensor<2x100xf64, #BatchedSparseVector> to tensor<2x6xf64>, tensor<2x6x1xi32>, i32
-  return %data, %indices, %nnz : tensor<2x6xf64>, tensor<2x6x1xi32>, i32
+//  CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
+//  CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
+//  CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
+//       CHECK: %[[D:.*]], %[[P:.*]]:2 = sparse_tensor.unpack %[[T]]
+//       CHECK: return %[[D]], %[[P]]#0, %[[P]]#1
+func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
+                         %od : tensor<6xf64>,
+                         %op : tensor<2xindex>,
+                         %oi : tensor<6x1xi32>)
+                       -> (tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>) {
+  %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100xf64, #SparseVector>
+                  outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>)
+                  -> tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
+  return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x1xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 1d948cbd604fd..09ba910fc3cfc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -23,9 +23,9 @@
 // CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  lvl_sz at 0 with %[[VAL_13]]
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_12]]
 // CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_16]]
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  lvl_sz at 1 with %[[VAL_13]]
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  crd_mem_sz at 1 with %[[VAL_16]]
+// CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_17]]
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_13]]
 // CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  val_mem_sz with %[[VAL_16]]
 // CHECK:           return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]]
 // CHECK:         }
@@ -40,36 +40,38 @@ func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinate
 // CHECK-SAME:      %[[VAL_0:.*]]: memref<?xindex>,
 // CHECK-SAME:      %[[VAL_1:.*]]: memref<?xi32>,
 // CHECK-SAME:      %[[VAL_2:.*]]: memref<?xf64>,
-// CHECK-SAME:      %[[VAL_3:.*]]
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 6 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
-// CHECK:           %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
-// CHECK:             %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
-// CHECK:             scf.yield %[[VAL_9]] : memref<6xf64>
-// CHECK:           } else {
-// CHECK:             %[[VAL_10:.*]] = memref.subview %[[VAL_2]][0] [6] [1] : memref<?xf64> to memref<6xf64>
-// CHECK:             scf.yield %[[VAL_10]] : memref<6xf64>
-// CHECK:           }
-// CHECK:           %[[VAL_11:.*]] = arith.constant 12 : index
-// CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
-// CHECK:           %[[VAL_13:.*]] = arith.cmpi ugt, %[[VAL_11]], %[[VAL_12]] : index
-// CHECK:           %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
-// CHECK:             %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
-// CHECK:             scf.yield %[[VAL_15]] : memref<12xi32>
-// CHECK:           } else {
-// CHECK:             %[[VAL_16:.*]] = memref.subview %[[VAL_1]][0] [12] [1] : memref<?xi32> to memref<12xi32>
-// CHECK:             scf.yield %[[VAL_16]] : memref<12xi32>
-// CHECK:           }
-// CHECK:           %[[VAL_17:.*]] = memref.expand_shape %[[VAL_18:.*]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
-// CHECK:           %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
-// CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
-// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier
-// CHECK:           return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ lvlTypes = [ "compressed", "singleton" ] }>>,
+// CHECK-SAME:      %[[VAL_4:.*]]: tensor<6xf64>,
+// CHECK-SAME:      %[[VAL_5:.*]]: tensor<2xindex>,
+// CHECK-SAME:      %[[VAL_6:.*]]: tensor<6x2xi32>) -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  pos_mem_sz at 0
+// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_5]] : memref<2xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_7]]] [1] : memref<2xindex> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_7]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK:           memref.copy %[[VAL_10]], %[[VAL_9]] : memref<?xindex> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  crd_mem_sz at 0
+// CHECK:           %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_6]] : memref<6x2xi32>
+// CHECK:           %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK:           %[[VAL_14:.*]] = memref.subview %[[VAL_13]][0] {{\[}}%[[VAL_11]]] [1] : memref<12xi32> to memref<?xi32>
+// CHECK:           %[[VAL_15:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xi32> to memref<?xi32>
+// CHECK:           memref.copy %[[VAL_15]], %[[VAL_14]] : memref<?xi32> to memref<?xi32>
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  val_mem_sz
+// CHECK:           %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_4]] : memref<6xf64>
+// CHECK:           %[[VAL_18:.*]] = memref.subview %[[VAL_17]][0] {{\[}}%[[VAL_16]]] [1] : memref<6xf64> to memref<?xf64>
+// CHECK:           %[[VAL_19:.*]] = memref.subview %[[VAL_2]][0] {{\[}}%[[VAL_16]]] [1] : memref<?xf64> to memref<?xf64>
+// CHECK:           memref.copy %[[VAL_19]], %[[VAL_18]] : memref<?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6xf64>
+// CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<2xindex>
+// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x2xi32>
+// CHECK:           return %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
 // CHECK:         }
-func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
-  %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
-                                         to tensor<6xf64>, tensor<6x2xi32>, index
-  return %d, %i, %nnz : tensor<6xf64>, tensor<6x2xi32>, index
+func.func @sparse_unpack(%sp : tensor<100x100xf64, #COO>,
+                         %od : tensor<6xf64>,
+                         %op : tensor<2xindex>,
+                         %oi : tensor<6x2xi32>)
+                       -> (tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>) {
+  %rd, %rp, %ri = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
+                  outs(%od, %op, %oi : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>)
+                  -> tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
+  return %rd, %rp, %ri : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32>
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 4b44436b6da54..4c541a6b61a0f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -179,8 +179,12 @@ module {
         vector.print %v: f64
      }
 
-    %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
-                                         to tensor<3xf64>, tensor<3x2xi32>, i32
+    %od = tensor.empty() : tensor<3xf64>
+    %op = tensor.empty() : tensor<2xi32>
+    %oi = tensor.empty() : tensor<3x2xi32>
+    %d, %p, %i = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
+                 outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
+                 -> tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
 
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -190,30 +194,22 @@ module {
     %vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
     vector.print %vi : vector<3x2xi32>
 
-    // CHECK-NEXT: 3
-    vector.print %n : i32
 
+    %bod = tensor.empty() : tensor<6xf64>
+    %bop = tensor.empty() : tensor<4xindex>
+    %boi = tensor.empty() : tensor<6x2xindex>
+    %bd, %bp, %bi = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
+                    outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
+                    -> tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>
 
-    %bd, %bi, %bn = sparse_tensor.unpack %bs batched_lvls=1 :
-       tensor<2x10x10xf64, #BCOO> to tensor<2x3xf64>, tensor<2x3x2xindex>, i32
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
+    %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
+    vector.print %vbd : vector<6xf64>
 
-    // CHECK-NEXT: ( ( 1, 2, 3 ), ( 4, 5, 0 ) )
-    %vbd = vector.transfer_read %bd[%c0, %c0], %f0 : tensor<2x3xf64>, vector<2x3xf64>
-    vector.print %vbd : vector<2x3xf64>
+    // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) )
+    %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
+    vector.print %vbi : vector<6x2xindex>
 
-    // CHECK-NEXT: ( ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) ), ( ( 2, 3 ), ( 4, 2 ), ( 0, 0 ) ) )
-    %vbi = vector.transfer_read %bi[%c0, %c0, %c0], %c0 : tensor<2x3x2xindex>, vector<2x3x2xindex>
-    vector.print %vbi : vector<2x3x2xindex>
-
-    // CHECK-NEXT: 3
-    vector.print %bn : i32
-
-    %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
-                                         to tensor<3xf64>, tensor<3x2xindex>, index
-
-    // FIXME: This should be freed by one-shot-bufferization.
-    bufferization.dealloc_tensor %bd : tensor<2x3xf64>
-    bufferization.dealloc_tensor %bi : tensor<2x3x2xindex>
     return
   }
 }


        


More information about the Mlir-commits mailing list