[Mlir-commits] [mlir] [mlir][sparse] deallocate tmp coo buffer generated during stage-spars… (PR #82017)
Peiming Liu
llvmlistbot at llvm.org
Fri Feb 16 12:52:28 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/82017
>From d8075feb4e44a745e262386149ee60e7c962612f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 16 Feb 2024 17:51:15 +0000
Subject: [PATCH 1/2] [mlir][sparse] deallocate tmp coo buffer generated during
stage-sparse-ops pass.
---
.../SparseTensor/IR/SparseTensorInterfaces.h | 5 ++---
.../SparseTensor/IR/SparseTensorInterfaces.td | 5 +++--
.../SparseTensor/IR/SparseTensorInterfaces.cpp | 17 +++++++++++------
.../Transforms/StageSparseOperations.cpp | 13 +++++++++++--
.../SparseTensor/convert_dense2sparse.mlir | 3 ++-
5 files changed, 29 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index ebbc522123a599..c0f31762ee071f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -1,5 +1,4 @@
-//===- SparseTensorInterfaces.h - sparse tensor operations
-//interfaces-------===//
+//===- SparseTensorInterfaces.h - sparse tensor operations interfaces------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -20,7 +19,7 @@ class StageWithSortSparseOp;
namespace detail {
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
- PatternRewriter &rewriter);
+ PatternRewriter &rewriter, Value &tmpBufs);
} // namespace detail
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 1379363ff75f42..05eed0483f2c8a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -34,9 +34,10 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
/*desc=*/"Stage the operation, return the final result value after staging.",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"stageWithSort",
- /*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
+ /*args=*/(ins "::mlir::PatternRewriter &":$rewriter,
+ "Value &":$tmpBuf),
/*methodBody=*/[{
- return detail::stageWithSortImpl($_op, rewriter);
+ return detail::stageWithSortImpl($_op, rewriter, tmpBuf);
}]>,
];
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d33eb9d2877ae3..4866971af08e7d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -16,9 +16,8 @@ using namespace mlir::sparse_tensor;
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
-LogicalResult
-sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
- PatternRewriter &rewriter) {
+LogicalResult sparse_tensor::detail::stageWithSortImpl(
+ StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
if (!op.needsExtraSort())
return failure();
@@ -44,9 +43,15 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
rewriter.replaceOp(op, dstCOO);
} else {
// Need an extra conversion if the target type is not COO.
- rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+ auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+ rewriter.setInsertionPointAfter(c);
+ // Informs the caller about the intermediate buffer we allocated. We can not
+ // create a bufferization::DeallocateTensorOp here because it would
+ // introduce cyclic dependency between the SparseTensorDialect and the
+ // BufferizationDialect. Besides, whether the buffer need to be deallocated
+ // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
+ tmpBufs = dstCOO;
}
- // TODO: deallocate extra COOs, we should probably delegate it to buffer
- // deallocation pass.
+
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 5875cd4f9fd9d1..992f4faafc099b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -21,8 +22,16 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
LogicalResult matchAndRewrite(StageWithSortOp op,
PatternRewriter &rewriter) const override {
- return llvm::cast<StageWithSortSparseOp>(op.getOperation())
- .stageWithSort(rewriter);
+ Location loc = op.getLoc();
+ Value tmpBuf = nullptr;
+ auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
+ LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
+ // Deallocate tmpBuf, maybe delegate to buffer deallocation pass in the
+ // future.
+ if (succeeded(stageResult) && tmpBuf)
+ rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
+
+ return stageResult;
}
};
} // namespace
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 96a1140372bd6c..83dbc9568c7a36 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -82,10 +82,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
// CHECK: scf.if
// CHECK: tensor.insert
// CHECK: sparse_tensor.load
-// CHECK: sparse_tensor.reorder_coo
+// CHECK: %[[TMP:.*]] = sparse_tensor.reorder_coo
// CHECK: sparse_tensor.foreach
// CHECK: tensor.insert
// CHECK: sparse_tensor.load
+// CHECK: bufferization.dealloc_tensor %[[TMP]]
func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
return %0 : tensor<?x?x?xf64, #SparseTensor>
>From 5d2fca28f2851c3804d4e5ce49bb402955ab1f8e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 16 Feb 2024 20:52:16 +0000
Subject: [PATCH 2/2] address comments
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp | 6 ++++++
.../SparseTensor/Transforms/StageSparseOperations.cpp | 4 ++--
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index 4866971af08e7d..4f9988d48d7710 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -16,6 +16,12 @@ using namespace mlir::sparse_tensor;
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
+/// Stage the operations into a sequence of simple operations as follow:
+/// op -> unsorted_coo +
+/// unsorted_coo -> sorted_coo +
+/// sorted_coo -> dstTp.
+///
+/// return `tmpBuf` if a intermediate memory is allocated.
LogicalResult sparse_tensor::detail::stageWithSortImpl(
StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
if (!op.needsExtraSort())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 992f4faafc099b..5b4395cc31a46b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -26,8 +26,8 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
Value tmpBuf = nullptr;
auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
- // Deallocate tmpBuf, maybe delegate to buffer deallocation pass in the
- // future.
+ // Deallocate tmpBuf.
+ // TODO: Delegate to buffer deallocation pass in the future.
if (succeeded(stageResult) && tmpBuf)
rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
More information about the Mlir-commits
mailing list