[Mlir-commits] [mlir] [mlir][sparse] unify support of (dis)assemble between direct IR/lib path (PR #71880)

Aart Bik llvmlistbot at llvm.org
Fri Nov 10 17:49:40 PST 2023


================
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
   }
 };
 
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+    : public OpConversionPattern<DisassembleOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We simply expose the buffers to the external client. This
+    // assumes the client only reads the buffers (usually copying it
+    // to the external data structures, such as numpy arrays).
+    Location loc = op->getLoc();
+    auto stt = getSparseTensorType(op.getTensor());
+    SmallVector<Value> retVal;
+    SmallVector<Value> retLen;
+    // Get the values buffer first.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+    // Then get the positions and coordinates buffers.
+    const Level lvlRank = stt.getLvlRank();
+    Level trailCOOLen = 0;
+    for (Level l = 0; l < lvlRank; l++) {
+      if (!stt.isUniqueLvl(l) &&
+          (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+        // A `(loose)compressed_nu` level marks the start of trailing COO
+        // start level. Since the target coordinate buffer used for trailing
+        // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+        // scheme, we cannot simply use the internal buffers.
+        trailCOOLen = lvlRank - l;
+        break;
+      }
+      if (stt.isWithPos(l)) {
+        auto poss =
+            genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+        auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        retVal.push_back(poss);
+        retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+      }
+      if (stt.isWithCrd(l)) {
+        auto crds =
+            genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+        auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
+        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        retVal.push_back(crds);
+        retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
+      }
+    }
+    // Handle AoS vs. SoA mismatch for COO.
+    if (trailCOOLen != 0) {
+      uint64_t cooStartLvl = lvlRank - trailCOOLen;
+      assert(!stt.isUniqueLvl(cooStartLvl) &&
+             (stt.isCompressedLvl(cooStartLvl) ||
+              stt.isLooseCompressedLvl(cooStartLvl)));
+      // Positions.
+      auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
+                                   cooStartLvl);
+      auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+      auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      retVal.push_back(poss);
+      retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+      // Coordinates, copied over with:
+      //    for (i = 0; i < crdLen; i++)
+      //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
----------------
aartbik wrote:

I think that is harder, right, since we are coming from two memrefs?
Anyway, when generalizing this beyond 2-D we can fix.

https://github.com/llvm/llvm-project/pull/71880


More information about the Mlir-commits mailing list