[Mlir-commits] [mlir] [mlir][sparse] avoid non-perm on sparse tensor convert for new (PR #72459)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 15 16:57:53 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>

This avoids seeing non-perm on the convert from COO to non-COO for higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR

---
Full diff: https://github.com/llvm/llvm-project/pull/72459.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+18-6) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 811bdc57ce14fb6..3fe0c551be57a4d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
   LogicalResult matchAndRewrite(NewOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    const auto dstTp = getSparseTensorType(op.getResult());
-    const auto encDst = dstTp.getEncoding();
-    if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
+    auto stt = getSparseTensorType(op.getResult());
+    auto enc = stt.getEncoding();
+    if (!stt.hasEncoding() || getCOOStart(enc) == 0)
       return failure();
 
     // Implement the NewOp as follows:
     //   %orderedCoo = sparse_tensor.new %filename
     //   %t = sparse_tensor.convert %orderedCoo
+    // with enveloping reinterpreted_map ops for non-permutations.
+    RankedTensorType dstTp = stt.getRankedTensorType();
     RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
     Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
-    Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
-        op, dstTp.getRankedTensorType(), cooTensor);
+    Value convert = cooTensor;
+    if (!stt.isPermutation()) { // demap coo, demap dstTp
+      auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
+      convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
+      dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
+    }
+    convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
+    if (!stt.isPermutation()) // remap to original enc
+      convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
+    rewriter.replaceOp(op, convert);
 
-    // Release the ordered COO tensor.
+    // Release the temporary ordered COO tensor.
     rewriter.setInsertionPointAfterValue(convert);
     rewriter.create<DeallocTensorOp>(loc, cooTensor);
 
@@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
   }
 };
 
+/// Sparse rewriting rule for the out operator.
 struct OutRewriter : public OpRewritePattern<OutOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
                                     primaryTypeFunctionSuffix(eltTp)};
     Value value = genAllocaScalar(rewriter, loc, eltTp);
     ModuleOp module = op->getParentOfType<ModuleOp>();
+
     // For each element in the source tensor, output the element.
     rewriter.create<ForeachOp>(
         loc, src, std::nullopt,

``````````

</details>


https://github.com/llvm/llvm-project/pull/72459


More information about the Mlir-commits mailing list