[Mlir-commits] [mlir] d4db528 - [mlir][sparse] extend unpack operation to support unpacking a batched COO type

Peiming Liu llvmlistbot at llvm.org
Mon May 1 11:17:36 PDT 2023


Author: Peiming Liu
Date: 2023-05-01T18:17:29Z
New Revision: d4db52893857a836940e0951daa205de1bb1d201

URL: https://github.com/llvm/llvm-project/commit/d4db52893857a836940e0951daa205de1bb1d201
DIFF: https://github.com/llvm/llvm-project/commit/d4db52893857a836940e0951daa205de1bb1d201.diff

LOG: [mlir][sparse] extend unpack operation to support unpacking a batched COO type

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D149103

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_2d.mlir
    mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index eea58f91b583c..f29ea600e3347 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -124,9 +124,10 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
 }
 
 def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
-    Arguments<(ins AnySparseTensor:$tensor)>,
-    Results<(outs 1DTensorOf<[AnyType]>:$values,
-                  2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
+    Arguments<(ins AnySparseTensor:$tensor,
+                   OptionalAttr<IndexAttr>:$batched_lvls)>,
+    Results<(outs TensorOf<[AnyType]>:$values,
+                  TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
                   AnySignlessIntegerOrIndex:$nse)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
@@ -159,11 +160,44 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
     // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
     // %nse = 3
     ```
+
+    If `batched_lvls` is provided, the operation unpacks each batch of the tensors
+    separately. The returned `nse` is the maximum nse of all batches. For a batch with
+    a smaller nse, trailing zeros are appended in the result.
+    Example:
+
+    ```mlir
+    // input BCOO format |1.1, 2.2, 3.3, 0.0|
+    //      of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
+    %values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1
+        : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
+    // %values      = arith.constant dense<[[ 1.1,   2.2,   3.3 ],
+    //                                      [ 1.2,   2.3,   0.0 ]]> : tensor<2x3xf64>
+    // %coordinates = arith.constant dense<[[ [0],   [1],   [2] ],
+    //                                      [ [1],   [2],   [0] ]> : tensor<2x3x1xindex>
+    ```
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns the number of leading levels that are batched.
+    unsigned getNumBatchedLvls();
   }];
 
+  let builders = [
+    OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor),
+    [{
+      build($_builder, $_state, values, coordinates, nse, tensor, nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor),
+    [{
+      build($_builder, $_state, resultTypes, tensor, nullptr);
+    }]>
+  ];
+
+
   let assemblyFormat =
-    "$tensor attr-dict `:` type($tensor)"
-    "`to` type($values) `,` type($coordinates) `,` type($nse)";
+    "$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`"
+    "type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)";
 
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b2353016079b7..42776c7d80a32 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -719,7 +719,11 @@ LogicalResult UnpackOp::verify() {
   const auto coordinatesTp = getRankedTensorType(getCoordinates());
   const auto srcTp = getSparseTensorType(getTensor());
   return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
-                          nullptr);
+                          getBatchedLvlsAttr());
+}
+
+unsigned UnpackOp::getNumBatchedLvls() {
+  return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
 }
 
 LogicalResult ConvertOp::verify() {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 8a8b2eda5175d..f17c001308bf0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -153,9 +153,12 @@ struct UnpackOpInterface
     : public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
                                                     sparse_tensor::UnpackOp> {
   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
-    // Similar to InsertOp, reallocation is not considered to allocate a new
-    // piece of memory.
-    return false;
+    // We allocate and return unpacked memory if this is a batched unpack.
+    // When the number of batched levels equals to zero, we reuse the
+    // coordinates/values memref (and reallocation if the requested output size
+    // is larger than the actual size). Similar to InsertOp, reallocation is
+    // not considered to allocate a new piece of memory.
+    return llvm::cast<UnpackOp>(op).getNumBatchedLvls() != 0;
   }
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 3a488b311b95a..9aae52db873a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -213,6 +213,18 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
   return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
 }
 
+Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
+                                  Value s) {
+  Value load = builder.create<memref::LoadOp>(loc, mem, s);
+  if (!load.getType().isa<IndexType>()) {
+    if (load.getType().getIntOrFloatBitWidth() < 64)
+      load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
+    load =
+        builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
+  }
+  return load;
+}
+
 mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
   if (tp.isa<FloatType>())
     return builder.getFloatAttr(tp, 1.0);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index b6e6def4e5860..3e1d0b00f825b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -75,6 +75,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
 /// Add type casting between arith and index types when needed.
 Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
 
+/// Generates a pointer/index load from the sparse storage scheme. Narrower
+/// data types need to be zero extended before casting the value into the
+/// index type used for looping and indexing.
+Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);
+
 /// Generates a 1-valued attribute of the given type.  This supports
 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
 /// for unsupported types we raise `llvm_unreachable` rather than

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index afa4828bf170a..ba6b4641408a5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -41,25 +41,6 @@ using namespace mlir::sparse_tensor;
 // File local helper functions.
 //===----------------------------------------------------------------------===//
 
-/// Generates a pointer/index load from the sparse storage scheme. Narrower
-/// data types need to be zero extended before casting the value into the
-/// index type used for looping and indexing.
-static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
-                          Value s) {
-  // For the scalar case, we simply zero extend narrower indices into 64-bit
-  // values before casting to index without a performance penalty. Here too,
-  // however, indices that already are 64-bit, in theory, cannot express the
-  // full range as explained above.
-  Value load = builder.create<memref::LoadOp>(loc, mem, s);
-  if (!load.getType().isa<IndexType>()) {
-    if (load.getType().getIntOrFloatBitWidth() < 64)
-      load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
-    load =
-        builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
-  }
-  return load;
-}
-
 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
@@ -707,7 +688,8 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
       continue;
     }
 
-    bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType);
+    bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
+                    isCompressedWithHiDLT(lvlType);
     // We can at most have one sparse input, otherwise, a while loop is
     // required to co-iterate multiple sparse tensors.
     assert(!isSparseCond || !isSparse);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index c1cb0926622f6..4b94392b3d19c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -602,6 +602,25 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
   return ifOp.getResult(0);
 }
 
+static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange bounds) {
+  assert(ivs.size() == bounds.size());
+  Value crd = constantIndex(builder, loc, 0);
+  for (unsigned i = 0, e = ivs.size(); i < e; i++) {
+    crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
+    if (i != ivs.size() - 1)
+      crd = builder.create<arith::MulIOp>(loc, crd, bounds[i + 1]);
+  }
+  return crd;
+}
+
+ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
+  ReassociationIndices reassociation;
+  for (int i = 0, e = srcTp.getRank(); i < e; i++)
+    reassociation.push_back(i);
+  return reassociation;
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -1252,12 +1271,7 @@ static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
       [&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
                                                ValueRange ivs) {
         // Linearize index variables
-        Value crd = constantIndex(builder, loc, 0);
-        for (unsigned i = 0, e = ivs.size(); i < e; i++) {
-          crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
-          if (i != ivs.size() - 1)
-            crd = builder.create<arith::MulIOp>(loc, crd, ubs[i + 1]);
-        }
+        Value crd = linearize(builder, loc, ivs, ubs);
         Value len = constantIndex(builder, loc, nse);
         Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
         SmallVector<Value> indices(ivs.begin(), ivs.end());
@@ -1420,6 +1434,166 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
   }
 };
 
+static LogicalResult genUnBatchedUnpackOp(UnpackOp op,
+                                          SparseTensorDescriptor desc,
+                                          ConversionPatternRewriter &rewriter) {
+  Location loc = op.getLoc();
+  const auto srcTp = getSparseTensorType(op.getTensor());
+  const Level lvlRank = srcTp.getLvlRank();
+  Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
+                               : desc.getAOSMemRef();
+  Value valuesBuf = desc.getValMemRef();
+
+  // If frontend requests a static buffer, we reallocate the
+  // values/coordinates to ensure that we meet their need.
+  const auto valuesTp = getRankedTensorType(op.getValues());
+  if (valuesTp.hasStaticShape()) {
+    // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
+    // tensor that is packed from constants.
+    valuesBuf =
+        reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
+  }
+
+  const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
+  if (coordinatesTp.hasStaticShape()) {
+    // FIXME: Reallocation is not always safe! E.g., if we are unpacking a
+    // tensor that is packed from constants.
+    auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
+    flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
+  }
+
+  Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
+      loc,
+      MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()),
+      flatBuf, ArrayRef{ReassociationIndices{0, 1}});
+
+  // Converts MemRefs back to Tensors.
+  Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
+  Value coordinates =
+      rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
+  Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
+                      op.getNse().getType());
+
+  rewriter.replaceOp(op, {values, coordinates, nse});
+  return success();
+}
+
+static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
+                                        SparseTensorDescriptor desc,
+                                        ConversionPatternRewriter &rewriter) {
+  assert(nBatched != 0);
+  Location loc = op.getLoc();
+  Value c0 = constantIndex(rewriter, loc, 0);
+  Value c1 = constantIndex(rewriter, loc, 1);
+  Value c2 = constantIndex(rewriter, loc, 2);
+
+  auto genZeroedAlloc = [loc,
+                         &rewriter](TensorType tt) -> TypedValue<MemRefType> {
+    auto mem = rewriter
+                   .create<memref::AllocOp>(
+                       loc, MemRefType::get(tt.getShape(), tt.getElementType()))
+                   .getMemref();
+    // TODO: Instead of filling the entire buffer, we can only fill the
+    // trailing zeros.
+    rewriter.create<linalg::FillOp>(
+        loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem);
+    return mem;
+  };
+  SparseTensorType stt = getSparseTensorType(op.getTensor());
+  TensorType valTensorTp = op.getValues().getType();
+  TensorType crdTensorTp = op.getCoordinates().getType();
+  TypedValue<MemRefType> valMemref = genZeroedAlloc(valTensorTp);
+  TypedValue<MemRefType> crdMemref = genZeroedAlloc(crdTensorTp);
+  assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape());
+
+  SmallVector<Value> lbs(nBatched, c0), steps(nBatched, c1);
+  SmallVector<Value> ubs;
+  for (unsigned i = 0; i < nBatched; i++) {
+    assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
+    ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i]));
+  }
+
+  DimLevelType dlt = stt.getLvlType(nBatched);
+  assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
+  Value posStep = isCompressedDLT(dlt) ? c1  // forward position index by 1
+                                       : c2; // forward position index by 2
+  auto loopNest = scf::buildLoopNest(
+      rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/},
+      [&ubs, c0, c1, posStep, desc, nBatched, &valMemref,
+       &crdMemref](OpBuilder &builder, Location loc, ValueRange ivs,
+                   ValueRange args) -> scf::ValueVector {
+        // crdMemref has shape: <... x nse x rank>
+        unsigned unBatchedRank = crdMemref.getType().getShape().back();
+        Value values = desc.getValMemRef();
+        Value flatCrds = unBatchedRank == 1
+                             ? desc.getCrdMemRefOrView(builder, loc, 0)
+                             : desc.getAOSMemRef();
+
+        Value positions = desc.getPosMemRef(nBatched);
+        Value positLo = builder.create<arith::MulIOp>(
+            loc, linearize(builder, loc, ivs, ubs), posStep);
+        Value positHi = builder.create<arith::AddIOp>(loc, positLo, c1);
+
+        Value pLo = genIndexLoad(builder, loc, positions, positLo);
+        Value pHi = genIndexLoad(builder, loc, positions, positHi);
+        Value nse = builder.create<arith::SubIOp>(loc, pHi, pLo);
+
+        Value crdLo = builder.create<arith::MulIOp>(
+            loc, pLo, constantIndex(builder, loc, unBatchedRank));
+        Value nCrd = builder.create<arith::MulIOp>(
+            loc, nse, constantIndex(builder, loc, unBatchedRank));
+
+        SmallVector<Value> offsets, sizes, strides;
+        for (unsigned i = 0; i < nBatched; i++) {
+          offsets.push_back(ivs[i]);
+          sizes.push_back(c1);
+          strides.push_back(c1);
+        }
+        // [0, nse, 1].
+        offsets.push_back(c0);
+        sizes.push_back(nse);
+        strides.push_back(c1);
+
+        auto valView = builder.create<memref::SubViewOp>(
+            loc, valMemref, offsets, sizes, strides);
+        auto valReass = getReassociationForFlattening(valView.getType());
+        Value valDst =
+            builder.create<memref::CollapseShapeOp>(loc, valView, valReass);
+        Value valSrc =
+            builder.create<memref::SubViewOp>(loc, values, pLo, nse, c1);
+        builder.create<memref::CopyOp>(loc, valSrc, valDst);
+
+        // [0, rank, 1].
+        offsets.push_back(c0);
+        sizes.push_back(constantIndex(builder, loc, unBatchedRank));
+        strides.push_back(c1);
+
+        auto crdView = builder.create<memref::SubViewOp>(
+            loc, crdMemref, offsets, sizes, strides);
+        auto crdReass = getReassociationForFlattening(crdView.getType());
+        Value crdDst =
+            builder.create<memref::CollapseShapeOp>(loc, crdView, crdReass);
+        Value crdSrc =
+            builder.create<memref::SubViewOp>(loc, flatCrds, crdLo, nCrd, c1);
+        builder.create<memref::CopyOp>(loc, crdSrc, crdDst);
+
+        Value pred = builder.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::ugt, nse, args[0]);
+        // Choose the larger NSE
+        return {builder.create<arith::SelectOp>(loc, pred, nse, args[0])};
+      });
+
+  // Converts MemRefs back to Tensors.
+  Value values = rewriter.create<bufferization::ToTensorOp>(loc, valMemref);
+  Value coordinates =
+      rewriter.create<bufferization::ToTensorOp>(loc, crdMemref);
+  Value nse =
+      genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType());
+
+  rewriter.replaceOp(op, {values, coordinates, nse});
+  return success();
+}
+
 struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
   using OpConversionPattern::OpConversionPattern;
   SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
@@ -1431,52 +1605,26 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
   matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    Location loc = op.getLoc();
     const auto srcTp = getSparseTensorType(op.getTensor());
-    const Level lvlRank = srcTp.getLvlRank();
+    const unsigned nBatched = op.getNumBatchedLvls();
+    assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
+           desc.getFields().size() == 4); // specifier + pos + crds + values
+    auto logicRes = nBatched == 0
+                        ? genUnBatchedUnpackOp(op, desc, rewriter)
+                        : genBatchedUnpackOp(op, nBatched, desc, rewriter);
+    Value posBuf = desc.getPosMemRef(nBatched);
 
-    assert(isUniqueCOOType(srcTp) && desc.getFields().size() == 4);
-
-    Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
-                                 : desc.getAOSMemRef();
-    Value valuesBuf = desc.getValMemRef();
-    Value posBuf = desc.getPosMemRef(0);
     if (createDeallocs) {
       // Unpack ends the lifetime of the sparse tensor. While the value array
       // and coordinate array are unpacked and returned, the position array
       // becomes useless and need to be freed (if user requests).
-      rewriter.create<memref::DeallocOp>(loc, posBuf);
-    }
-
-    // If frontend requests a static buffer, we reallocate the
-    // values/coordinates to ensure that we meet their need.
-    const auto valuesTp = getRankedTensorType(op.getValues());
-    if (valuesTp.hasStaticShape()) {
-      valuesBuf =
-          reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
-    }
-
-    const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
-    if (coordinatesTp.hasStaticShape()) {
-      auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
-      flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
+      // FIXME: Depending on whether the tensor being unpacked is created by
+      // PackOp or not, we may or may not need to free other memref fields of
+      // the sparse tensor too (PackOp borrows value/coordinate buffer).
+      rewriter.create<memref::DeallocOp>(op.getLoc(), posBuf);
     }
 
-    Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
-        loc,
-        MemRefType::get(coordinatesTp.getShape(),
-                        coordinatesTp.getElementType()),
-        flatBuf, ArrayRef{ReassociationIndices{0, 1}});
-
-    // Converts MemRefs back to Tensors.
-    Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
-    Value coordinates =
-        rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
-    Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
-                        op.getNse().getType());
-
-    rewriter.replaceOp(op, {values, coordinates, nse});
-    return success();
+    return logicRes;
   }
 
 private:

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b6f43adaf399c..0766e906c7216 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -128,6 +128,18 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
 
 // -----
 
+#BCOO = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}>
+
+func.func @invalid_unpack_type(%sp: tensor<2x100xf32, #BCOO>)
+                            -> (tensor<2x6xf32>, tensor<3x6x2xi32>, i32) {
+  // expected-error at +1 {{values/coordinates batched level sizes don't match statically}}
+  %values, %coordinates, %nse = sparse_tensor.unpack %sp batched_lvls=1
+     : tensor<2x100xf32, #BCOO> to tensor<2x6xf32>, tensor<3x6x2xi32>, i32
+  return %values, %coordinates, %nse : tensor<2x6xf32>, tensor<3x6x2xi32>, i32
+}
+
+// -----
+
 func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
   // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<128xf64> to memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e3e548c993714..3bfa7c2164494 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -59,6 +59,21 @@ func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>)
 
 // -----
 
+#BatchedSparseVector = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}>
+
+// CHECK-LABEL: func @sparse_unpack(
+//  CHECK-SAME: %[[T:.*]]: tensor<2x100xf64, #
+//       CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]] batched_lvls = 1
+//       CHECK: return %[[D]], %[[I]], %[[N]]
+func.func @sparse_unpack(%sp : tensor<2x100xf64, #BatchedSparseVector>)
+                           -> (tensor<2x6xf64>, tensor<2x6x1xi32>, i32) {
+  %data, %indices, %nnz = sparse_tensor.unpack %sp batched_lvls=1
+       : tensor<2x100xf64, #BatchedSparseVector> to tensor<2x6xf64>, tensor<2x6x1xi32>, i32
+  return %data, %indices, %nnz : tensor<2x6xf64>, tensor<2x6x1xi32>, i32
+}
+
+// -----
+
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
 // CHECK-LABEL: func @sparse_dealloc(

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 42f2f1c35c5b4..58dc1e49dcf98 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -603,19 +603,19 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
   dimLevelType = [ "dense", "compressed-hi" ],
 }>
 // CHECK-LABEL:   func.func @sub_ss_batched(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x3xf64, #{{.*}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<2x3xf64, #{{.*}}>>) -> tensor<2x3xf64, #{{.*}}>> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
-// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
-// CHECK:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #{{.*}}>>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #{{.*}}>> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #{{.*}}>> to memref<?xf64>
+// CHECK:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #{{.*}}>>) {
 // CHECK:             %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index
 // CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_4]] : index
@@ -628,9 +628,9 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK:               %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_18]] : index
 // CHECK:               %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_22]] : index
 // CHECK:               %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
-// CHECK:               scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:               scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #{{.*}}>>
 // CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>):
+// CHECK:             ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #{{.*}}>>):
 // CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_30]]] : memref<?xindex>
 // CHECK:               %[[VAL_34:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xindex>
 // CHECK:               %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_33]] : index
@@ -638,31 +638,31 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK:               %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index
 // CHECK:               %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index
 // CHECK:               %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1
-// CHECK:               %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK:               %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #{{.*}}>>) {
 // CHECK:                 %[[VAL_41:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref<?xf64>
 // CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<?xf64>
 // CHECK:                 %[[VAL_43:.*]] = arith.subf %[[VAL_41]], %[[VAL_42]] : f64
-// CHECK:                 %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:                 scf.yield %[[VAL_44]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                 %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK:                 scf.yield %[[VAL_44]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:               } else {
 // CHECK:                 %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index
-// CHECK:                 %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK:                 %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #{{.*}}>>) {
 // CHECK:                   %[[VAL_47:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref<?xf64>
-// CHECK:                   %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:                   scf.yield %[[VAL_48]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                   %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK:                   scf.yield %[[VAL_48]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:                 } else {
 // CHECK:                   %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index
-// CHECK:                   %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK:                   %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #{{.*}}>>) {
 // CHECK:                     %[[VAL_51:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<?xf64>
 // CHECK:                     %[[VAL_52:.*]] = arith.negf %[[VAL_51]] : f64
-// CHECK:                     %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:                     scf.yield %[[VAL_53]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                     %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK:                     scf.yield %[[VAL_53]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:                   } else {
-// CHECK:                     scf.yield %[[VAL_32]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                     scf.yield %[[VAL_32]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:                   }
-// CHECK:                   scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                   scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:                 }
-// CHECK:                 scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                 scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:               }
 // CHECK:               %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index
 // CHECK:               %[[VAL_57:.*]] = arith.addi %[[VAL_30]], %[[VAL_4]] : index
@@ -670,25 +670,25 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
 // CHECK:               %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index
 // CHECK:               %[[VAL_60:.*]] = arith.addi %[[VAL_31]], %[[VAL_4]] : index
 // CHECK:               %[[VAL_61:.*]] = arith.select %[[VAL_59]], %[[VAL_60]], %[[VAL_31]] : index
-// CHECK:               scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:               scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #{{.*}}>>
 // CHECK:             } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_3]] to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_65:.*]] = %[[VAL_66:.*]]#2)
+// CHECK:             %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_65:.*]]#0 to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_66:.*]] = %[[VAL_65]]#2)
 // CHECK:               %[[VAL_67:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_64]]] : memref<?xindex>
 // CHECK:               %[[VAL_68:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_64]]] : memref<?xf64>
-// CHECK:               %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_65]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:               scf.yield %[[VAL_69]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:               %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_66]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK:               scf.yield %[[VAL_69]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_3]] to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_72:.*]] = %[[VAL_73:.*]])
-// CHECK:               %[[VAL_74:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref<?xindex>
-// CHECK:               %[[VAL_75:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref<?xf64>
-// CHECK:               %[[VAL_76:.*]] = arith.negf %[[VAL_75]] : f64
-// CHECK:               %[[VAL_77:.*]] = sparse_tensor.insert %[[VAL_76]] into %[[VAL_72]]{{\[}}%[[VAL_13]], %[[VAL_74]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:               scf.yield %[[VAL_77]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:             %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_72:.*]]#1 to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_73:.*]] = %[[VAL_74:.*]])
+// CHECK:               %[[VAL_75:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref<?xindex>
+// CHECK:               %[[VAL_76:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref<?xf64>
+// CHECK:               %[[VAL_77:.*]] = arith.negf %[[VAL_76]] : f64
+// CHECK:               %[[VAL_78:.*]] = sparse_tensor.insert %[[VAL_77]] into %[[VAL_73]]{{\[}}%[[VAL_13]], %[[VAL_75]]] : tensor<2x3xf64, #{{.*}}>>
+// CHECK:               scf.yield %[[VAL_78]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             scf.yield %[[VAL_78:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:             scf.yield %[[VAL_79:.*]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:           } {"Emitted from" = "linalg.generic"}
-// CHECK:           %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] hasInserts : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:           return %[[VAL_79]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_80:.*]] = sparse_tensor.load %[[VAL_81:.*]] hasInserts : tensor<2x3xf64, #{{.*}}>>
+// CHECK:           return %[[VAL_80]] : tensor<2x3xf64, #{{.*}}>>
 // CHECK:         }
 func.func @sub_ss_batched(%0: tensor<2x3xf64, #BatchedVector>, %1: tensor<2x3xf64, #BatchedVector>)
                            -> tensor<2x3xf64, #BatchedVector> {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index 57013e7715c43..3d95c86f4aa12 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -145,23 +145,25 @@ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
 }>
 
 // CHECK-LABEL:   func.func @foreach_bcoo(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x4x4xf64, #{{.*}}>>) {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 4 : index
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #{{.*}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #{{.*}}>> to memref<?xf64>
 // CHECK:           scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
 // CHECK:             %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
-// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_10]] step %[[VAL_3]] {
-// CHECK:               %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
-// CHECK:               "test.use"(%[[VAL_12]]) : (f64) -> ()
-// CHECK:             }
-// CHECK:           }
+// CHECK:             %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:             %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] {
+// CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xf64>
+// CHECK:               "test.use"(%[[VAL_13]]) : (f64) -> ()
+// CHECK:             } {"Emitted from" = "sparse_tensor.foreach"}
+// CHECK:           } {"Emitted from" = "sparse_tensor.foreach"}
 // CHECK:           return
+// CHECK:         }
 func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) {
   sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do {
   ^bb0(%1: index, %2: index, %3: index,  %v: f64) :

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4648cb3bf2983..fb0d4a73068d9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -45,7 +45,6 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
 // CHECK-SAME:      %[[VAL_3:.*]]
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 6 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG:       memref.dealloc %[[VAL_0]] : memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
 // CHECK:           %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
 // CHECK:           %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
@@ -69,6 +68,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
 // CHECK:           %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
 // CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
 // CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier
+// CHECK:       memref.dealloc %[[VAL_0]] : memref<?xindex>
 // CHECK:           return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
 // CHECK:         }
 func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 2b86d566ec4fd..34f0188a92720 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -31,6 +31,10 @@
   crdWidth = 32
 }>
 
+#BCOO = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ]
+}>
+
 module {
   //
   // Main driver.
@@ -60,6 +64,25 @@ module {
 
     %s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex>
                                           to tensor<10x10xf64, #SortedCOO>
+    %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32>
+                                           to tensor<10x10xf64, #SortedCOOI32>
+
+    %bdata = arith.constant dense<
+       [[  1.0,  2.0,  3.0],
+        [  4.0,  5.0,  0.0]]
+    > : tensor<2x3xf64>
+
+    %bindex = arith.constant dense<
+      [[[  1,  2],
+        [  5,  6],
+        [  7,  8]],
+       [[  2,  3],
+        [  4,  2],
+        [ 10, 10]]]
+    > : tensor<2x3x2xindex>
+    %bs = sparse_tensor.pack %bdata, %bindex batched_lvls = 1 :
+          tensor<2x3xf64>, tensor<2x3x2xindex> to tensor<2x10x10xf64, #BCOO>
+
     // CHECK:1
     // CHECK-NEXT:2
     // CHECK-NEXT:1
@@ -78,8 +101,6 @@ module {
         vector.print %v: f64
      }
 
-    %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32>
-                                          to tensor<10x10xf64, #SortedCOOI32>
     // CHECK-NEXT:1
     // CHECK-NEXT:2
     // CHECK-NEXT:1
@@ -98,11 +119,23 @@ module {
         vector.print %v: f64
      }
 
+    // CHECK-NEXT:1
+    // CHECK-NEXT:2
+    // CHECK-NEXT:3
+    //
+    // CHECK-NEXT:4
+    // CHECK-NEXT:5
+    //
+    // Make sure the trailing zeros are not traversed.
+    // CHECK-NOT: 0
+    sparse_tensor.foreach in %bs : tensor<2x10x10xf64, #BCOO> do {
+      ^bb0(%0: index, %1: index, %2: index, %v: f64) :
+        vector.print %v: f64
+     }
+
     %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
                                          to tensor<3xf64>, tensor<3x2xi32>, i32
 
-
-
     // CHECK-NEXT: ( 1, 2, 3 )
     %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
     vector.print %vd : vector<3xf64>
@@ -114,8 +147,26 @@ module {
     // CHECK-NEXT: 3
     vector.print %n : i32
 
+
+    %bd, %bi, %bn = sparse_tensor.unpack %bs batched_lvls=1 :
+       tensor<2x10x10xf64, #BCOO> to tensor<2x3xf64>, tensor<2x3x2xindex>, i32
+
+    // CHECK-NEXT: ( ( 1, 2, 3 ), ( 4, 5, 0 ) )
+    %vbd = vector.transfer_read %bd[%c0, %c0], %f0 : tensor<2x3xf64>, vector<2x3xf64>
+    vector.print %vbd : vector<2x3xf64>
+
+    // CHECK-NEXT: ( ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) ), ( ( 2, 3 ), ( 4, 2 ), ( 0, 0 ) ) )
+    %vbi = vector.transfer_read %bi[%c0, %c0, %c0], %c0 : tensor<2x3x2xindex>, vector<2x3x2xindex>
+    vector.print %vbi : vector<2x3x2xindex>
+
+    // CHECK-NEXT: 3
+    vector.print %bn : i32
+
     %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
                                          to tensor<3xf64>, tensor<3x2xindex>, index
+    // FIXME: This should be freed by one-shot-bufferization.
+    bufferization.dealloc_tensor %bd : tensor<2x3xf64>
+    bufferization.dealloc_tensor %bi : tensor<2x3x2xindex>
     return
   }
 }


        


More information about the Mlir-commits mailing list