[Mlir-commits] [mlir] [mlir][sparse] avoid non-perm on sparse tensor convert for new (PR #72459)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 15 16:57:53 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
This avoids seeing non-perm on the convert from COO to non-COO for higher dimensional new operators (viz. reading in BSR).
This is step 1 out of 3 to make sparse_tensor.new work for BSR
---
Full diff: https://github.com/llvm/llvm-project/pull/72459.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+18-6)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 811bdc57ce14fb6..3fe0c551be57a4d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- const auto dstTp = getSparseTensorType(op.getResult());
- const auto encDst = dstTp.getEncoding();
- if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
+ auto stt = getSparseTensorType(op.getResult());
+ auto enc = stt.getEncoding();
+ if (!stt.hasEncoding() || getCOOStart(enc) == 0)
return failure();
// Implement the NewOp as follows:
// %orderedCoo = sparse_tensor.new %filename
// %t = sparse_tensor.convert %orderedCoo
+ // with enveloping reinterpreted_map ops for non-permutations.
+ RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
- Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
- op, dstTp.getRankedTensorType(), cooTensor);
+ Value convert = cooTensor;
+ if (!stt.isPermutation()) { // demap coo, demap dstTp
+ auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
+ convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
+ dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
+ }
+ convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
+ if (!stt.isPermutation()) // remap to original enc
+ convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
+ rewriter.replaceOp(op, convert);
- // Release the ordered COO tensor.
+ // Release the temporary ordered COO tensor.
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);
@@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
}
};
+/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();
+
// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, std::nullopt,
``````````
</details>
https://github.com/llvm/llvm-project/pull/72459
More information about the Mlir-commits
mailing list