[Mlir-commits] [mlir] 180bf5f - [mlir][sparse] fix a bug in sparse2sparse reshape.

Peiming Liu llvmlistbot at llvm.org
Thu Sep 8 17:32:09 PDT 2022


Author: Peiming Liu
Date: 2022-09-09T00:32:00Z
New Revision: 180bf5f9403d42586aedb374d63c72e75a7b7ce3

URL: https://github.com/llvm/llvm-project/commit/180bf5f9403d42586aedb374d63c72e75a7b7ce3
DIFF: https://github.com/llvm/llvm-project/commit/180bf5f9403d42586aedb374d63c72e75a7b7ce3.diff

LOG: [mlir][sparse] fix a bug in sparse2sparse reshape.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133521

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index ded1a98eed5e..f2967c33705b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -541,14 +541,18 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
 ///     coo->add(reshape(elem.indices), elem.value)
 ///   }
 ///   s = newSparseTensor(coo)
+template <typename ReshapeOp>
 static LogicalResult
-genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
-                        ArrayRef<ReassociationIndices> reassociation, Value src,
-                        RankedTensorType dstTp, RankedTensorType srcTp) {
-  Location loc = op->getLoc();
-  auto encDst = getSparseTensorEncoding(dstTp);
+genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
+                        ConversionPatternRewriter &rewriter) {
+  Location loc = op.getLoc();
+  auto srcTp = op.getSrc().getType().template cast<RankedTensorType>();
+  auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
   auto encSrc = getSparseTensorEncoding(srcTp);
-  assert(encDst && encSrc);
+  auto encDst = getSparseTensorEncoding(dstTp);
+  if (!encDst || !encSrc)
+    return failure();
+
   unsigned srcRank = srcTp.getRank();
   unsigned dstRank = dstTp.getRank();
   Type elemTp = srcTp.getElementType();
@@ -560,14 +564,16 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
       encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
   SmallVector<Value, 4> sizes;
   SmallVector<Value, 8> params;
-  sizesFromPtr(rewriter, sizes, loc, noPerm, srcTp, src);
+  sizesFromSrc(rewriter, sizes, loc, op.getSrc());
   newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes,
-            src);
+            adaptor.getSrc());
   Value iter = genNewCall(rewriter, loc, params);
   // Start a new COO for the destination tensor.
   sizes.clear();
   params.clear();
-  sizesFromPtr(rewriter, sizes, loc, encDst, dstTp, src);
+  // Fills sizes array using the sizes from destination type.
+  assert(dstTp.hasStaticShape());
+  sizesFromType(rewriter, sizes, loc, dstTp);
   newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
   Value coo = genNewCall(rewriter, loc, params);
   Value dstPerm = params[2];
@@ -586,7 +592,8 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
   // not need to store the value in elemPtr, as the value is still there.
   Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
   rewriter.setInsertionPointToStart(after);
-  translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
+  translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
+                   dstIdx, srcIdx);
   genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
   rewriter.create<scf::YieldOp>(loc);
   // Final call to construct sparse tensor storage and free temporary resources.
@@ -756,15 +763,7 @@ class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
   LogicalResult
   matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type dstType = op.getResult().getType();
-    Type srcType = op.getSrc().getType();
-    auto encDst = getSparseTensorEncoding(dstType);
-    auto encSrc = getSparseTensorEncoding(srcType);
-    if (encDst && encSrc)
-      return genSparse2SparseReshape(
-          op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
-          dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
-    return failure(); // handled elsewhere
+    return genSparse2SparseReshape(op, adaptor, rewriter);
   }
 };
 


        


More information about the Mlir-commits mailing list