[Mlir-commits] [mlir] [mlir][sparse] unify support of (dis)assemble between direct IR/lib path (PR #71880)

Aart Bik llvmlistbot at llvm.org
Thu Nov 9 16:00:34 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/71880

Note that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same.

Generalizing the ops is still TBD.

>From b470eabac2a2aadf93da80a854adb6d810104fca Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 9 Nov 2023 15:55:49 -0800
Subject: [PATCH] [mlir][sparse] unify support of (dis)assemble between direct
 IR/lib path

Note that the (dis)assemble operations still make some simplfying
assumptions (e.g. trailing 2-D COO in AoS format) but now at least
both the direct IR and support library path behave exactly the same.

Generalizing the ops is still TBD.
---
 .../ExecutionEngine/SparseTensor/Storage.h    |  65 +++--
 .../SparseTensor/Transforms/CodegenUtils.cpp  |  11 +
 .../SparseTensor/Transforms/CodegenUtils.h    |   4 +
 .../Transforms/SparseTensorCodegen.cpp        |  13 -
 .../Transforms/SparseTensorConversion.cpp     | 222 +++++++++++++-----
 .../Transforms/SparseTensorPasses.cpp         |   3 +-
 .../Dialect/SparseTensor/CPU/sparse_pack.mlir | 116 ++++++---
 .../SparseTensor/CPU/sparse_pack_libgen.mlir  | 165 -------------
 8 files changed, 304 insertions(+), 295 deletions(-)
 delete mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 460549726356370..3382e293d123746 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -301,8 +301,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       uint64_t lvlRank = getLvlRank();
       uint64_t valIdx = 0;
       // Linearize the address.
-      for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
-        valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+      for (uint64_t l = 0; l < lvlRank; l++)
+        valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
       values[valIdx] = val;
       return;
     }
@@ -472,9 +472,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
     if (isCompressedLvl(l))
       return positions[l][parentSz];
-    if (isSingletonLvl(l))
-      return parentSz; // New size is same as the parent.
-    // TODO: support levels assignment for loose/2:4?
+    if (isLooseCompressedLvl(l))
+      return positions[l][2 * parentSz - 1];
+    if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+      return parentSz; // new size same as the parent
     assert(isDenseLvl(l));
     return parentSz * getLvlSize(l);
   }
@@ -766,40 +767,59 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
                           dim2lvl, lvl2dim) {
+  // Note that none of the buffers cany be reused because ownership
+  // of the memory passed from clients is not necessarily transferred.
+  // Therefore, all data is copied over into a new SparseTensorStorage.
+  //
+  // TODO: this needs to be generalized to all formats AND
+  //       we need a proper audit of e.g. double compressed
+  //       levels where some are not filled
+  //
   uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
   for (uint64_t l = 0; l < lvlRank; l++) {
-    if (!isUniqueLvl(l) && isCompressedLvl(l)) {
-      // A `compressed_nu` level marks the start of trailing COO start level.
-      // Since the coordinate buffer used for trailing COO are passed in as AoS
-      // scheme, and SparseTensorStorage uses a SoA scheme, we can not simply
-      // copy the value from the provided buffers.
+    if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
+      // A `(loose)compressed_nu` level marks the start of trailing COO
+      // start level. Since the coordinate buffer used for trailing COO
+      // is passed in as AoS scheme and SparseTensorStorage uses a SoA
+      // scheme, we cannot simply copy the value from the provided buffers.
       trailCOOLen = lvlRank - l;
       break;
     }
-    assert(!isSingletonLvl(l) &&
-           "Singleton level not following a compressed_nu level");
-    if (isCompressedLvl(l)) {
+    if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
       P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
       C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
-      // Copies the lvlBuf into the vectors. The buffer can not be simply reused
-      // because the memory passed from users is not necessarily allocated on
-      // heap.
-      positions[l].assign(posPtr, posPtr + parentSz + 1);
-      coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+      if (!isLooseCompressedLvl(l)) {
+        positions[l].assign(posPtr, posPtr + parentSz + 1);
+        coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+      } else {
+        positions[l].assign(posPtr, posPtr + 2 * parentSz);
+        coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
+      }
+    } else if (isSingletonLvl(l)) {
+      assert(0 && "general singleton not supported yet");
+    } else if (is2OutOf4Lvl(l)) {
+      assert(0 && "2Out4 not supported yet");
     } else {
-      // TODO: support levels assignment for loose/2:4?
       assert(isDenseLvl(l));
     }
     parentSz = assembledSize(parentSz, l);
   }
 
+  // Handle Aos vs. SoA mismatch for COO.
   if (trailCOOLen != 0) {
     uint64_t cooStartLvl = lvlRank - trailCOOLen;
-    assert(!isUniqueLvl(cooStartLvl) && isCompressedLvl(cooStartLvl));
+    assert(!isUniqueLvl(cooStartLvl) &&
+           (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
     P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
     C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
-    positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
-    P crdLen = positions[cooStartLvl][parentSz];
+    P crdLen;
+    if (!isLooseCompressedLvl(cooStartLvl)) {
+      positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
+      crdLen = positions[cooStartLvl][parentSz];
+    } else {
+      positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
+      crdLen = positions[cooStartLvl][2 * parentSz - 1];
+    }
     for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
       coordinates[l].resize(crdLen);
       for (uint64_t n = 0; n < crdLen; n++) {
@@ -809,6 +829,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     parentSz = assembledSize(parentSz, cooStartLvl);
   }
 
+  // Copy the values buffer.
   V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
   values.assign(valPtr, valPtr + parentSz);
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..8e2c2cd6dad7b19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -163,6 +163,17 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
   return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
 }
 
+Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
+                                       Value elem, Type dstTp) {
+  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+    // Scalars can only be converted to 0-ranked tensors.
+    assert(rtp.getRank() == 0);
+    elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
+    return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+  }
+  return sparse_tensor::genCast(builder, loc, elem, dstTp);
+}
+
 Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
                                   Value s) {
   Value load = builder.create<memref::LoadOp>(loc, mem, s);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1f53f3525203c70..d3b0889b71b514c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -142,6 +142,10 @@ class FuncCallOrInlineGenerator {
 /// Add type casting between arith and index types when needed.
 Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
 
+/// Add conversion from scalar to given type (possibly a 0-rank tensor).
+Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+                        Type dstTp);
+
 /// Generates a pointer/index load from the sparse storage scheme. Narrower
 /// data types need to be zero extended before casting the value into the
 /// index type used for looping and indexing.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 08c38394a46343a..888f513be2e4dc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -435,19 +435,6 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
   return reassociation;
 }
 
-/// Generates scalar to tensor cast.
-static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
-                               Type dstTp) {
-  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
-    // Scalars can only be converted to 0-ranked tensors.
-    if (rtp.getRank() != 0)
-      return nullptr;
-    elem = genCast(builder, loc, elem, rtp.getElementType());
-    return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
-  }
-  return genCast(builder, loc, elem, dstTp);
-}
-
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4fe9c59d8c320a7..e629133171e15dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
   return std::nullopt;
 }
 
-/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
-static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
-                                          StringRef name, TypeRange resultType,
-                                          ValueRange operands,
-                                          EmitCInterface emitCInterface) {
-  auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
-                    emitCInterface);
-  return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
-                                                   operands);
-}
-
 /// Generates call to lookup a level-size.  N.B., this only generates
 /// the raw function call, and therefore (intentionally) does not perform
 /// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ class NewCallParams final {
 };
 
 /// Generates a call to obtain the values array.
-static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
-                           ValueRange ptr) {
-  SmallString<15> name{"sparseValues",
-                       primaryTypeFunctionSuffix(tp.getElementType())};
-  return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+static Value genValuesCall(OpBuilder &builder, Location loc,
+                           SparseTensorType stt, Value ptr) {
+  auto eltTp = stt.getElementType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
+  SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
+      .getResult(0);
+}
+
+/// Generates a call to obtain the positions array.
+static Value genPositionsCall(OpBuilder &builder, Location loc,
+                              SparseTensorType stt, Value ptr, Level l) {
+  Type posTp = stt.getPosType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
+  Value lvl = constantIndex(builder, loc, l);
+  SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+                        EmitCInterface::On)
+      .getResult(0);
+}
+
+/// Generates a call to obtain the coordindates array.
+static Value genCoordinatesCall(OpBuilder &builder, Location loc,
+                                SparseTensorType stt, Value ptr, Level l) {
+  Type crdTp = stt.getCrdType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+  Value lvl = constantIndex(builder, loc, l);
+  SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+                        EmitCInterface::On)
       .getResult(0);
 }
 
@@ -391,7 +405,7 @@ class SparseTensorAllocConverter
     SmallVector<Value> dimSizes;
     dimSizes.reserve(dimRank);
     unsigned operandCtr = 0;
-    for (Dimension d = 0; d < dimRank; ++d) {
+    for (Dimension d = 0; d < dimRank; d++) {
       dimSizes.push_back(
           stt.isDynamicDim(d)
               ? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
     dimSizes.reserve(dimRank);
     auto shape = op.getType().getShape();
     unsigned operandCtr = 0;
-    for (Dimension d = 0; d < dimRank; ++d) {
+    for (Dimension d = 0; d < dimRank; d++) {
       dimSizes.push_back(stt.isDynamicDim(d)
                              ? adaptor.getOperands()[operandCtr++]
                              : constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ class SparseTensorToPositionsConverter
   LogicalResult
   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resTp = op.getType();
-    Type posTp = cast<ShapedType>(resTp).getElementType();
-    SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
-    Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
-    replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
-                          EmitCInterface::On);
+    auto stt = getSparseTensorType(op.getTensor());
+    auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
+                                 adaptor.getTensor(), op.getLevel());
+    rewriter.replaceOp(op, poss);
     return success();
   }
 };
@@ -505,29 +517,14 @@ class SparseTensorToCoordinatesConverter
   LogicalResult
   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // TODO: use `SparseTensorType::getCrdType` instead.
-    Type resType = op.getType();
-    const Type crdTp = cast<ShapedType>(resType).getElementType();
-    SmallString<19> name{"sparseCoordinates",
-                         overheadTypeFunctionSuffix(crdTp)};
-    Location loc = op->getLoc();
-    Value lvl = constantIndex(rewriter, loc, op.getLevel());
-
-    // The function returns a MemRef without a layout.
-    MemRefType callRetType = get1DMemRefType(crdTp, false);
-    SmallVector<Value> operands{adaptor.getTensor(), lvl};
-    auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
-                      operands, EmitCInterface::On);
-    Value callRet =
-        rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
-            .getResult(0);
-
+    auto stt = getSparseTensorType(op.getTensor());
+    auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
+                                   adaptor.getTensor(), op.getLevel());
     // Cast the MemRef type to the type expected by the users, though these
     // two types should be compatible at runtime.
-    if (resType != callRetType)
-      callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
-    rewriter.replaceOp(op, callRet);
-
+    if (op.getType() != crds.getType())
+      crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+    rewriter.replaceOp(op, crds);
     return success();
   }
 };
@@ -539,9 +536,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
   LogicalResult
   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto resType = cast<ShapedType>(op.getType());
-    rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
-                                         adaptor.getOperands()));
+    auto stt = getSparseTensorType(op.getTensor());
+    auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+    rewriter.replaceOp(op, vals);
     return success();
   }
 };
@@ -554,13 +551,11 @@ class SparseNumberOfEntriesConverter
   LogicalResult
   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
     // Query values array size for the actually stored values size.
-    Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
-    auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
-    Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
-    rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
-                                               constantIndex(rewriter, loc, 0));
+    auto stt = getSparseTensorType(op.getTensor());
+    auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+    auto zero = constantIndex(rewriter, op.getLoc(), 0);
+    rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
     return success();
   }
 };
@@ -701,7 +696,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
-/// Sparse conversion rule for the sparse_tensor.pack operator.
+/// Sparse conversion rule for the sparse_tensor.assemble operator.
 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
                   ConversionPatternRewriter &rewriter) const override {
     const Location loc = op->getLoc();
     const auto dstTp = getSparseTensorType(op.getResult());
-    // AssembleOps always returns a static shaped tensor result.
     assert(dstTp.hasStaticDimShape());
     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
+    // Use a library method to transfer the external buffers from
+    // clients to the internal SparseTensorStorage. Since we cannot
+    // assume clients transfer ownership of the buffers, this method
+    // will copy all data over into a new SparseTensorStorage.
     Value dst =
         NewCallParams(rewriter, loc)
             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
   }
 };
 
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+    : public OpConversionPattern<DisassembleOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We simply expose the buffers to the external client. This
+    // assumes the client only reads the buffers (usually copying it
+    // to the external data structures, such as numpy arrays).
+    Location loc = op->getLoc();
+    auto stt = getSparseTensorType(op.getTensor());
+    SmallVector<Value> retVal;
+    SmallVector<Value> retLen;
+    // Get the values buffer first.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+    // Then get the positions and coordinates buffers.
+    const Level lvlRank = stt.getLvlRank();
+    Level trailCOOLen = 0;
+    for (Level l = 0; l < lvlRank; l++) {
+      if (!stt.isUniqueLvl(l) &&
+          (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+        // A `(loose)compressed_nu` level marks the start of trailing COO
+        // start level. Since the target coordinate buffer used for trailing
+        // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+        // scheme, we cannot simply use the internal buffers.
+        trailCOOLen = lvlRank - l;
+        break;
+      }
+      if (stt.isWithPos(l)) {
+        auto poss =
+            genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+        auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        retVal.push_back(poss);
+        retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+      }
+      if (stt.isWithCrd(l)) {
+        auto crds =
+            genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+        auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
+        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        retVal.push_back(crds);
+        retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
+      }
+    }
+    // Handle AoS vs. SoA mismatch for COO.
+    if (trailCOOLen != 0) {
+      uint64_t cooStartLvl = lvlRank - trailCOOLen;
+      assert(!stt.isUniqueLvl(cooStartLvl) &&
+             (stt.isCompressedLvl(cooStartLvl) ||
+              stt.isLooseCompressedLvl(cooStartLvl)));
+      // Positions.
+      auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
+                                   cooStartLvl);
+      auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+      auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      retVal.push_back(poss);
+      retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+      // Coordinates, copied over with:
+      //    for (i = 0; i < crdLen; i++)
+      //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
+      auto buf =
+          genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+      auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
+                                      cooStartLvl);
+      auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
+                                      cooStartLvl + 1);
+      auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
+      auto two = constantIndex(rewriter, loc, 2);
+      auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
+      Type indexType = rewriter.getIndexType();
+      auto zero = constantZero(rewriter, loc, indexType);
+      auto one = constantOne(rewriter, loc, indexType);
+      scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
+      auto idx = forOp.getInductionVar();
+      rewriter.setInsertionPointToStart(forOp.getBody());
+      auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
+      auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
+      SmallVector<Value> args;
+      args.push_back(idx);
+      args.push_back(zero);
+      rewriter.create<memref::StoreOp>(loc, c0, buf, args);
+      args[1] = one;
+      rewriter.create<memref::StoreOp>(loc, c1, buf, args);
+      rewriter.setInsertionPointAfter(forOp);
+      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      retVal.push_back(buf);
+      retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
+    }
+    // Converts MemRefs back to Tensors.
+    assert(retVal.size() + retLen.size() == op.getNumResults());
+    for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
+      auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
+      retVal[i] =
+          rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
+    }
+    // Appends the actual memory length used in each buffer returned.
+    retVal.append(retLen.begin(), retLen.end());
+    rewriter.replaceOp(op, retVal);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -752,5 +859,6 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
            SparseTensorLoadConverter, SparseTensorInsertConverter,
            SparseTensorExpandConverter, SparseTensorCompressConverter,
-           SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
+           SparseTensorAssembleConverter, SparseTensorDisassembleConverter>(
+          typeConverter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index e1cbf3482708ad0..10ebfa7922088a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -198,7 +198,8 @@ struct SparseTensorConversionPass
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
     target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
-                      linalg::YieldOp, tensor::ExtractOp>();
+                      linalg::YieldOp, tensor::ExtractOp,
+                      tensor::FromElementsOp>();
     target.addLegalDialect<
         arith::ArithDialect, bufferization::BufferizationDialect,
         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index a2f93614590f106..840e3c97ae28843 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -3,7 +3,7 @@
 //
 // Set-up that's shared across all tests in this directory. In principle, this
 // config could be moved to lit.local.cfg. However, there are downstream users that
-//  do not use these LIT config files. Hence why this is kept inline.
+// do not use these LIT config files. Hence why this is kept inline.
 //
 // DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
 // DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
@@ -17,15 +17,19 @@
 // DEFINE: %{env} =
 //--------------------------------------------------------------------------------------------------
 
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
 // REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
 // RUN: %{compile} | %{run} | FileCheck %s
 //
-// Do the same run, but now with VLA vectorization.
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=4
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
-// TODO: support sparse_tensor.disassemble on libgen path.
-
 #SortedCOO = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
 }>
@@ -54,11 +58,13 @@ module {
     %c0 = arith.constant 0 : index
     %f0 = arith.constant 0.0 : f64
     %i0 = arith.constant 0 : i32
+
     //
-    // Initialize a 3-dim dense tensor.
+    // Setup COO.
     //
+
     %data = arith.constant dense<
-       [  1.0,  2.0,  3.0]
+       [ 1.0,  2.0,  3.0 ]
     > : tensor<3xf64>
 
     %pos = arith.constant dense<
@@ -83,12 +89,16 @@ module {
 
     %s4 = sparse_tensor.assemble %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
                                           to tensor<10x10xf64, #SortedCOO>
-    %s5= sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
-                                           to tensor<10x10xf64, #SortedCOOI32>
+    %s5 = sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
+                                          to tensor<10x10xf64, #SortedCOOI32>
+
+    //
+    // Setup CSR.
+    //
 
     %csr_data = arith.constant dense<
-       [  1.0,  2.0,  3.0,  4.0]
-    > : tensor<4xf64>
+       [ 1.0,  2.0,  3.0 ]
+    > : tensor<3xf64>
 
     %csr_pos32 = arith.constant dense<
        [0, 1, 3]
@@ -97,12 +107,16 @@ module {
     %csr_index32 = arith.constant dense<
        [1, 0, 1]
     > : tensor<3xi32>
-    %csr= sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
+    %csr = sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
                                            to tensor<2x2xf64, #CSR>
 
+    //
+    // Setup BCOO.
+    //
+
     %bdata = arith.constant dense<
-       [  1.0,  2.0,  3.0,  4.0,  5.0,  0.0]
-    > : tensor<6xf64>
+       [ 1.0,  2.0,  3.0,  4.0,  5.0 ]
+    > : tensor<5xf64>
 
     %bpos = arith.constant dense<
        [0, 3, 3, 5]
@@ -116,10 +130,15 @@ module {
        [  4,  2],
        [ 10, 10]]
     > : tensor<6x2xindex>
+
     %bs = sparse_tensor.assemble %bdata, %bpos, %bindex :
-          tensor<6xf64>, tensor<4xindex>,  tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
+          tensor<5xf64>, tensor<4xindex>,  tensor<6x2xindex> to tensor<2x10x10xf64, #BCOO>
 
-    // CHECK:1
+    //
+    // Verify results.
+    //
+
+    // CHECK:     1
     // CHECK-NEXT:2
     // CHECK-NEXT:1
     //
@@ -135,7 +154,7 @@ module {
         vector.print %1: index
         vector.print %2: index
         vector.print %v: f64
-     }
+    }
 
     // CHECK-NEXT:1
     // CHECK-NEXT:2
@@ -153,7 +172,7 @@ module {
         vector.print %1: index
         vector.print %2: index
         vector.print %v: f64
-     }
+    }
 
     // CHECK-NEXT:0
     // CHECK-NEXT:1
@@ -171,32 +190,43 @@ module {
         vector.print %1: index
         vector.print %2: index
         vector.print %v: f64
-     }
-
-    %d_csr = tensor.empty() : tensor<4xf64>
-    %p_csr = tensor.empty() : tensor<3xi32>
-    %i_csr = tensor.empty() : tensor<3xi32>
-    %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
-                 outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
-                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
-
-    // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
-    %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
-    vector.print %vd_csr : vector<4xf64>
+    }
 
+    // CHECK-NEXT:0
+    // CHECK-NEXT:1
+    // CHECK-NEXT:2
+    // CHECK-NEXT:1
+    //
+    // CHECK-NEXT:0
+    // CHECK-NEXT:5
+    // CHECK-NEXT:6
+    // CHECK-NEXT:2
+    //
+    // CHECK-NEXT:0
+    // CHECK-NEXT:7
+    // CHECK-NEXT:8
+    // CHECK-NEXT:3
+    //
     // CHECK-NEXT:1
     // CHECK-NEXT:2
     // CHECK-NEXT:3
+    // CHECK-NEXT:4
     //
+    // CHECK-NEXT:1
     // CHECK-NEXT:4
+    // CHECK-NEXT:2
     // CHECK-NEXT:5
-    //
-    // Make sure the trailing zeros are not traversed.
-    // CHECK-NOT: 0
     sparse_tensor.foreach in %bs : tensor<2x10x10xf64, #BCOO> do {
       ^bb0(%0: index, %1: index, %2: index, %v: f64) :
+        vector.print %0: index
+        vector.print %1: index
+        vector.print %2: index
         vector.print %v: f64
-     }
+    }
+
+    //
+    // Verify disassemble operations.
+    //
 
     %od = tensor.empty() : tensor<3xf64>
     %op = tensor.empty() : tensor<2xi32>
@@ -213,6 +243,16 @@ module {
     %vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
     vector.print %vi : vector<3x2xi32>
 
+    %d_csr = tensor.empty() : tensor<4xf64>
+    %p_csr = tensor.empty() : tensor<3xi32>
+    %i_csr = tensor.empty() : tensor<3xi32>
+    %rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
+                 outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
+                 -> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
+
+    // CHECK-NEXT: ( 1, 2, 3 )
+    %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<3xf64>
+    vector.print %vd_csr : vector<3xf64>
 
     %bod = tensor.empty() : tensor<6xf64>
     %bop = tensor.empty() : tensor<4xindex>
@@ -221,15 +261,17 @@ module {
                     outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
                     -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
 
-    // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
-    %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
-    vector.print %vbd : vector<6xf64>
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5 )
+    %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<5xf64>
+    vector.print %vbd : vector<5xf64>
+
     // CHECK-NEXT: 5
     vector.print %ld : index
 
     // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ), ( 2, 3 ), ( 4, 2 ), ( {{.*}}, {{.*}} ) )
     %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
     vector.print %vbi : vector<6x2xindex>
+
     // CHECK-NEXT: 10
     %si = tensor.extract %li[] : tensor<i64>
     vector.print %si : i64
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir
deleted file mode 100644
index 6540c950ab675b0..000000000000000
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir
+++ /dev/null
@@ -1,165 +0,0 @@
-//--------------------------------------------------------------------------------------------------
-// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
-//
-// Set-up that's shared across all tests in this directory. In principle, this
-// config could be moved to lit.local.cfg. However, there are downstream users that
-//  do not use these LIT config files. Hence why this is kept inline.
-//
-// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
-// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
-// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
-// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
-// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
-// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
-// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
-//
-// DEFINE: %{env} =
-//--------------------------------------------------------------------------------------------------
-
-// RUN: %{compile} | %{run} | FileCheck %s
-//
-// Do the same run, but now with VLA vectorization.
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=true vl=4
-// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-
-// TODO: This is considered to be a short-living tests and should be merged with sparse_pack.mlir
-// after sparse_tensor.disassemble is supported on libgen path.
-
-#SortedCOO = #sparse_tensor.encoding<{
-  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
-}>
-
-#SortedCOOI32 = #sparse_tensor.encoding<{
-  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),
-  posWidth = 32,
-  crdWidth = 32
-}>
-
-#CSR = #sparse_tensor.encoding<{
-  map = (d0, d1) -> (d0 : dense, d1 : compressed),
-  posWidth = 32,
-  crdWidth = 32
-}>
-
-// TODO: "loose_compressed" is not supported by libgen path.
-// #BCOO = #sparse_tensor.encoding<{
-//   map = (d0, d1, d2) -> (d0 : dense, d1 : compressed(nonunique, high), d2 : singleton)
-//}>
-
-module {
-  //
-  // Main driver.
-  //
-  func.func @entry() {
-    %c0 = arith.constant 0 : index
-    %f0 = arith.constant 0.0 : f64
-    %i0 = arith.constant 0 : i32
-    //
-    // Initialize a 3-dim dense tensor.
-    //
-    %data = arith.constant dense<
-       [  1.0,  2.0,  3.0]
-    > : tensor<3xf64>
-
-    %pos = arith.constant dense<
-       [0, 3]
-    > : tensor<2xindex>
-
-    %index = arith.constant dense<
-       [[  1,  2],
-        [  5,  6],
-        [  7,  8]]
-    > : tensor<3x2xindex>
-
-    %pos32 = arith.constant dense<
-       [0, 3]
-    > : tensor<2xi32>
-
-    %index32 = arith.constant dense<
-       [[  1,  2],
-        [  5,  6],
-        [  7,  8]]
-    > : tensor<3x2xi32>
-
-    %s4 = sparse_tensor.assemble %data, %pos, %index : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>
-                                          to tensor<10x10xf64, #SortedCOO>
-    %s5= sparse_tensor.assemble %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
-                                           to tensor<10x10xf64, #SortedCOOI32>
-
-    %csr_data = arith.constant dense<
-       [  1.0,  2.0,  3.0,  4.0]
-    > : tensor<4xf64>
-
-    %csr_pos32 = arith.constant dense<
-       [0, 1, 3]
-    > : tensor<3xi32>
-
-    %csr_index32 = arith.constant dense<
-       [1, 0, 1]
-    > : tensor<3xi32>
-    %csr= sparse_tensor.assemble %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
-                                           to tensor<2x2xf64, #CSR>
-
-    // CHECK:1
-    // CHECK-NEXT:2
-    // CHECK-NEXT:1
-    //
-    // CHECK-NEXT:5
-    // CHECK-NEXT:6
-    // CHECK-NEXT:2
-    //
-    // CHECK-NEXT:7
-    // CHECK-NEXT:8
-    // CHECK-NEXT:3
-    sparse_tensor.foreach in %s4 : tensor<10x10xf64, #SortedCOO> do {
-      ^bb0(%1: index, %2: index, %v: f64) :
-        vector.print %1: index
-        vector.print %2: index
-        vector.print %v: f64
-     }
-
-    // CHECK-NEXT:1
-    // CHECK-NEXT:2
-    // CHECK-NEXT:1
-    //
-    // CHECK-NEXT:5
-    // CHECK-NEXT:6
-    // CHECK-NEXT:2
-    //
-    // CHECK-NEXT:7
-    // CHECK-NEXT:8
-    // CHECK-NEXT:3
-    sparse_tensor.foreach in %s5 : tensor<10x10xf64, #SortedCOOI32> do {
-      ^bb0(%1: index, %2: index, %v: f64) :
-        vector.print %1: index
-        vector.print %2: index
-        vector.print %v: f64
-     }
-
-    // CHECK-NEXT:0
-    // CHECK-NEXT:1
-    // CHECK-NEXT:1
-    //
-    // CHECK-NEXT:1
-    // CHECK-NEXT:0
-    // CHECK-NEXT:2
-    //
-    // CHECK-NEXT:1
-    // CHECK-NEXT:1
-    // CHECK-NEXT:3
-    sparse_tensor.foreach in %csr : tensor<2x2xf64, #CSR> do {
-      ^bb0(%1: index, %2: index, %v: f64) :
-        vector.print %1: index
-        vector.print %2: index
-        vector.print %v: f64
-     }
-
-
-    bufferization.dealloc_tensor %s4  : tensor<10x10xf64, #SortedCOO>
-    bufferization.dealloc_tensor %s5  : tensor<10x10xf64, #SortedCOOI32>
-    bufferization.dealloc_tensor %csr  : tensor<2x2xf64, #CSR>
-
-    return
-  }
-}



More information about the Mlir-commits mailing list