[Mlir-commits] [mlir] 11705af - [mlir][sparse] deallocate tmp coo buffer generated during stage-spars… (#82017)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 17 12:18:01 PST 2024


Author: Peiming Liu
Date: 2024-02-17T12:17:57-08:00
New Revision: 11705afc19383dedfb06c3b708d6fe8c0729b807

URL: https://github.com/llvm/llvm-project/commit/11705afc19383dedfb06c3b708d6fe8c0729b807
DIFF: https://github.com/llvm/llvm-project/commit/11705afc19383dedfb06c3b708d6fe8c0729b807.diff

LOG: [mlir][sparse] deallocate tmp coo buffer generated during stage-spars… (#82017)

…e-ops pass.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
    mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index ebbc522123a599..c0f31762ee071f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -1,5 +1,4 @@
-//===- SparseTensorInterfaces.h - sparse tensor operations
-//interfaces-------===//
+//===- SparseTensorInterfaces.h - sparse tensor operations interfaces------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -20,7 +19,7 @@ class StageWithSortSparseOp;
 
 namespace detail {
 LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
-                                PatternRewriter &rewriter);
+                                PatternRewriter &rewriter, Value &tmpBufs);
 } // namespace detail
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 1379363ff75f42..05eed0483f2c8a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -34,9 +34,10 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
     /*desc=*/"Stage the operation, return the final result value after staging.",
     /*retTy=*/"::mlir::LogicalResult",
     /*methodName=*/"stageWithSort",
-    /*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
+    /*args=*/(ins "::mlir::PatternRewriter &":$rewriter,
+                  "Value &":$tmpBuf),
     /*methodBody=*/[{
-        return detail::stageWithSortImpl($_op, rewriter);
+        return detail::stageWithSortImpl($_op, rewriter, tmpBuf);
     }]>,
   ];
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d33eb9d2877ae3..4f9988d48d7710 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -16,9 +16,14 @@ using namespace mlir::sparse_tensor;
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
 
-LogicalResult
-sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
-                                         PatternRewriter &rewriter) {
+/// Stage the operations into a sequence of simple operations as follow:
+/// op -> unsorted_coo +
+/// unsorted_coo -> sorted_coo +
+/// sorted_coo -> dstTp.
+///
+/// return `tmpBuf` if a intermediate memory is allocated.
+LogicalResult sparse_tensor::detail::stageWithSortImpl(
+    StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
   if (!op.needsExtraSort())
     return failure();
 
@@ -44,9 +49,15 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
     rewriter.replaceOp(op, dstCOO);
   } else {
     // Need an extra conversion if the target type is not COO.
-    rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+    auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+    rewriter.setInsertionPointAfter(c);
+    // Informs the caller about the intermediate buffer we allocated. We can not
+    // create a bufferization::DeallocateTensorOp here because it would
+    // introduce cyclic dependency between the SparseTensorDialect and the
+    // BufferizationDialect. Besides, whether the buffer need to be deallocated
+    // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
+    tmpBufs = dstCOO;
   }
-  // TODO: deallocate extra COOs, we should probably delegate it to buffer
-  // deallocation pass.
+
   return success();
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 5875cd4f9fd9d1..5b4395cc31a46b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -21,8 +22,16 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
 
   LogicalResult matchAndRewrite(StageWithSortOp op,
                                 PatternRewriter &rewriter) const override {
-    return llvm::cast<StageWithSortSparseOp>(op.getOperation())
-        .stageWithSort(rewriter);
+    Location loc = op.getLoc();
+    Value tmpBuf = nullptr;
+    auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
+    LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
+    // Deallocate tmpBuf.
+    // TODO: Delegate to buffer deallocation pass in the future.
+    if (succeeded(stageResult) && tmpBuf)
+      rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
+
+    return stageResult;
   }
 };
 } // namespace

diff  --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 96a1140372bd6c..83dbc9568c7a36 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -82,10 +82,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
 // CHECK:             scf.if
 // CHECK:               tensor.insert
 // CHECK:           sparse_tensor.load
-// CHECK:           sparse_tensor.reorder_coo
+// CHECK:           %[[TMP:.*]] = sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.foreach
 // CHECK:             tensor.insert
 // CHECK:           sparse_tensor.load
+// CHECK:           bufferization.dealloc_tensor %[[TMP]]
 func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
   return %0 : tensor<?x?x?xf64, #SparseTensor>


        


More information about the Mlir-commits mailing list