[Mlir-commits] [mlir] faa00c1 - [mlir][sparse] implement sparse2sparse reshaping (expand/collapse)

Aart Bik llvmlistbot at llvm.org
Mon Jul 11 14:49:14 PDT 2022


Author: Aart Bik
Date: 2022-07-11T14:49:06-07:00
New Revision: faa00c131351725d8db74bac6a06459430344455

URL: https://github.com/llvm/llvm-project/commit/faa00c131351725d8db74bac6a06459430344455
DIFF: https://github.com/llvm/llvm-project/commit/faa00c131351725d8db74bac6a06459430344455.diff

LOG: [mlir][sparse] implement sparse2sparse reshaping (expand/collapse)

A previous revision implemented expand/collapse reshaping between
dense and sparse tensors for sparse2dense and dense2sparse since those
could use the "cheap" view reshape on the already materialized
dense tensor (at either the input or output side), and do some
reshuffling from or to sparse. The dense2dense case, as always,
is handled with a "cheap" view change.

This revision implements the sparse2sparse cases. Lacking any "view"
support on sparse tensors this operation necessarily has to perform
data reshuffling on both ends.

Tracker for improving this:
https://github.com/llvm/llvm-project/issues/56477

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D129416

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/rewriting.mlir
    mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index dae378490d2f7..282af7aed2df5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -238,7 +238,7 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
 /// the following and the insertion point after this routine is inside the
 /// if-then branch behind the assignment to ind. This is to ensure that the
 /// addEltX call generated after is inside the if-then branch.
-///    if (tensor[ivs]!=0) {
+///    if (tensor[ivs] != 0)
 ///      ind = ivs
 static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
                                       Value tensor, Value ind, ValueRange ivs) {
@@ -382,6 +382,133 @@ static bool canUseDirectConversion(
   return true;
 }
 
+/// Helper method to translate indices during a reshaping operation.
+/// TODO: provide as general utility to MLIR at large?
+static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
+                             ArrayRef<ReassociationIndices> reassociation,
+                             TensorType dstTp, TensorType srcTp, Value dstIdx,
+                             Value srcIdx) {
+  unsigned dstRank = dstTp.getRank();
+  unsigned srcRank = srcTp.getRank();
+  unsigned start = 0;
+  unsigned i = 0;
+  bool isExpand = srcRank > dstRank;
+  ArrayRef<int64_t> shape = isExpand ? srcTp.getShape() : dstTp.getShape();
+  // Iterate over reassociation map.
+  for (const auto &map : llvm::enumerate(reassociation)) {
+    // Prepare strides information in dimension slice.
+    uint64_t linear = 1;
+    for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+      assert(!ShapedType::isDynamic(shape[j]));
+      linear *= shape[j];
+    }
+    // Start collapse.
+    Value idx = constantIndex(rewriter, loc, i++);
+    Value val;
+    if (!isExpand)
+      val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
+    // Iterate over dimension slice.
+    for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+      linear /= shape[j];
+      Value stride = constantIndex(rewriter, loc, linear);
+      Value jdx = constantIndex(rewriter, loc, j);
+      if (isExpand) {
+        Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
+        Value mul = linear == 1
+                        ? old
+                        : rewriter.create<arith::MulIOp>(loc, old, stride);
+        val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
+      } else {
+        Value old = val;
+        if (linear != 1)
+          val = rewriter.create<arith::DivUIOp>(loc, val, stride);
+        rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
+        if (linear != 1)
+          val = rewriter.create<arith::RemUIOp>(loc, old, stride);
+      }
+    }
+    // Finalize expansion.
+    if (isExpand)
+      rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
+    start += map.value().size();
+  }
+  // Sanity.
+  assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
+}
+
+/// Generate code for a general sparse to sparse reshaping operation.
+/// Note that unlike dense reshaping (which can be done with a "cheap"
+/// change of view), sparse reshaping is currently done with actual
+/// data shuffling.
+///
+/// TODO: proportional to nnz, but still a lot of data movement
+///       https://github.com/llvm/llvm-project/issues/56477
+///
+///   iter = src->toCOO();
+///   coo = newSparseCOO()
+///   while (elem = iter->getNext()) {
+///     coo->add(reshape(elem.indices), elem.value)
+///   }
+///   s = newSparseTensor(coo)
+static LogicalResult
+genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
+                        ArrayRef<ReassociationIndices> reassociation, Value src,
+                        RankedTensorType dstTp, RankedTensorType srcTp) {
+  Location loc = op->getLoc();
+  auto encDst = getSparseTensorEncoding(dstTp);
+  auto encSrc = getSparseTensorEncoding(srcTp);
+  assert(encDst && encSrc);
+  unsigned srcRank = srcTp.getRank();
+  unsigned dstRank = dstTp.getRank();
+  Type elemTp = srcTp.getElementType();
+  assert(elemTp == dstTp.getElementType() &&
+         "reshape should not change element type");
+  // Start an iterator over the source tensor (in original index order).
+  auto noPerm = SparseTensorEncodingAttr::get(
+      op->getContext(), encSrc.getDimLevelType(), AffineMap(),
+      encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+  SmallVector<Value, 4> sizes;
+  SmallVector<Value, 8> params;
+  sizesFromPtr(rewriter, sizes, op, noPerm, srcTp, src);
+  newParams(rewriter, params, op, srcTp, noPerm, Action::kToIterator, sizes,
+            src);
+  Value iter = genNewCall(rewriter, op, params);
+  // Start a new COO for the destination tensor.
+  sizes.clear();
+  params.clear();
+  sizesFromPtr(rewriter, sizes, op, encDst, dstTp, src);
+  newParams(rewriter, params, op, dstTp, encDst, Action::kEmptyCOO, sizes);
+  Value coo = genNewCall(rewriter, op, params);
+  Value dstPerm = params[2];
+  // Construct a while loop over the iterator.
+  Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
+  Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
+  Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
+  SmallVector<Value> noArgs;
+  SmallVector<Type> noTypes;
+  auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
+  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
+  rewriter.setInsertionPointToEnd(before);
+  Value cond = genGetNextCall(rewriter, op, iter, srcIdx, elemPtr);
+  rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
+  // Translate indices from source to target and insert. Note that we do
+  // not need to store the value in elemPtr, as the value is still there.
+  Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
+  rewriter.setInsertionPointToStart(after);
+  translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
+  genAddEltCall(rewriter, op, elemTp, coo, elemPtr, dstIdx, dstPerm);
+  rewriter.create<scf::YieldOp>(loc);
+  // Final call to construct sparse tensor storage and free temporary resources.
+  rewriter.setInsertionPointAfter(whileOp);
+  params[6] = constantAction(rewriter, loc, Action::kFromCOO);
+  params[7] = coo;
+  Value dst = genNewCall(rewriter, op, params);
+  genDelCOOCall(rewriter, op, elemTp, coo);
+  genDelCOOCall(rewriter, op, elemTp, iter);
+  rewriter.replaceOp(op, dst);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -423,6 +550,7 @@ class SparseTensorToDimSizeConverter
 
 /// Sparse conversion rule for trivial tensor casts.
 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
+public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@@ -437,8 +565,30 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
   }
 };
 
+/// Sparse conversion rule for a reshape operator.
+template <typename ReshapeOp>
+class SparseReshapeConverter : public OpConversionPattern<ReshapeOp> {
+public:
+  using OpAdaptor = typename OpConversionPattern<ReshapeOp>::OpAdaptor;
+  using OpConversionPattern<ReshapeOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstType = op.getResult().getType();
+    Type srcType = op.getSrc().getType();
+    auto encDst = getSparseTensorEncoding(dstType);
+    auto encSrc = getSparseTensorEncoding(srcType);
+    if (encDst && encSrc)
+      return genSparse2SparseReshape(
+          op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
+          dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
+    return failure(); // handled elsewhere
+  }
+};
+
 /// Sparse conversion rule for the new operator.
 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
+public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(NewOp op, OpAdaptor adaptor,
@@ -463,6 +613,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
 /// Sparse conversion rule for the alloc operator.
 class SparseTensorAllocConverter
     : public OpConversionPattern<bufferization::AllocTensorOp> {
+public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
@@ -494,9 +645,6 @@ class SparseTensorAllocConverter
 
 /// Sparse conversion rule for the convert operator.
 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
-  /// Options to control sparse code generation.
-  SparseTensorConversionOptions options;
-
 public:
   using OpConversionPattern::OpConversionPattern;
   SparseTensorConvertConverter(MLIRContext *context,
@@ -697,6 +845,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     rewriter.replaceOp(op, dst);
     return success();
   }
+
+private:
+  /// Options to control sparse code generation.
+  SparseTensorConversionOptions options;
 };
 
 /// Sparse conversion rule for the release operator.
@@ -799,6 +951,7 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
   }
 };
 
+/// Sparse conversion rule for the expand operator.
 class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -841,6 +994,7 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
   }
 };
 
+/// Sparse conversion rule for the compress operator.
 class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -873,6 +1027,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
+/// Sparse conversion rule for the output operator.
 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -926,6 +1081,8 @@ void mlir::populateSparseTensorConversionPatterns(
     const SparseTensorConversionOptions &options) {
   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
                SparseCastConverter, SparseTensorNewConverter,
+               SparseReshapeConverter<tensor::ExpandShapeOp>,
+               SparseReshapeConverter<tensor::CollapseShapeOp>,
                SparseTensorAllocConverter, SparseTensorReleaseConverter,
                SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
                SparseTensorToValuesConverter, SparseTensorLoadConverter,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f85f47a29ec94..1f157eab3c57d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -127,13 +127,11 @@ struct SparseTensorConversionPass
         });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
-    target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
-                      arith::IndexCastOp, complex::ConstantOp,
-                      complex::NotEqualOp, linalg::FillOp, linalg::YieldOp,
-                      tensor::ExtractOp>();
-    target
-        .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
-                         memref::MemRefDialect, scf::SCFDialect>();
+    target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
+                      linalg::YieldOp, tensor::ExtractOp>();
+    target.addLegalDialect<
+        arith::ArithmeticDialect, bufferization::BufferizationDialect,
+        LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
     // Translate strategy flags to strategy options.
     SparseTensorConversionOptions options(
         sparseToSparseConversionStrategy(sparseToSparse));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a5364ef6c601..d892aa67f8c85 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1832,71 +1832,38 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
   SparsificationOptions options;
 };
 
-/// Sparse rewriting rule for expand shape operator.
-struct ExpandShapeRewriter : public OpRewritePattern<tensor::ExpandShapeOp> {
+/// Sparse rewriting rule for reshape operator.
+template <typename ReshapeOp>
+struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
 public:
-  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+  using OpRewritePattern<ReshapeOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(tensor::ExpandShapeOp op,
+  LogicalResult matchAndRewrite(ReshapeOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     auto encDst = getSparseTensorEncoding(op.getResult().getType());
     auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
     // Since a pure dense expansion is very cheap (change of view), for
-    // sparse2dense or dense2sparse, we can simply unfuse a sparse
-    // conversion from the actual expansion operation itself.
+    // a sparse2dense or dense2sparse, we can simply unfuse a sparse
+    // conversion from the reshape operation itself.
+    // All other cases are handled elsewhere.
     if (encDst && encSrc) {
-      return failure(); // TODO: implement sparse2sparse
-    } else if (encSrc) {
-      RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>();
-      auto denseTp =
-          RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
-      op->setOperand(0, convert);
-      return success();
-    } else if (encDst) {
-      RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>();
-      auto denseTp =
-          RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto reshape = rewriter.create<tensor::ExpandShapeOp>(
-          loc, denseTp, op.getSrc(), op.getReassociation());
-      Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
-      rewriter.replaceOp(op, convert);
-      return success();
-    }
-    return failure();
-  }
-};
-
-/// Sparse rewriting rule for collapse shape operator.
-struct CollapseShapeRewriter
-    : public OpRewritePattern<tensor::CollapseShapeOp> {
-public:
-  using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::CollapseShapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    auto encDst = getSparseTensorEncoding(op.getResult().getType());
-    auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
-    // Since a pure dense collapse is very cheap (change of view), for
-    // sparse2dense or dense2sparse, we can simply unfuse a sparse
-    // conversion from the actual collapse operation itself.
-    if (encDst && encSrc) {
-      return failure(); // TODO: implement sparse2sparse
+      return failure();
     } else if (encSrc) {
-      RankedTensorType rtp = op.getSrc().getType().cast<RankedTensorType>();
+      RankedTensorType rtp =
+          op.getSrc().getType().template cast<RankedTensorType>();
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
       op->setOperand(0, convert);
       return success();
     } else if (encDst) {
-      RankedTensorType rtp = op.getResult().getType().cast<RankedTensorType>();
+      RankedTensorType rtp =
+          op.getResult().getType().template cast<RankedTensorType>();
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto reshape = rewriter.create<tensor::CollapseShapeOp>(
-          loc, denseTp, op.getSrc(), op.getReassociation());
+      auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+                                                op.getReassociation());
       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
       rewriter.replaceOp(op, convert);
       return success();
@@ -1912,6 +1879,6 @@ struct CollapseShapeRewriter
 void mlir::populateSparsificationPatterns(
     RewritePatternSet &patterns, const SparsificationOptions &options) {
   patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
-  patterns.add<ExpandShapeRewriter, CollapseShapeRewriter>(
-      patterns.getContext());
+  patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
+               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir
old mode 100644
new mode 100755
index 3955310fce9be..000c3560f1e0f
--- a/mlir/test/Dialect/SparseTensor/rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir
@@ -40,8 +40,14 @@ func.func @expand_to_sparse(%arg0: tensor<12xf64>) -> tensor<3x4xf64, #SparseMat
   return %0 : tensor<3x4xf64, #SparseMatrix>
 }
 
-// TODO: make this work
+//
+// Not rewritten, needs conversion.
+//
 // CHECK-LABEL:   func.func @expand_sparse2sparse(
+// CHECK-SAME:    %[[A:.*]]: tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK:         %[[E:.*]] = tensor.expand_shape %[[A]] {{.*}} : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:         return %[[E]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:       }
 func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
   %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
   return %0 : tensor<3x4xf64, #SparseMatrix>
@@ -79,8 +85,14 @@ func.func @collapse_to_sparse(%arg0: tensor<3x4xf64>) -> tensor<12xf64, #SparseV
   return %0 : tensor<12xf64, #SparseVector>
 }
 
-// TODO: make this work
+//
+// Not rewritten, needs conversion.
+//
 // CHECK-LABEL:   func.func @collapse_sparse2sparse(
+// CHECK-SAME:    %[[A:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK:         %[[C:.*]] = tensor.collapse_shape %[[A]] {{.*}} : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:         return %[[C]] : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:       }
 func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
   %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector>
   return %0 : tensor<12xf64, #SparseVector>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index c791536e15197..65eb56b9bac37 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,24 +1,81 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-
-// TODO: check lowering to an actual implementation
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
+// RUN: mlir-opt %s --sparse-tensor-conversion --cse | FileCheck %s --check-prefix=CHECK-CONV
 
 #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
 #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
 
-// CHECK-LABEL: func.func @sparse_expand(
-// CHECK-SAME:  %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
-//      CHECK:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
-//      CHECK:  return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// roundtrip:
+//
+// CHECK-ROUND-LABEL: func.func @sparse_expand(
+// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK-ROUND:  return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// conversion:
+//
+// CHECK-CONV-LABEL: func.func @sparse_expand(
+// CHECK-CONV-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONV-DAG:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONV-DAG:  %[[C10:.*]] = arith.constant 10 : index
+// CHECK-CONV-DAG:  call @newSparseTensor
+// CHECK-CONV-DAG:  call @newSparseTensor
+// CHECK-CONV:      scf.while : () -> () {
+// CHECK-CONV:        call @getNextF64
+// CHECK-CONV:        scf.condition(%13)
+// CHECK-CONV:      } do {
+// CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV:        %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index
+// CHECK-CONV:        memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV:        %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index
+// CHECK-CONV:        memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<?xindex>
+// CHECK-CONV:        call @addEltF64
+// CHECK-CONV:        scf.yield
+// CHECK-CONV:      }
+// CHECK-CONV:      %[[N:.*]] = call @newSparseTensor
+// CHECK-CONV:      call @delSparseTensorCOOF64
+// CHECK-CONV:      call @delSparseTensorCOOF64
+// CHECK-CONV:      return %[[N]] : !llvm.ptr<i8>
+//
 func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
   %0 = tensor.expand_shape %arg0 [[0, 1]] :
     tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
   return %0 : tensor<10x10xf64, #SparseMatrix>
 }
 
-// CHECK-LABEL: func.func @sparse_collapse(
-// CHECK-SAME:  %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
-//      CHECK:  %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
-//      CHECK:  return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// roundtrip:
+//
+// CHECK-ROUND-LABEL: func.func @sparse_collapse(
+// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK-ROUND:  %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK-ROUND:  return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//
+// conversion:
+//
+// CHECK-CONV-LABEL: func.func @sparse_collapse(
+// CHECK-CONV-DAG:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK-CONV-DAG:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK-CONV-DAG:  %[[C10:.*]] = arith.constant 10 : index
+// CHECK-CONV-DAG:  call @newSparseTensor
+// CHECK-CONV-DAG:  call @newSparseTensor
+// CHECK-CONV:      scf.while : () -> () {
+// CHECK-CONV:        call @getNextF64
+// CHECK-CONV:        scf.condition(%13)
+// CHECK-CONV:      } do {
+// CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV:        %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
+// CHECK-CONV:        %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xindex>
+// CHECK-CONV:        %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index
+// CHECK-CONV:        memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<?xindex>
+// CHECK-CONV:        call @addEltF64
+// CHECK-CONV:        scf.yield
+// CHECK-CONV:      }
+// CHECK-CONV:      %[[N:.*]] = call @newSparseTensor
+// CHECK-CONV:      call @delSparseTensorCOOF64
+// CHECK-CONV:      call @delSparseTensorCOOF64
+// CHECK-CONV:      return %[[N]] : !llvm.ptr<i8>
+//
 func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
   %0 = tensor.collapse_shape %arg0 [[0, 1]] :
     tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
old mode 100644
new mode 100755
index eed880866884d..57d2a931fad5d
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
@@ -32,11 +32,10 @@ module {
     return %0 : tensor<3x4xf64, #SparseMatrix>
   }
 
-// TODO: make this work
-//  func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
-//    %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
-//    return %0 : tensor<3x4xf64, #SparseMatrix>
-//  }
+  func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> {
+    %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix>
+    return %0 : tensor<3x4xf64, #SparseMatrix>
+  }
 
   func.func @collapse_dense(%arg0: tensor<3x4xf64>) -> tensor<12xf64> {
     %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64>
@@ -53,11 +52,10 @@ module {
     return %0 : tensor<12xf64, #SparseVector>
   }
 
-// TODO: make this work
-//  func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
-//    %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector>
-//    return %0 : tensor<12xf64, #SparseVector>
-//  }
+  func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
+    %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector>
+    return %0 : tensor<12xf64, #SparseVector>
+  }
 
 
   //
@@ -81,10 +79,12 @@ module {
     %expand0 = call @expand_dense(%v) : (tensor<12xf64>) -> tensor<3x4xf64>
     %expand1 = call @expand_from_sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64>
     %expand2 = call @expand_to_sparse(%v) : (tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix>
+    %expand3 = call @expand_sparse2sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix>
 
     %collapse0 = call @collapse_dense(%m) : (tensor<3x4xf64>) -> tensor<12xf64>
     %collapse1 = call @collapse_from_sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64>
     %collapse2 = call @collapse_to_sparse(%m) : (tensor<3x4xf64>) -> tensor<12xf64, #SparseVector>
+    %collapse3 = call @collapse_sparse2sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector>
 
     //
     // Verify result.
@@ -92,9 +92,11 @@ module {
     // CHECK:      ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) )
     // CHECK-NEXT: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) )
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 )
     // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 )
     // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 )
     // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 )
     //
     %m0 = vector.transfer_read %expand0[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64>
     vector.print %m0 : vector<3x4xf64>
@@ -103,6 +105,9 @@ module {
     %a2 = sparse_tensor.values %expand2 : tensor<3x4xf64, #SparseMatrix> to memref<?xf64>
     %m2 = vector.transfer_read %a2[%c0], %df: memref<?xf64>, vector<16xf64>
     vector.print %m2 : vector<16xf64>
+    %a3 = sparse_tensor.values %expand3 : tensor<3x4xf64, #SparseMatrix> to memref<?xf64>
+    %m3 = vector.transfer_read %a3[%c0], %df: memref<?xf64>, vector<16xf64>
+    vector.print %m3 : vector<16xf64>
 
     %v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64>
     vector.print %v0 : vector<12xf64>
@@ -111,12 +116,17 @@ module {
     %b2 = sparse_tensor.values %collapse2 : tensor<12xf64, #SparseVector> to memref<?xf64>
     %v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<16xf64>
     vector.print %v2 : vector<16xf64>
+    %b3 = sparse_tensor.values %collapse3 : tensor<12xf64, #SparseVector> to memref<?xf64>
+    %v3 = vector.transfer_read %b3[%c0], %df: memref<?xf64>, vector<16xf64>
+    vector.print %v3 : vector<16xf64>
 
     // Release sparse resources.
     sparse_tensor.release %sv : tensor<12xf64, #SparseVector>
     sparse_tensor.release %sm : tensor<3x4xf64, #SparseMatrix>
     sparse_tensor.release %expand2 : tensor<3x4xf64, #SparseMatrix>
+    sparse_tensor.release %expand3 : tensor<3x4xf64, #SparseMatrix>
     sparse_tensor.release %collapse2 : tensor<12xf64, #SparseVector>
+    sparse_tensor.release %collapse3 : tensor<12xf64, #SparseVector>
 
     // Release dense resources.
     %meme1 = bufferization.to_memref %expand1 : memref<3x4xf64>


        


More information about the Mlir-commits mailing list