[Mlir-commits] [mlir] c24547e - [mlir][sparse] avoid creating temporary unordered COO buffer when reshape sparse tensor.
Peiming Liu
llvmlistbot at llvm.org
Wed Mar 29 18:30:01 PDT 2023
Author: Peiming Liu
Date: 2023-03-30T01:29:55Z
New Revision: c24547e969ef183971c3a02cb8cbf151eb529715
URL: https://github.com/llvm/llvm-project/commit/c24547e969ef183971c3a02cb8cbf151eb529715
DIFF: https://github.com/llvm/llvm-project/commit/c24547e969ef183971c3a02cb8cbf151eb529715.diff
LOG: [mlir][sparse] avoid creating temporary unordered COO buffer when reshape sparse tensor.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D147192
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7827573b215a1..dc5755bcf5012 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -358,9 +358,12 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
Value srcTensor = op.getSrc();
auto srcTp = getRankedTensorType(srcTensor);
auto dstTp = getRankedTensorType(op.getResult());
- SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
- SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
- if (!encDst || !encSrc) {
+
+ SparseTensorType srcStt(srcTp);
+ SparseTensorType dstStt(dstTp);
+
+ const auto encSrc = srcStt.getEncoding();
+ if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) {
return failure();
}
@@ -382,22 +385,29 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
dstDynSizes.push_back(dstSizes[idx]);
}
}
-
- // Implement the sparse2sparse reshape as follows:
- // %tmp = bufferization.alloc_tensor : unordered COO
- // foreach srcCoords %srcTensor
- // insert reshapeCvs(srcCoords), %tmp
- // %t = sparse_tensor.cast %tmp
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
- RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
- Value cooBuffer =
+ // Only need a unordered COO buffer if input and output are not sorted
+ // in the same way.
+ Type bufferTp =
+ srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity()
+ ? dstTp
+ : getUnorderedCOOFromType(dstTp);
+
+ Value buffer =
rewriter
- .create<AllocTensorOp>(loc, cooTp, dstDynSizes, Value(),
+ .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
/*sizeHint=*/nnz, Attribute())
.getResult();
+ // Implement the sparse2sparse reshape as follows:
+ // foreach srcCoords %srcTensor
+ // insert reshapeCvs(srcCoords), %buffer
+ //
+ // followed by an optional
+ // %t = sparse_tensor.cast %tmp
+ // depending on whether the input/output are sorted in the same way.
ForeachOp foreachOp = rewriter.create<ForeachOp>(
- loc, srcTensor, cooBuffer,
+ loc, srcTensor, buffer,
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
ValueRange reduc) {
const Dimension dimRank = srcTp.getRank();
@@ -414,10 +424,14 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
- auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
- auto converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
- rewriter.create<DeallocTensorOp>(loc, t);
- rewriter.replaceOp(op, converted);
+
+ Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+ if (bufferTp != dstTp) {
+ Value converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
+ rewriter.create<DeallocTensorOp>(loc, t);
+ t = converted;
+ }
+ rewriter.replaceOp(op, t);
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 6eb754db9ce3d..49eee201fc323 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -61,8 +61,8 @@
// CHECK-RWT: scf.yield %[[NT:.*]]
// CHECK-RWT: }
// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT: return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK-RWT-NOT: sparse_tensor.convert
+// CHECK-RWT: return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
@@ -134,8 +134,8 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
// CHECK-RWT scf.yield %[[RET_1]]
// CHECK-RWT: }
// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT: return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK-RWT-NOT: sparse_tensor.convert
+// CHECK-RWT: return %[[NT1]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
//
func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] :
@@ -209,8 +209,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-RWT: scf.yield %[[NT]]
// CHECK-RWT: }
// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT: return %[[T]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK-RWT-NOT: sparse_tensor.convert
+// CHECK-RWT: return %[[NT1]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
//
func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
@@ -291,8 +291,8 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-RWT scf.yield %[[RET_1]]
// CHECK-RWT: }
// CHECK-RWT: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
-// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[NT1]]
-// CHECK-RWT: return %[[T]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK-RWT-NOT: sparse_tensor.convert
+// CHECK-RWT: return %[[NT1]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
//
func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] :
More information about the Mlir-commits
mailing list