[Mlir-commits] [mlir] de56088 - [mlir][sparse] Support packing external data into arbitrary sparse tensor encoding.

Peiming Liu llvmlistbot at llvm.org
Fri May 19 10:41:55 PDT 2023


Author: Peiming Liu
Date: 2023-05-19T17:41:49Z
New Revision: de56088866be46ec6644c2c6c5f3d90b1ceda31c

URL: https://github.com/llvm/llvm-project/commit/de56088866be46ec6644c2c6c5f3d90b1ceda31c
DIFF: https://github.com/llvm/llvm-project/commit/de56088866be46ec6644c2c6c5f3d90b1ceda31c.diff

LOG: [mlir][sparse] Support packing external data into arbitrary sparse tensor encoding.

We previously only support packing two array (values and coordinates) into COO tensors.
This patch allows packing inputs into arbitrary sparse tensor format.

It also deletes the "implicit" data canonicalization performed inside sparse compiler,
but instead requires users to canonicalize the data before passing it to the sparse compiler.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150916

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 72c1da8714b54..41fd19e71cb76 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -110,6 +110,16 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
+/// Convenience method to query whether a given DLT needs both position and
+/// coordinates array or only coordinates array.
+constexpr inline bool isDLTWithPos(DimLevelType dlt) {
+  return isCompressedWithHiDLT(dlt) || isCompressedDLT(dlt);
+}
+constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
+  return isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt) ||
+         isCompressedDLT(dlt);
+}
+
 /// Returns true iff the given sparse tensor encoding attribute has a trailing
 /// COO region starting at the given level.
 bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 180bd8bfd1f52..954c9cb49f8c4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,34 +55,32 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
 
 def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
     Arguments<(ins TensorOf<[AnyType]>:$values,
-                   TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
-                   OptionalAttr<IndexAttr>:$batched_lvls)>,
+                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
     Results<(outs AnySparseTensor: $result)> {
-  let summary = "Returns a sparse tensor from the given (values, coordinates) pair";
+  let summary = "Returns a sparse tensor from the given values, levels";
 
   let description = [{
-    Packs the values/coordinates into a COO sparse tensor.  The length
-    of `values` must match the outer-length of `coordinates`, since these
-    two tensors are "zipped" together.  The `coordinates` argument provides
-    level-coords for each value, therefore, the inner-length of `coordinates`
-    must match the level-rank of the returned tensor, and each level-coords
-    must be valid for the level-sizes of the returned tensor.  Note that
-    the returned tensor must be statically shaped because it is impossible
-    to infer the dimension-shape from level-coordinates alone.
+    Packs the values and per-level coordinate or postion arrays into a sparse tensor.
+    The order and types of provided levels must be consistent with the actual storage
+    layout of the returned sparse tensor described below.
+
+    - `values : tensor<? x V>`
+      supplies the value for each stored element in the sparse tensor.
+    - `levels: [tensor<? x iType>, ...]`
+      each supplies the sparse tensor coordinates scheme in the sparse tensor for
+      the corresponding level as specifed by `sparse_tensor::StorageLayout`.
+
+    This operation can be used to materialize a sparse tensor from external
+    sources; e.g., when passing two numpy arrays from Python.
+
+    Disclaimer: This is users' responsibility to provide input that can be
+    correctly interpreted by the sparse compiler, which does not perform
+    any sanity test during runtime to verify data integrity.
 
     TODO: The returned tensor is allowed (in principle) to have non-identity
     dimOrdering/higherOrdering mappings.  However, the current implementation
     does not yet support them.
 
-    - `coordinates : tensor<NSE x lvlRank x iType>`
-      supplies the level-coords for each element in `values`.
-    - `values : tensor<NSE x V>`
-      supplies the corresponding values for each entry in `coordinates`.
-    - `batched_lvls : optional<index>`
-      supplies the number of leading levels that are batched.
-
-    This operation can be used to materialize a sparse tensor from external
-    sources; e.g., when passing two numpy arrays from Python.
 
     Example:
 
@@ -95,30 +93,12 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
     //     of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
     //                   |0.0, 0.0, 0.0, 0.0|
     ```
-
-    If `batched_lvls` is provided, the operation materializes a batched sparse tensor.
-    Example:
-
-    ```mlir
-    %values      = arith.constant dense<[[ 1.1,   2.2,   3.3 ],
-                                         [ 1.2,   2.3,   0.0 ]]> : tensor<2x3xf64>
-    %coordinates = arith.constant dense<[[ [0],   [1],   [2] ],
-                                         [ [1],   [2],   [3] ]> : tensor<2x3x1xindex>
-    %st = sparse_tensor.pack %values, %coordinates batched_lvls=1
-        : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
-    // yields BCOO format |1.1, 2.2, 3.3, 0.0|
-    //      of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
     ```
   }];
 
-  let extraClassDeclaration = [{
-    /// Returns the number of leading levels that are batched.
-    unsigned getNumBatchedLvls();
-  }];
-
   let assemblyFormat =
-    "$values `,` $coordinates (`batched_lvls` `=` $batched_lvls^)? attr-dict"
-    "`:` type($values) `,` type($coordinates) `to` type($result)";
+    "$values `,` $levels attr-dict"
+    "`:` type($values) `,` type($levels) `to` type($result)";
 
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16f270fc464ee..e3f8d4cb9f96a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -66,13 +66,11 @@ void StorageLayout::foreachField(
   // Per-level storage.
   for (Level l = 0; l < end; l++) {
     const auto dlt = lvlTypes[l];
-    if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) {
+    if (isDLTWithPos(dlt)) {
       RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
+    }
+    if (isDLTWithCrd(dlt)) {
       RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
-    } else if (isSingletonDLT(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
-    } else {
-      assert(isDenseDLT(dlt)); // no fields
     }
   }
   // The values array.
@@ -786,6 +784,8 @@ static LogicalResult verifySparsifierGetterSetter(
   return success();
 }
 
+// DEPRECATED: This function is deprecated! Remove it after unpack supports
+// arbitrary sparse encoding.
 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
                                       SparseTensorType tensorTp,
                                       RankedTensorType valuesTp,
@@ -843,16 +843,83 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
   return success();
 }
 
+static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
+  switch (kind) {
+  case SparseTensorFieldKind::CrdMemRef:
+    return stt.getCrdType();
+  case SparseTensorFieldKind::PosMemRef:
+    return stt.getPosType();
+  case SparseTensorFieldKind::ValMemRef:
+    return stt.getElementType();
+  case SparseTensorFieldKind::StorageSpec:
+    return nullptr;
+  }
+  llvm_unreachable("Unrecognizable FieldKind");
+}
+
+static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
+                                      SparseTensorType stt,
+                                      RankedTensorType valTp,
+                                      TypeRange lvlTps) {
+  if (requiresStaticShape && !stt.hasStaticDimShape())
+    return op->emitError("the sparse-tensor must have static shape");
+  if (!stt.hasEncoding())
+    return op->emitError("the sparse-tensor must have an encoding attribute");
+  if (!stt.isIdentity())
+    return op->emitError("the sparse-tensor must have the identity mapping");
+
+  // Verifies the trailing COO.
+  Level cooStartLvl = getCOOStart(stt.getEncoding());
+  if (cooStartLvl < stt.getLvlRank()) {
+    // We only supports trailing COO for now, must be the last input.
+    auto cooTp = lvlTps.back().cast<ShapedType>();
+    // The coordinates should be in shape of <? x rank>
+    unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
+    if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
+      op->emitError("input/output trailing COO level-ranks don't match");
+    }
+  }
+
+  // Verifies that all types match.
+  StorageLayout layout(stt.getEncoding());
+  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
+    return op->emitError("inconsistent number of fields between input/output");
+
+  unsigned idx = 0;
+  bool misMatch = false;
+  layout.foreachField([&idx, &misMatch, stt, valTp,
+                       lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
+                               Level lvl, DimLevelType dlt) -> bool {
+    if (fKind == SparseTensorFieldKind::StorageSpec)
+      return true;
+
+    Type inputTp = nullptr;
+    if (fKind == SparseTensorFieldKind::ValMemRef) {
+      inputTp = valTp;
+    } else {
+      assert(fid == idx && stt.getLvlType(lvl) == dlt);
+      inputTp = lvlTps[idx++];
+    }
+    // The input element type and expected element type should match.
+    Type inpElemTp = inputTp.cast<TensorType>().getElementType();
+    Type expElemTp = getFieldElemType(stt, fKind);
+    if (inpElemTp != expElemTp) {
+      misMatch = true;
+      return false; // to terminate the iteration
+    }
+    return true;
+  });
+
+  if (misMatch)
+    return op->emitError("input/output element-types don't match");
+  return success();
+}
+
 LogicalResult PackOp::verify() {
   const auto valuesTp = getRankedTensorType(getValues());
-  const auto coordinatesTp = getRankedTensorType(getCoordinates());
+  const auto lvlsTp = getLevels().getTypes();
   const auto resTp = getSparseTensorType(getResult());
-  return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp,
-                          getBatchedLvlsAttr());
-}
-
-unsigned PackOp::getNumBatchedLvls() {
-  return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
+  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
 }
 
 LogicalResult UnpackOp::verify() {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6d747c00910fc..4f2e18f43c117 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1214,192 +1214,94 @@ class SparseNumberOfEntriesConverter
   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Query memSizes for the actually stored values.
+    // FIXME: the nse value computed in this way might be wrong when there is
+    // any "compressed-hi" level.
     rewriter.replaceOp(
         op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
     return success();
   }
 };
 
-static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
-                                             ArrayRef<unsigned> batchDimSzs,
-                                             Value posMemRef, unsigned nse,
-                                             PackOp op) {
-  SmallVector<Value> lbs, ubs, steps;
-  Value c0 = constantIndex(builder, loc, 0);
-  Value c1 = constantIndex(builder, loc, 1);
-  Value c2 = constantIndex(builder, loc, 2);
-  for (unsigned dimSz : batchDimSzs) {
-    lbs.push_back(c0);
-    ubs.push_back(constantIndex(builder, loc, dimSz));
-    steps.push_back(c1);
-  }
-  auto tensorType = op.getValues().getType();
-  auto memrefType =
-      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-  Value batV = builder.create<bufferization::ToMemrefOp>(loc, memrefType,
-                                                         op.getValues());
-  scf::buildLoopNest(
-      builder, loc, lbs, ubs, steps,
-      [&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
-                                               ValueRange ivs) {
-        // Linearize index variables
-        Value crd = linearize(builder, loc, ivs, ubs);
-        Value len = constantIndex(builder, loc, nse);
-        Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
-        SmallVector<Value> indices(ivs.begin(), ivs.end());
-        auto whileOp = builder.create<scf::WhileOp>(
-            loc, TypeRange{builder.getIndexType()}, ValueRange{len},
-            [&indices, c0, c1, batV](OpBuilder &builder, Location loc,
-                                     ValueRange vs) {
-              Value curLen = vs.front();
-              Value pred = builder.create<arith::CmpIOp>(
-                  loc, arith::CmpIPredicate::eq, curLen, c0);
-              auto ifOp = builder.create<scf::IfOp>(
-                  loc, TypeRange{builder.getI1Type()}, pred, true);
-              {
-                OpBuilder::InsertionGuard guard(builder);
-                // if len == 0.
-                builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-                builder.create<scf::YieldOp>(loc,
-                                             constantI1(builder, loc, false));
-                // Else branch.
-                builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-                indices.push_back(
-                    builder.create<arith::SubIOp>(loc, curLen, c1));
-                Value val = builder.create<memref::LoadOp>(loc, batV, indices);
-                indices.pop_back();
-                Value cont = builder.create<arith::CmpFOp>(
-                    loc, arith::CmpFPredicate::OEQ, val,
-                    constantZero(builder, loc, val.getType()));
-                builder.create<scf::YieldOp>(loc, cont);
-              }
-              builder.create<scf::ConditionOp>(loc, ifOp.getResults()[0], vs);
-            },
-            [c1](OpBuilder &builder, Location loc, ValueRange vs) {
-              // len --;
-              Value nxLen = builder.create<arith::SubIOp>(loc, vs.front(), c1);
-              builder.create<scf::YieldOp>(loc, nxLen);
-            });
-        len = whileOp.getResults()[0];
-        Value pHi = builder.create<arith::AddIOp>(loc, pLo, len);
-        // Stores position lower bound.
-        Value idx = builder.create<arith::MulIOp>(loc, crd, c2);
-        genStore(builder, loc, pLo, posMemRef, idx);
-        // Stores position upper bound.
-        idx = builder.create<arith::AddIOp>(loc, idx, c1);
-        genStore(builder, loc, pHi, posMemRef, idx);
-      });
-}
-
 struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(PackOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    const unsigned batchedLvls = op.getNumBatchedLvls();
-    unsigned nse = op.getValues().getType().getDimSize(batchedLvls);
+    Location loc = op.getLoc();
     const auto stt = getSparseTensorType(op.getResult());
-    assert(isCOOType(stt.getEncoding(), batchedLvls, true));
-
-    unsigned batchedCount = 1;
-    SmallVector<unsigned> batchDimSzs;
-    batchDimSzs.reserve(batchedLvls);
-    for (unsigned i = 0; i < batchedLvls; i++) {
-      // Should already be guaranteed by verifier.
-      assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
-      batchedCount *= stt.getDimShape()[i];
-      batchDimSzs.push_back(stt.getDimShape()[i]);
-    }
 
     SmallVector<Value> fields;
-    Location loc = op.getLoc();
 
     foreachFieldAndTypeInSparseTensor(
         stt,
-        [&rewriter, &fields, &op, &batchDimSzs, nse, batchedCount, stt,
+        [&rewriter, &fields, &op, &stt,
          loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
               Level /*lvl*/, DimLevelType dlt) -> bool {
           assert(fields.size() == fIdx);
-          Value field;
-          switch (fKind) {
-          case SparseTensorFieldKind::StorageSpec:
-            field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
-            break;
-          case SparseTensorFieldKind::PosMemRef: {
-            // TACO-style COO starts with a PosBuffer
-            const auto posTp = stt.getPosType();
-            if (isCompressedDLT(dlt)) {
-              auto memrefType = MemRefType::get({batchedCount + 1}, posTp);
-              field = rewriter.create<memref::AllocOp>(loc, memrefType);
-              Value c0 = constantIndex(rewriter, loc, 0);
-              genStore(rewriter, loc, c0, field, c0);
-              for (unsigned i = 1; i <= batchedCount; i++) {
-                // The postion memref will have values as
-                // [0, nse, 2 * nse, ..., batchedCount * nse]
-                Value idx = constantIndex(rewriter, loc, i);
-                Value val = constantIndex(rewriter, loc, nse * i);
-                genStore(rewriter, loc, val, field, idx);
-              }
-            } else {
-              assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
-              MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
-              field = rewriter.create<memref::AllocOp>(loc, posMemTp);
-              populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
-                                               field, nse, op);
-            }
-            break;
-          }
-          case SparseTensorFieldKind::CrdMemRef: {
-            auto tensorType = op.getCoordinates().getType();
-            auto memrefType = MemRefType::get(tensorType.getShape(),
-                                              tensorType.getElementType());
-            field = rewriter.create<bufferization::ToMemrefOp>(
-                op->getLoc(), memrefType, op.getCoordinates());
+          if (fKind == SparseTensorFieldKind::StorageSpec) {
+            fields.push_back(
+                SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
+          } else {
+            // Else simply takes the inputs.
+            Value field = fKind == SparseTensorFieldKind::ValMemRef
+                              ? op.getValues()
+                              : op.getLevels()[fIdx];
 
-            break;
-          }
-          case SparseTensorFieldKind::ValMemRef: {
-            auto tensorType = op.getValues().getType();
+            auto tensorType = field.getType().cast<RankedTensorType>();
             auto memrefType = MemRefType::get(tensorType.getShape(),
                                               tensorType.getElementType());
             field = rewriter.create<bufferization::ToMemrefOp>(
-                op->getLoc(), memrefType, op.getValues());
-            break;
-          }
-          }
-
-          assert(field);
-          if (auto memrefTp = dyn_cast<MemRefType>(field.getType());
-              memrefTp && memrefTp.getRank() > 1) {
-            ReassociationIndices reassociation;
-            for (int i = 0, e = memrefTp.getRank(); i < e; i++)
-              reassociation.push_back(i);
-            // Flattens the buffer to rank 1. The value buffer might need be
-            // collapsed as well due to batching.
-            field = rewriter.create<memref::CollapseShapeOp>(
-                loc, field, ArrayRef<ReassociationIndices>(reassociation));
-          }
-
-          if (fType != field.getType())
+                op->getLoc(), memrefType, field);
+            if (memrefType.getRank() > 1) {
+              // Flattens the buffer to rank 1.
+              auto reassoc = getReassociationForFlattening(memrefType);
+              field =
+                  rewriter.create<memref::CollapseShapeOp>(loc, field, reassoc);
+            }
             field = rewriter.create<memref::CastOp>(loc, fType, field);
-          fields.push_back(field);
-          // Returns true to continue the iteration.
+            fields.push_back(field);
+          }
           return true;
         });
 
     MutSparseTensorDescriptor desc(stt, fields);
-    auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
+    Value c1 = constantIndex(rewriter, loc, 1);
+    Value c2 = constantIndex(rewriter, loc, 2);
+    Value posBack = c1; // index to the last value in the postion array
+    Value memSize = c2; // memory size for current array
+    // Sets up SparseTensorSpecifier.
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
+      assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
+
       // FIXME: dim/lvl confusion!
-      const auto sh = stt.getDimShape()[lvl];
-      assert(!ShapedType::isDynamic(sh));
-      desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
-      if (lvl == 0)
-        desc.setPosMemSize(rewriter, loc, lvl, constantIndex(rewriter, loc, 2));
+      // Sets up the level size.
+      auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
+      desc.setLvlSize(rewriter, loc, lvl, lvlSize);
+
+      // Sets up the memory size by reading the last value in position array.
+      DimLevelType dlt = stt.getLvlType(lvl);
+      // Simply forwards the position index when this is a dense level.
+      if (isDenseDLT(dlt)) {
+        memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
+        posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
+        continue;
+      }
 
-      desc.setCrdMemSize(rewriter, loc, lvl, noe);
+      if (isDLTWithPos(dlt)) {
+        assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
+        if (isCompressedWithHiDLT(dlt)) {
+          memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
+          posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
+        }
+        desc.setPosMemSize(rewriter, loc, lvl, memSize);
+        // The last value in position array is the memory size for next level.
+        memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
+        posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
+      }
+      assert(isDLTWithCrd(dlt));
+      desc.setCrdMemSize(rewriter, loc, lvl, memSize);
     }
-    desc.setValMemSize(rewriter, loc, noe);
+    desc.setValMemSize(rewriter, loc, memSize);
 
     rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
     return success();
@@ -1568,10 +1470,8 @@ static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
 
 struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
   using OpConversionPattern::OpConversionPattern;
-  SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
-                          bool createDeallocs)
-      : OpConversionPattern(typeConverter, context),
-        createDeallocs(createDeallocs) {}
+  SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern(typeConverter, context) {}
 
   LogicalResult
   matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
@@ -1582,26 +1482,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
            desc.getFields().size() == 4); // specifier + pos + crds + values
     (void)srcTp;
-    auto logicRes = nBatched == 0
-                        ? genUnBatchedUnpackOp(op, desc, rewriter)
-                        : genBatchedUnpackOp(op, nBatched, desc, rewriter);
-    Value posBuf = desc.getPosMemRef(nBatched);
-
-    if (createDeallocs) {
-      // Unpack ends the lifetime of the sparse tensor. While the value array
-      // and coordinate array are unpacked and returned, the position array
-      // becomes useless and need to be freed (if user requests).
-      // FIXME: Depending on whether the tensor being unpacked is created by
-      // PackOp or not, we may or may not need to free other memref fields of
-      // the sparse tensor too (PackOp borrows value/coordinate buffer).
-      rewriter.create<memref::DeallocOp>(op.getLoc(), posBuf);
-    }
-
-    return logicRes;
+    return nBatched == 0 ? genUnBatchedUnpackOp(op, desc, rewriter)
+                         : genBatchedUnpackOp(op, nBatched, desc, rewriter);
   }
-
-private:
-  const bool createDeallocs;
 };
 
 struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
@@ -1755,11 +1638,11 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
 void mlir::populateSparseTensorCodegenPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     bool createSparseDeallocs, bool enableBufferInitialization) {
-  patterns.add<SparsePackOpConverter, SparseReturnConverter,
-               SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
-               SparseExtractSliceConverter, SparseTensorLoadConverter,
-               SparseExpandConverter, SparseCompressConverter,
-               SparseInsertConverter,
+  patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
+               SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+               SparseCastConverter, SparseExtractSliceConverter,
+               SparseTensorLoadConverter, SparseExpandConverter,
+               SparseCompressConverter, SparseInsertConverter,
                SparseSliceGetterOpConverter<ToSliceOffsetOp,
                                             StorageSpecifierKind::DimOffset>,
                SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1769,7 +1652,7 @@ void mlir::populateSparseTensorCodegenPatterns(
                SparseConvertConverter, SparseNewOpConverter,
                SparseNumberOfEntriesConverter>(typeConverter,
                                                patterns.getContext());
-  patterns.add<SparseTensorDeallocConverter, SparseUnpackOpConverter>(
+  patterns.add<SparseTensorDeallocConverter>(
       typeConverter, patterns.getContext(), createSparseDeallocs);
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
index a9ed5751ab67e..9b18394dee7e2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 namespace sparse_tensor {

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 27aee6b961970..3531400f75e81 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -8,86 +8,50 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], posWidth=32, crdWidth=32}>
 
-func.func @non_static_pack_ret(%values: tensor<6xf64>, %coordinates: tensor<6x1xi32>)
+func.func @non_static_pack_ret(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
                             -> tensor<?xf64, #SparseVector> {
   // expected-error at +1 {{the sparse-tensor must have static shape}}
-  %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<6xf64>, tensor<6x1xi32> to tensor<?xf64, #SparseVector>
+  %0 = sparse_tensor.pack %values, %pos, %coordinates
+     : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<?xf64, #SparseVector>
   return %0 : tensor<?xf64, #SparseVector>
 }
 
 // -----
 
-#DenseVector = #sparse_tensor.encoding<{lvlTypes = ["dense"], crdWidth=32}>
-
-func.func @invalid_pack_dense(%values: tensor<6xf64>, %coordinates: tensor<6x1xi32>)
-                            -> tensor<100xf64, #DenseVector> {
-  // expected-error at +1 {{the sparse-tensor must have a COO type}}
-  %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<6xf64>, tensor<6x1xi32> to tensor<100xf64, #DenseVector>
-  return %0 : tensor<100xf64, #DenseVector>
-}
-
-// -----
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], posWidth=32, crdWidth=32}>
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
-
-func.func @invalid_pack_data(%values: tensor<6x1xf64>, %coordinates: tensor<6x1xi32>)
-                            -> tensor<100xf64, #SparseVector> {
-  // expected-error at +1 {{values must have rank 1 + batched_lvls}}
-  %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<6x1xf64>, tensor<6x1xi32> to tensor<100xf64, #SparseVector>
-  return %0 : tensor<100xf64, #SparseVector>
-}
-
-// -----
-
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
-
-func.func @invalid_pack_type(%values: tensor<6xf64>, %coordinates: tensor<6x1xi32>)
+func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>)
                             -> tensor<100xf32, #SparseVector> {
   // expected-error at +1 {{input/output element-types don't match}}
-  %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<6xf64>, tensor<6x1xi32> to tensor<100xf32, #SparseVector>
+  %0 = sparse_tensor.pack %values, %pos, %coordinates
+     : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32> to tensor<100xf32, #SparseVector>
   return %0 : tensor<100xf32, #SparseVector>
 }
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed-nu", "singleton"], posWidth=32, crdWidth=32}>
 
-func.func @invalid_pack_type(%values: tensor<5xf64>, %coordinates: tensor<6x1xi32>)
-                            -> tensor<100xf64, #SparseVector> {
-  // expected-error at +1 {{values/coordinates number-of-elements don't match}}
-  %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<5xf64>, tensor<6x1xi32> to tensor<100xf64, #SparseVector>
-  return %0 : tensor<100xf64, #SparseVector>
+func.func @invalid_pack_type(%values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>)
+                            -> tensor<100x2xf64, #SparseVector> {
+  // expected-error at +1 {{input/output trailing COO level-ranks don't match}}
+  %0 = sparse_tensor.pack %values, %pos, %coordinates
+     : tensor<6xf64>, tensor<2xi32>, tensor<6x3xi32> to tensor<100x2xf64, #SparseVector>
+  return %0 : tensor<100x2xf64, #SparseVector>
 }
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#CSR = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed"], posWidth=32, crdWidth=32}>
 
-func.func @invalid_pack_type(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
-                            -> tensor<100xf64, #SparseVector> {
-  // expected-error at +1 {{input/output level-ranks don't match}}
+func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tensor<6xi32>)
+                                     -> tensor<2x100xf64, #CSR> {
+  // expected-error at +1 {{inconsistent number of fields between input/output}}
   %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<6xf64>, tensor<6x2xi32> to tensor<100xf64, #SparseVector>
-  return %0 : tensor<100xf64, #SparseVector>
-}
-
-// -----
-
-#BCOO = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed-hi"], crdWidth=32}>
-
-func.func @invalid_pack_batched(%values: tensor<2x6xf64>, %coordinates: tensor<3x6x1xi32>)
-                              -> tensor<2x100xf64, #BCOO> {
-  // expected-error at +1 {{values/coordinates batched level sizes don't match statically}}
-  %0 = sparse_tensor.pack %values, %coordinates batched_lvls=1
-     : tensor<2x6xf64>, tensor<3x6x1xi32> to tensor<2x100xf64, #BCOO>
-  return %0 : tensor<2x100xf64, #BCOO>
+     : tensor<6xf64>, tensor<6xi32> to tensor<2x100xf64, #CSR>
+  return %0 : tensor<2x100xf64, #CSR>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 817809d7fb8fc..41cc5e775c98c 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -13,39 +13,24 @@ func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
+#SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], posWidth=32, crdWidth=32}>
 
 // CHECK-LABEL: func @sparse_pack(
 // CHECK-SAME: %[[D:.*]]: tensor<6xf64>,
+// CHECK-SAME: %[[P:.*]]: tensor<2xi32>,
 // CHECK-SAME: %[[I:.*]]: tensor<6x1xi32>)
-//       CHECK: %[[R:.*]] = sparse_tensor.pack %[[D]], %[[I]]
+//       CHECK: %[[R:.*]] = sparse_tensor.pack %[[D]], %[[P]], %[[I]]
 //       CHECK: return %[[R]] : tensor<100xf64, #{{.*}}>
-func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x1xi32>)
+func.func @sparse_pack(%data: tensor<6xf64>, %pos: tensor<2xi32>, %index: tensor<6x1xi32>)
                             -> tensor<100xf64, #SparseVector> {
-  %0 = sparse_tensor.pack %data, %index : tensor<6xf64>, tensor<6x1xi32>
-                                       to tensor<100xf64, #SparseVector>
+  %0 = sparse_tensor.pack %data, %pos, %index : tensor<6xf64>, tensor<2xi32>, tensor<6x1xi32>
+                                             to tensor<100xf64, #SparseVector>
   return %0 : tensor<100xf64, #SparseVector>
 }
 
 // -----
 
-#BCOO = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed-hi"], crdWidth=32}>
-// CHECK-LABEL: func @sparse_pack_batched(
-// CHECK-SAME: %[[D:.*]]: tensor<2x6xf64>,
-// CHECK-SAME: %[[I:.*]]: tensor<2x6x1xi32>)
-//       CHECK: %[[R:.*]] = sparse_tensor.pack %[[D]], %[[I]] batched_lvls = 1
-//       CHECK: return %[[R]] : tensor<2x100xf64, #{{.*}}>
-func.func @sparse_pack_batched(%values: tensor<2x6xf64>, %coordinates: tensor<2x6x1xi32>)
-                            -> tensor<2x100xf64, #BCOO> {
-  %0 = sparse_tensor.pack %values, %coordinates batched_lvls=1
-     : tensor<2x6xf64>, tensor<2x6x1xi32> to tensor<2x100xf64, #BCOO>
-  return %0 : tensor<2x100xf64, #BCOO>
-}
-
-// -----
-
 #SparseVector = #sparse_tensor.encoding<{lvlTypes = ["compressed"], crdWidth=32}>
-
 // CHECK-LABEL: func @sparse_unpack(
 //  CHECK-SAME: %[[T:.*]]: tensor<100xf64, #
 //       CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]]

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 3939b53bc42a7..1d948cbd604fd 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -7,34 +7,32 @@
 
 // CHECK-LABEL:   func.func @sparse_pack(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<6xf64>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<6x2xi32>)
-// CHECK-DAG:       %[[VAL_2:.*]] = memref.alloc() : memref<2xindex>
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:       memref.store %[[VAL_3]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<2xindex>
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 6 : index
-// CHECK-DAG:       memref.store %[[VAL_5]], %[[VAL_2]]{{\[}}%[[VAL_4]]] : memref<2xindex>
-// CHECK:           %[[VAL_6:.*]] = memref.cast %[[VAL_2]] : memref<2xindex> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
-// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_7]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
-// CHECK:           %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<12xi32> to memref<?xi32>
-// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
-// CHECK:           %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<6xf64> to memref<?xf64>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.storage_specifier.init
-// CHECK:           %[[VAL_13:.*]] = arith.constant 100 : index
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]]  lvl_sz at 0 with %[[VAL_13]]
-// CHECK:           %[[VAL_15:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_15]]
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_16]]  crd_mem_sz at 0 with %[[VAL_5]]
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<2xindex>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<6x2xi32>)
+// CHECK-DAG:       %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
+// CHECK-DAG:       %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
+// CHECK-DAG:       %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK-DAG:       %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK-DAG:       %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  lvl_sz at 0 with %[[VAL_13]]
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_12]]
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_16]]
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  lvl_sz at 1 with %[[VAL_13]]
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  crd_mem_sz at 1 with %[[VAL_5]]
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  val_mem_sz with %[[VAL_5]]
-// CHECK:           return %[[VAL_6]], %[[VAL_9]], %[[VAL_11]], %[[VAL_20]]
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  crd_mem_sz at 1 with %[[VAL_16]]
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  val_mem_sz with %[[VAL_16]]
+// CHECK:           return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]]
 // CHECK:         }
-func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
+func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
                     -> tensor<100x100xf64, #COO> {
-  %0 = sparse_tensor.pack %values, %coordinates
-     : tensor<6xf64>, tensor<6x2xi32> to tensor<100x100xf64, #COO>
+  %0 = sparse_tensor.pack %values, %pos, %coordinates
+     : tensor<6xf64>, tensor<2xindex>, tensor<6x2xi32> to tensor<100x100xf64, #COO>
   return %0 : tensor<100x100xf64, #COO>
 }
 
@@ -68,7 +66,6 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
 // CHECK:           %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
 // CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
 // CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier
-// CHECK:       memref.dealloc %[[VAL_0]] : memref<?xindex>
 // CHECK:           return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
 // CHECK:         }
 func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index bc1b24ddea6a7..4b44436b6da54 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -31,6 +31,12 @@
   crdWidth = 32
 }>
 
+#CSR = #sparse_tensor.encoding<{
+  lvlTypes = [ "dense", "compressed" ],
+  posWidth = 32,
+  crdWidth = 32
+}>
+
 #BCOO = #sparse_tensor.encoding<{
   lvlTypes = [ "dense", "compressed-hi-nu", "singleton" ]
 }>
@@ -50,38 +56,59 @@ module {
        [  1.0,  2.0,  3.0]
     > : tensor<3xf64>
 
+    %pos = arith.constant dense<
+       [0, 3]
+    > : tensor<2xindex>
+
     %index = arith.constant dense<
        [[  1,  2],
         [  5,  6],
         [  7,  8]]
     > : tensor<3x2xindex>
 
+    %pos32 = arith.constant dense<
+       [0, 3]
+    > : tensor<2xi32>
+
     %index32 = arith.constant dense<
        [[  1,  2],
         [  5,  6],
         [  7,  8]]
     > : tensor<3x2xi32>
 
-    %s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex>
+    %s4 = sparse_tensor.pack %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
                                           to tensor<10x10xf64, #SortedCOO>
-    %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32>
+    %s5= sparse_tensor.pack %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
                                            to tensor<10x10xf64, #SortedCOOI32>
 
+    %csr_pos32 = arith.constant dense<
+       [0, 1, 3]
+    > : tensor<3xi32>
+
+    %csr_index32 = arith.constant dense<
+       [1, 0, 1]
+    > : tensor<3xi32>
+    %csr= sparse_tensor.pack %data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
+                                           to tensor<2x2xf64, #CSR>
+
     %bdata = arith.constant dense<
-       [[  1.0,  2.0,  3.0],
-        [  4.0,  5.0,  0.0]]
-    > : tensor<2x3xf64>
+       [  1.0,  2.0,  3.0,  4.0,  5.0,  0.0]
+    > : tensor<6xf64>
+
+    %bpos = arith.constant dense<
+       [0, 3, 3, 5]
+    > : tensor<4xindex>
 
     %bindex = arith.constant dense<
-      [[[  1,  2],
-        [  5,  6],
-        [  7,  8]],
-       [[  2,  3],
-        [  4,  2],
-        [ 10, 10]]]
-    > : tensor<2x3x2xindex>
-    %bs = sparse_tensor.pack %bdata, %bindex batched_lvls = 1 :
-          tensor<2x3xf64>, tensor<2x3x2xindex> to tensor<2x10x10xf64, #BCOO>
+      [[  1,  2],
+       [  5,  6],
+       [  7,  8],
+       [  2,  3],
+       [  4,  2],
+       [ 10, 10]]
+    > : tensor<6x2xindex>
+    %bs = sparse_tensor.pack %bdata, %bpos, %bindex :
+          tensor<6xf64>, tensor<4xindex>,  tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
 
     // CHECK:1
     // CHECK-NEXT:2
@@ -119,6 +146,25 @@ module {
         vector.print %v: f64
      }
 
+    // CHECK-NEXT:0
+    // CHECK-NEXT:1
+    // CHECK-NEXT:1
+    //
+    // CHECK-NEXT:1
+    // CHECK-NEXT:0
+    // CHECK-NEXT:2
+    //
+    // CHECK-NEXT:1
+    // CHECK-NEXT:1
+    // CHECK-NEXT:3
+    sparse_tensor.foreach in %csr : tensor<2x2xf64, #CSR> do {
+      ^bb0(%1: index, %2: index, %v: f64) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f64
+     }
+
+
     // CHECK-NEXT:1
     // CHECK-NEXT:2
     // CHECK-NEXT:3
@@ -164,6 +210,7 @@ module {
 
     %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
                                          to tensor<3xf64>, tensor<3x2xindex>, index
+
     // FIXME: This should be freed by one-shot-bufferization.
     bufferization.dealloc_tensor %bd : tensor<2x3xf64>
     bufferization.dealloc_tensor %bi : tensor<2x3x2xindex>


        


More information about the Mlir-commits mailing list