[Mlir-commits] [mlir] [mlir][sparse] implementating stageSparseOpPass as an interface (PR #69022)

Yinying Li llvmlistbot at llvm.org
Mon Oct 16 10:14:15 PDT 2023

@@ -854,94 +900,54 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
     // foreach in %s1 : insert d0, d1, %tmp
     // foreach in %s2 : insert d0, d1 + size(s1), %tmp
     // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
-    // %t = convert_to_dest_tensor(%tmp)
-    //
-    // NOTE: this cannot be `const` because it will be changed when
-    // `needTmpCOO`, but that's buried in the conditional below and
-    // thus not easily extracted.
-    auto encDst = dstTp.getEncoding();
-    Value dst; // Destination tensor for inserting source tensor values.
-    bool needTmpCOO = true;
-    const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense();
-    Value annotatedDenseDst;
-    if (dstTp.hasEncoding()) {
-      bool allOrdered = false;
-      // When concatenating on dimension 0, and all inputs are sorted
-      // and have an identity dimToLvl, the concatenate will generate
-      // coords in lexOrder thus no need for the tmp COO buffer.
-      // TODO: When conDim != 0, as long as conDim is the first dimension
-      // in all input/output buffers, and all input/output buffers have the same
-      // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
-      // CSC matrices along column).
-      if (!allDense && conDim == 0 && dstTp.isIdentity()) {
-        for (auto i : op.getInputs()) {
-          const auto stt = getSparseTensorType(i);
-          allOrdered = stt.isAllOrdered() && stt.isIdentity();
-          if (!allOrdered)
-            break;
-        }
-      }
-      needTmpCOO = !allDense && !allOrdered;
-      const RankedTensorType tp = getBufferType(dstTp, needTmpCOO);
-      encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
-      SmallVector<Value> dynSizes;
-      getDynamicSizes(dstTp, sizes, dynSizes);
-      dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
-      if (allDense) {
-        // Create a view of the values buffer to match the unannotated dense
-        // tensor.
-        Value valuesBuffer = genToValues(rewriter, loc, dst);
-        Value dimCoords =
-            genAlloca(rewriter, loc, dimRank, rewriter.getIndexType(),
-                      /*staticShape=*/true);
-        annotatedDenseDst = dst;
-        dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer,
-                                    dimCoords);
-      }
-    } else {
-      // TODO: Dense buffers should be allocated/deallocated via the callback
-      // in BufferizationOptions.
-      dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
-    }
+    TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
     Value offset = constantIndex(rewriter, loc, 0);
-    SmallVector<Value> initArgs;
-    if (encDst && !allDense)
-      initArgs.push_back(dst);
+    Value iterArg = dstBuf.getSSA();
     ForeachOp foreachOp;
     for (Value input : op.getInputs()) {
       // Build a for op for each input tensor to append new values into the
       // output tensor.
       foreachOp = rewriter.create<ForeachOp>(
-          loc, input, initArgs,
+          loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
               ValueRange reduc) {
             SmallVector<Value> dstLcvs(dstTp.getLvlRank());
             for (Dimension d = 0; d < dimRank; d++) {
               Value crd = dcvs[d];
+              // Transform coordinates for the concatenating dim.
               if (d == conDim)
-                // Transform coordinates for the concatenating dim.
                 crd = builder.create<arith::AddIOp>(loc, crd, offset);
               // FIXME: `toStoredDim` is deprecated
-              dstLcvs[toStoredDim(encDst, d)] = crd;
+              dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
-            if (encDst && !allDense) {
-              Value cond = genIsNonzero(rewriter, loc, v);
-              scf::IfOp ifOp = builder.create<scf::IfOp>(
-                  loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
+            if (!reduc.empty())
+              dstBuf.updateSSA(reduc.front());
+            if (!dstTp.isAllDense()) {
+              Value cond = genIsNonzero(builder, loc, v);
+              auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
+                                                    /*else*/ true);
+              builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
-              Value t =
-                  builder.create<InsertOp>(loc, v, reduc.front(), dstLcvs);
-              rewriter.create<scf::YieldOp>(loc, t);
-              rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
-              rewriter.create<scf::YieldOp>(loc, reduc.front());
-              rewriter.setInsertionPointAfter(ifOp);
-              rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
+              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+              // Exits the ifOp, update the sparse tensor SSA value.
yinying-lisa-li wrote:

Exit? So it's consistent with update.


