[Mlir-commits] [mlir] [mlir][sparse] unify support of (dis)assemble between direct IR/lib path (PR #71880)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 9 16:01:01 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>

Note that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same.

Generalizing the ops is still TBD.

---

Patch is 37.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71880.diff


8 Files Affected:

- (modified) mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h (+43-22) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (+11) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h (+4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (-13) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+165-57) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+2-1) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+79-37) 
- (removed) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_libgen.mlir (-165) 


``````````diff
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 460549726356370..3382e293d123746 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -301,8 +301,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       uint64_t lvlRank = getLvlRank();
       uint64_t valIdx = 0;
       // Linearize the address.
-      for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
-        valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+      for (uint64_t l = 0; l < lvlRank; l++)
+        valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
       values[valIdx] = val;
       return;
     }
@@ -472,9 +472,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
     if (isCompressedLvl(l))
       return positions[l][parentSz];
-    if (isSingletonLvl(l))
-      return parentSz; // New size is same as the parent.
-    // TODO: support levels assignment for loose/2:4?
+    if (isLooseCompressedLvl(l))
+      return positions[l][2 * parentSz - 1];
+    if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+      return parentSz; // new size same as the parent
     assert(isDenseLvl(l));
     return parentSz * getLvlSize(l);
   }
@@ -766,40 +767,59 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
     : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
                           dim2lvl, lvl2dim) {
+  // Note that none of the buffers cany be reused because ownership
+  // of the memory passed from clients is not necessarily transferred.
+  // Therefore, all data is copied over into a new SparseTensorStorage.
+  //
+  // TODO: this needs to be generalized to all formats AND
+  //       we need a proper audit of e.g. double compressed
+  //       levels where some are not filled
+  //
   uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
   for (uint64_t l = 0; l < lvlRank; l++) {
-    if (!isUniqueLvl(l) && isCompressedLvl(l)) {
-      // A `compressed_nu` level marks the start of trailing COO start level.
-      // Since the coordinate buffer used for trailing COO are passed in as AoS
-      // scheme, and SparseTensorStorage uses a SoA scheme, we can not simply
-      // copy the value from the provided buffers.
+    if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
+      // A `(loose)compressed_nu` level marks the start of trailing COO
+      // start level. Since the coordinate buffer used for trailing COO
+      // is passed in as AoS scheme and SparseTensorStorage uses a SoA
+      // scheme, we cannot simply copy the value from the provided buffers.
       trailCOOLen = lvlRank - l;
       break;
     }
-    assert(!isSingletonLvl(l) &&
-           "Singleton level not following a compressed_nu level");
-    if (isCompressedLvl(l)) {
+    if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
       P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
       C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
-      // Copies the lvlBuf into the vectors. The buffer can not be simply reused
-      // because the memory passed from users is not necessarily allocated on
-      // heap.
-      positions[l].assign(posPtr, posPtr + parentSz + 1);
-      coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+      if (!isLooseCompressedLvl(l)) {
+        positions[l].assign(posPtr, posPtr + parentSz + 1);
+        coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+      } else {
+        positions[l].assign(posPtr, posPtr + 2 * parentSz);
+        coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
+      }
+    } else if (isSingletonLvl(l)) {
+      assert(0 && "general singleton not supported yet");
+    } else if (is2OutOf4Lvl(l)) {
+      assert(0 && "2Out4 not supported yet");
     } else {
-      // TODO: support levels assignment for loose/2:4?
       assert(isDenseLvl(l));
     }
     parentSz = assembledSize(parentSz, l);
   }
 
+  // Handle Aos vs. SoA mismatch for COO.
   if (trailCOOLen != 0) {
     uint64_t cooStartLvl = lvlRank - trailCOOLen;
-    assert(!isUniqueLvl(cooStartLvl) && isCompressedLvl(cooStartLvl));
+    assert(!isUniqueLvl(cooStartLvl) &&
+           (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
     P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
     C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
-    positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
-    P crdLen = positions[cooStartLvl][parentSz];
+    P crdLen;
+    if (!isLooseCompressedLvl(cooStartLvl)) {
+      positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
+      crdLen = positions[cooStartLvl][parentSz];
+    } else {
+      positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
+      crdLen = positions[cooStartLvl][2 * parentSz - 1];
+    }
     for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
       coordinates[l].resize(crdLen);
       for (uint64_t n = 0; n < crdLen; n++) {
@@ -809,6 +829,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     parentSz = assembledSize(parentSz, cooStartLvl);
   }
 
+  // Copy the values buffer.
   V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
   values.assign(valPtr, valPtr + parentSz);
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..8e2c2cd6dad7b19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -163,6 +163,17 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
   return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
 }
 
+Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
+                                       Value elem, Type dstTp) {
+  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+    // Scalars can only be converted to 0-ranked tensors.
+    assert(rtp.getRank() == 0);
+    elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
+    return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+  }
+  return sparse_tensor::genCast(builder, loc, elem, dstTp);
+}
+
 Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
                                   Value s) {
   Value load = builder.create<memref::LoadOp>(loc, mem, s);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1f53f3525203c70..d3b0889b71b514c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -142,6 +142,10 @@ class FuncCallOrInlineGenerator {
 /// Add type casting between arith and index types when needed.
 Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
 
+/// Add conversion from scalar to given type (possibly a 0-rank tensor).
+Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+                        Type dstTp);
+
 /// Generates a pointer/index load from the sparse storage scheme. Narrower
 /// data types need to be zero extended before casting the value into the
 /// index type used for looping and indexing.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 08c38394a46343a..888f513be2e4dc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -435,19 +435,6 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
   return reassociation;
 }
 
-/// Generates scalar to tensor cast.
-static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
-                               Type dstTp) {
-  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
-    // Scalars can only be converted to 0-ranked tensors.
-    if (rtp.getRank() != 0)
-      return nullptr;
-    elem = genCast(builder, loc, elem, rtp.getElementType());
-    return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
-  }
-  return genCast(builder, loc, elem, dstTp);
-}
-
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4fe9c59d8c320a7..e629133171e15dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
   return std::nullopt;
 }
 
-/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
-static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
-                                          StringRef name, TypeRange resultType,
-                                          ValueRange operands,
-                                          EmitCInterface emitCInterface) {
-  auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
-                    emitCInterface);
-  return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
-                                                   operands);
-}
-
 /// Generates call to lookup a level-size.  N.B., this only generates
 /// the raw function call, and therefore (intentionally) does not perform
 /// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ class NewCallParams final {
 };
 
 /// Generates a call to obtain the values array.
-static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
-                           ValueRange ptr) {
-  SmallString<15> name{"sparseValues",
-                       primaryTypeFunctionSuffix(tp.getElementType())};
-  return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+static Value genValuesCall(OpBuilder &builder, Location loc,
+                           SparseTensorType stt, Value ptr) {
+  auto eltTp = stt.getElementType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
+  SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
+      .getResult(0);
+}
+
+/// Generates a call to obtain the positions array.
+static Value genPositionsCall(OpBuilder &builder, Location loc,
+                              SparseTensorType stt, Value ptr, Level l) {
+  Type posTp = stt.getPosType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
+  Value lvl = constantIndex(builder, loc, l);
+  SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+                        EmitCInterface::On)
+      .getResult(0);
+}
+
+/// Generates a call to obtain the coordindates array.
+static Value genCoordinatesCall(OpBuilder &builder, Location loc,
+                                SparseTensorType stt, Value ptr, Level l) {
+  Type crdTp = stt.getCrdType();
+  auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+  Value lvl = constantIndex(builder, loc, l);
+  SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
+  return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+                        EmitCInterface::On)
       .getResult(0);
 }
 
@@ -391,7 +405,7 @@ class SparseTensorAllocConverter
     SmallVector<Value> dimSizes;
     dimSizes.reserve(dimRank);
     unsigned operandCtr = 0;
-    for (Dimension d = 0; d < dimRank; ++d) {
+    for (Dimension d = 0; d < dimRank; d++) {
       dimSizes.push_back(
           stt.isDynamicDim(d)
               ? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
     dimSizes.reserve(dimRank);
     auto shape = op.getType().getShape();
     unsigned operandCtr = 0;
-    for (Dimension d = 0; d < dimRank; ++d) {
+    for (Dimension d = 0; d < dimRank; d++) {
       dimSizes.push_back(stt.isDynamicDim(d)
                              ? adaptor.getOperands()[operandCtr++]
                              : constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ class SparseTensorToPositionsConverter
   LogicalResult
   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resTp = op.getType();
-    Type posTp = cast<ShapedType>(resTp).getElementType();
-    SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
-    Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
-    replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
-                          EmitCInterface::On);
+    auto stt = getSparseTensorType(op.getTensor());
+    auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
+                                 adaptor.getTensor(), op.getLevel());
+    rewriter.replaceOp(op, poss);
     return success();
   }
 };
@@ -505,29 +517,14 @@ class SparseTensorToCoordinatesConverter
   LogicalResult
   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // TODO: use `SparseTensorType::getCrdType` instead.
-    Type resType = op.getType();
-    const Type crdTp = cast<ShapedType>(resType).getElementType();
-    SmallString<19> name{"sparseCoordinates",
-                         overheadTypeFunctionSuffix(crdTp)};
-    Location loc = op->getLoc();
-    Value lvl = constantIndex(rewriter, loc, op.getLevel());
-
-    // The function returns a MemRef without a layout.
-    MemRefType callRetType = get1DMemRefType(crdTp, false);
-    SmallVector<Value> operands{adaptor.getTensor(), lvl};
-    auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
-                      operands, EmitCInterface::On);
-    Value callRet =
-        rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
-            .getResult(0);
-
+    auto stt = getSparseTensorType(op.getTensor());
+    auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
+                                   adaptor.getTensor(), op.getLevel());
     // Cast the MemRef type to the type expected by the users, though these
     // two types should be compatible at runtime.
-    if (resType != callRetType)
-      callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
-    rewriter.replaceOp(op, callRet);
-
+    if (op.getType() != crds.getType())
+      crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+    rewriter.replaceOp(op, crds);
     return success();
   }
 };
@@ -539,9 +536,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
   LogicalResult
   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto resType = cast<ShapedType>(op.getType());
-    rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
-                                         adaptor.getOperands()));
+    auto stt = getSparseTensorType(op.getTensor());
+    auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+    rewriter.replaceOp(op, vals);
     return success();
   }
 };
@@ -554,13 +551,11 @@ class SparseNumberOfEntriesConverter
   LogicalResult
   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
     // Query values array size for the actually stored values size.
-    Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
-    auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
-    Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
-    rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
-                                               constantIndex(rewriter, loc, 0));
+    auto stt = getSparseTensorType(op.getTensor());
+    auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+    auto zero = constantIndex(rewriter, op.getLoc(), 0);
+    rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
     return success();
   }
 };
@@ -701,7 +696,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
-/// Sparse conversion rule for the sparse_tensor.pack operator.
+/// Sparse conversion rule for the sparse_tensor.assemble operator.
 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
                   ConversionPatternRewriter &rewriter) const override {
     const Location loc = op->getLoc();
     const auto dstTp = getSparseTensorType(op.getResult());
-    // AssembleOps always returns a static shaped tensor result.
     assert(dstTp.hasStaticDimShape());
     SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
+    // Use a library method to transfer the external buffers from
+    // clients to the internal SparseTensorStorage. Since we cannot
+    // assume clients transfer ownership of the buffers, this method
+    // will copy all data over into a new SparseTensorStorage.
     Value dst =
         NewCallParams(rewriter, loc)
             .genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
   }
 };
 
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+    : public OpConversionPattern<DisassembleOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We simply expose the buffers to the external client. This
+    // assumes the client only reads the buffers (usually copying it
+    // to the external data structures, such as numpy arrays).
+    Location loc = op->getLoc();
+    auto stt = getSparseTensorType(op.getTensor());
+    SmallVector<Value> retVal;
+    SmallVector<Value> retLen;
+    // Get the values buffer first.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+    // Then get the positions and coordinates buffers.
+    const Level lvlRank = stt.getLvlRank();
+    Level trailCOOLen = 0;
+    for (Level l = 0; l < lvlRank; l++) {
+      if (!stt.isUniqueLvl(l) &&
+          (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+        // A `(loose)compressed_nu` level marks the start of trailing COO
+        // start level. Since the target coordinate buffer used for trailing
+        // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+        // scheme, we cannot simply use the internal buffers.
+        trailCOOLen = lvlRank - l;
+        break;
+      }
+      if (stt.isWithPos(l)) {
+        auto poss =
+            genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+        auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        retVal.push_back(poss);
+        retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+      }
+      if (stt.isWithCrd(l)) {
+        auto crds =
+            genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+        auto crdLen = linalg::createOrFoldDimOp(rewriter, lo...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71880


More information about the Mlir-commits mailing list