[Mlir-commits] [mlir] c24547e - [mlir][sparse] avoid creating temporary unordered COO buffer when reshape sparse tensor.

Peiming Liu llvmlistbot at llvm.org
Wed Mar 29 18:30:01 PDT 2023


Author: Peiming Liu
Date: 2023-03-30T01:29:55Z
New Revision: c24547e969ef183971c3a02cb8cbf151eb529715

URL: https://github.com/llvm/llvm-project/commit/c24547e969ef183971c3a02cb8cbf151eb529715
DIFF: https://github.com/llvm/llvm-project/commit/c24547e969ef183971c3a02cb8cbf151eb529715.diff

LOG: [mlir][sparse] avoid creating temporary unordered COO buffer when reshape sparse tensor.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D147192

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/sparse_reshape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7827573b215a1..dc5755bcf5012 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -358,9 +358,12 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
     Value srcTensor = op.getSrc();
     auto srcTp = getRankedTensorType(srcTensor);
     auto dstTp = getRankedTensorType(op.getResult());
-    SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
-    SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
-    if (!encDst || !encSrc) {
+
+    SparseTensorType srcStt(srcTp);
+    SparseTensorType dstStt(dstTp);
+
+    const auto encSrc = srcStt.getEncoding();
+    if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) {
       return failure();
     }
 
@@ -382,22 +385,29 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           dstDynSizes.push_back(dstSizes[idx]);
       }
     }
-
-    // Implement the sparse2sparse reshape as follows:
-    //   %tmp = bufferization.alloc_tensor : unordered COO
-    //   foreach srcCoords %srcTensor
-    //     insert reshapeCvs(srcCoords), %tmp
-    //   %t = sparse_tensor.cast %tmp
     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
-    RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
-    Value cooBuffer =
+    // Only need a unordered COO buffer if input and output are not sorted
+    // in the same way.
+    Type bufferTp =
+        srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity()
+            ? dstTp
+            : getUnorderedCOOFromType(dstTp);
+
+    Value buffer =
         rewriter
-            .create<AllocTensorOp>(loc, cooTp, dstDynSizes, Value(),
+            .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
                                    /*sizeHint=*/nnz, Attribute())
             .getResult();
 
+    // Implement the sparse2sparse reshape as follows:
+    //   foreach srcCoords %srcTensor
+    //     insert reshapeCvs(srcCoords), %buffer
+    //
+    // followed by an optional
+    //   %t = sparse_tensor.cast %tmp
+    // depending on whether the input/output are sorted in the same way.
     ForeachOp foreachOp = rewriter.create<ForeachOp>(
-        loc, srcTensor, cooBuffer,
+        loc, srcTensor, buffer,
         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
             ValueRange reduc) {
           const Dimension dimRank = srcTp.getRank();
@@ -414,10 +424,14 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
           builder.create<sparse_tensor::YieldOp>(loc, t);
         });
-    auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
-    auto converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
-    rewriter.create<DeallocTensorOp>(loc, t);
-    rewriter.replaceOp(op, converted);
+
+    Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+    if (bufferTp != dstTp) {
+      Value converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
+      rewriter.create<DeallocTensorOp>(loc, t);
+      t = converted;
+    }
+    rewriter.replaceOp(op, t);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 6eb754db9ce3d..49eee201fc323 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -61,8 +61,8 @@
 // CHECK-RWT:           scf.yield %[[NT:.*]]
 // CHECK-RWT:         }
 // CHECK-RWT:         %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT:         %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT:         return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK-RWT-NOT:     sparse_tensor.convert
+// CHECK-RWT:         return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 //
 func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
   %0 = tensor.expand_shape %arg0 [[0, 1]] :
@@ -134,8 +134,8 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
 // CHECK-RWT            scf.yield %[[RET_1]]
 // CHECK-RWT:         }
 // CHECK-RWT:        %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT:        %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT:        return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK-RWT-NOT:    sparse_tensor.convert
+// CHECK-RWT:        return %[[NT1]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
 //
 func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
   %0 = tensor.collapse_shape %arg0 [[0, 1]] :
@@ -209,8 +209,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 // CHECK-RWT:           scf.yield %[[NT]]
 // CHECK-RWT:         }
 // CHECK-RWT:         %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT:         %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT:         return %[[T]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK-RWT-NOT:     sparse_tensor.convert
+// CHECK-RWT:         return %[[NT1]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 //
 func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
   %0 = tensor.expand_shape %arg0 [[0, 1]] :
@@ -291,8 +291,8 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
 // CHECK-RWT            scf.yield %[[RET_1]]
 // CHECK-RWT:        }
 // CHECK-RWT:        %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT:        %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT:        return %[[T]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK-RWT-NOT:    sparse_tensor.convert
+// CHECK-RWT:        return %[[NT1]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
 //
 func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> {
   %0 = tensor.collapse_shape %arg0 [[0, 1]] :


        


More information about the Mlir-commits mailing list