[Mlir-commits] [mlir] 180bf5f - [mlir][sparse] fix a bug in sparse2sparse reshape.
Peiming Liu
llvmlistbot at llvm.org
Thu Sep 8 17:32:09 PDT 2022
Author: Peiming Liu
Date: 2022-09-09T00:32:00Z
New Revision: 180bf5f9403d42586aedb374d63c72e75a7b7ce3
URL: https://github.com/llvm/llvm-project/commit/180bf5f9403d42586aedb374d63c72e75a7b7ce3
DIFF: https://github.com/llvm/llvm-project/commit/180bf5f9403d42586aedb374d63c72e75a7b7ce3.diff
LOG: [mlir][sparse] fix a bug in sparse2sparse reshape.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133521
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index ded1a98eed5e..f2967c33705b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -541,14 +541,18 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
/// coo->add(reshape(elem.indices), elem.value)
/// }
/// s = newSparseTensor(coo)
+template <typename ReshapeOp>
static LogicalResult
-genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
- ArrayRef<ReassociationIndices> reassociation, Value src,
- RankedTensorType dstTp, RankedTensorType srcTp) {
- Location loc = op->getLoc();
- auto encDst = getSparseTensorEncoding(dstTp);
+genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+ auto srcTp = op.getSrc().getType().template cast<RankedTensorType>();
+ auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
auto encSrc = getSparseTensorEncoding(srcTp);
- assert(encDst && encSrc);
+ auto encDst = getSparseTensorEncoding(dstTp);
+ if (!encDst || !encSrc)
+ return failure();
+
unsigned srcRank = srcTp.getRank();
unsigned dstRank = dstTp.getRank();
Type elemTp = srcTp.getElementType();
@@ -560,14 +564,16 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
- sizesFromPtr(rewriter, sizes, loc, noPerm, srcTp, src);
+ sizesFromSrc(rewriter, sizes, loc, op.getSrc());
newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes,
- src);
+ adaptor.getSrc());
Value iter = genNewCall(rewriter, loc, params);
// Start a new COO for the destination tensor.
sizes.clear();
params.clear();
- sizesFromPtr(rewriter, sizes, loc, encDst, dstTp, src);
+ // Fills sizes array using the sizes from destination type.
+ assert(dstTp.hasStaticShape());
+ sizesFromType(rewriter, sizes, loc, dstTp);
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
Value coo = genNewCall(rewriter, loc, params);
Value dstPerm = params[2];
@@ -586,7 +592,8 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
// not need to store the value in elemPtr, as the value is still there.
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
rewriter.setInsertionPointToStart(after);
- translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
+ translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
+ dstIdx, srcIdx);
genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
rewriter.create<scf::YieldOp>(loc);
// Final call to construct sparse tensor storage and free temporary resources.
@@ -756,15 +763,7 @@ class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
LogicalResult
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type dstType = op.getResult().getType();
- Type srcType = op.getSrc().getType();
- auto encDst = getSparseTensorEncoding(dstType);
- auto encSrc = getSparseTensorEncoding(srcType);
- if (encDst && encSrc)
- return genSparse2SparseReshape(
- op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
- dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
- return failure(); // handled elsewhere
+ return genSparse2SparseReshape(op, adaptor, rewriter);
}
};
More information about the Mlir-commits
mailing list