[Mlir-commits] [mlir] e8fc282 - [mlir][sparse] avoid non-perm on sparse tensor convert for new (#72459)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 15 20:47:41 PST 2023
Author: Aart Bik
Date: 2023-11-15T20:47:37-08:00
New Revision: e8fc282ff26b4d1d71a316bf036fc486b420ea19
URL: https://github.com/llvm/llvm-project/commit/e8fc282ff26b4d1d71a316bf036fc486b420ea19
DIFF: https://github.com/llvm/llvm-project/commit/e8fc282ff26b4d1d71a316bf036fc486b420ea19.diff
LOG: [mlir][sparse] avoid non-perm on sparse tensor convert for new (#72459)
This avoids seeing non-perm on the convert from COO to non-COO for
higher dimensional new operators (viz. reading in BSR).
This is step 1 out of 3 to make sparse_tensor.new work for BSR
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 811bdc57ce14fb6..3fe0c551be57a4d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- const auto dstTp = getSparseTensorType(op.getResult());
- const auto encDst = dstTp.getEncoding();
- if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
+ auto stt = getSparseTensorType(op.getResult());
+ auto enc = stt.getEncoding();
+ if (!stt.hasEncoding() || getCOOStart(enc) == 0)
return failure();
// Implement the NewOp as follows:
// %orderedCoo = sparse_tensor.new %filename
// %t = sparse_tensor.convert %orderedCoo
+ // with enveloping reinterpreted_map ops for non-permutations.
+ RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
- Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
- op, dstTp.getRankedTensorType(), cooTensor);
+ Value convert = cooTensor;
+ if (!stt.isPermutation()) { // demap coo, demap dstTp
+ auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
+ convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
+ dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
+ }
+ convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
+ if (!stt.isPermutation()) // remap to original enc
+ convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
+ rewriter.replaceOp(op, convert);
- // Release the ordered COO tensor.
+ // Release the temporary ordered COO tensor.
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);
@@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
}
};
+/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();
+
// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, std::nullopt,
More information about the Mlir-commits
mailing list