[Mlir-commits] [mlir] [mlir][sparse] unify support of (dis)assemble between direct IR/lib path (PR #71880)
Aart Bik
llvmlistbot at llvm.org
Fri Nov 10 17:49:40 PST 2023
================
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
}
};
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+ : public OpConversionPattern<DisassembleOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // We simply expose the buffers to the external client. This
+ // assumes the client only reads the buffers (usually copying it
+ // to the external data structures, such as numpy arrays).
+ Location loc = op->getLoc();
+ auto stt = getSparseTensorType(op.getTensor());
+ SmallVector<Value> retVal;
+ SmallVector<Value> retLen;
+ // Get the values buffer first.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+ // Then get the positions and coordinates buffers.
+ const Level lvlRank = stt.getLvlRank();
+ Level trailCOOLen = 0;
+ for (Level l = 0; l < lvlRank; l++) {
+ if (!stt.isUniqueLvl(l) &&
+ (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the target coordinate buffer used for trailing
+ // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply use the internal buffers.
+ trailCOOLen = lvlRank - l;
+ break;
+ }
+ if (stt.isWithPos(l)) {
+ auto poss =
+ genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ }
+ if (stt.isWithCrd(l)) {
+ auto crds =
+ genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
+ auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(crds);
+ retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
+ }
+ }
+ // Handle AoS vs. SoA mismatch for COO.
+ if (trailCOOLen != 0) {
+ uint64_t cooStartLvl = lvlRank - trailCOOLen;
+ assert(!stt.isUniqueLvl(cooStartLvl) &&
+ (stt.isCompressedLvl(cooStartLvl) ||
+ stt.isLooseCompressedLvl(cooStartLvl)));
+ // Positions.
+ auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
+ cooStartLvl);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ // Coordinates, copied over with:
+ // for (i = 0; i < crdLen; i++)
+ // buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
----------------
aartbik wrote:
I think that is harder, right, since we are coming from two memrefs?
Anyway, when generalizing this beyond 2-D we can fix.
https://github.com/llvm/llvm-project/pull/71880
More information about the Mlir-commits
mailing list