[Mlir-commits] [mlir] fd2211d - use heap memory for position buffer allocated for PackOp.

Peiming Liu llvmlistbot at llvm.org
Thu Apr 20 13:26:08 PDT 2023


Author: Peiming Liu
Date: 2023-04-20T20:26:01Z
New Revision: fd2211d84a071633d007aac90d2ecdf0d990091c

URL: https://github.com/llvm/llvm-project/commit/fd2211d84a071633d007aac90d2ecdf0d990091c
DIFF: https://github.com/llvm/llvm-project/commit/fd2211d84a071633d007aac90d2ecdf0d990091c.diff

LOG: use heap memory for position buffer allocated for PackOp.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D148818

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9cc3967b6f293..eea58f91b583c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -123,7 +123,7 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
   let hasVerifier = 1;
 }
 
-def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
+def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
     Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs 1DTensorOf<[AnyType]>:$values,
                   2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 55f4419df53d0..4e1e66d8bc0f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -827,7 +827,7 @@ class SparseTensorDeallocConverter
   }
 
 private:
-  bool createDeallocs;
+  const bool createDeallocs;
 };
 
 /// Sparse codegen rule for tensor rematerialization.
@@ -1343,29 +1343,23 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
             break;
           case SparseTensorFieldKind::PosMemRef: {
             // TACO-style COO starts with a PosBuffer
-            // By creating a constant value for it, we avoid the complexity of
-            // memory management.
             const auto posTp = stt.getPosType();
             if (isCompressedDLT(dlt)) {
-              RankedTensorType tensorType;
-              SmallVector<Attribute> posAttr;
-              tensorType = RankedTensorType::get({batchedCount + 1}, posTp);
-              posAttr.push_back(IntegerAttr::get(posTp, 0));
-              for (unsigned i = 0; i < batchedCount; i++) {
+              auto memrefType = MemRefType::get({batchedCount + 1}, posTp);
+              field = rewriter.create<memref::AllocOp>(loc, memrefType);
+              Value c0 = constantIndex(rewriter, loc, 0);
+              genStore(rewriter, loc, c0, field, c0);
+              for (unsigned i = 1; i <= batchedCount; i++) {
                 // The postion memref will have values as
                 // [0, nse, 2 * nse, ..., batchedCount * nse]
-                posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1)));
+                Value idx = constantIndex(rewriter, loc, i);
+                Value val = constantIndex(rewriter, loc, nse * i);
+                genStore(rewriter, loc, val, field, idx);
               }
-              MemRefType memrefType = MemRefType::get(
-                  tensorType.getShape(), tensorType.getElementType());
-              auto cstPtr = rewriter.create<arith::ConstantOp>(
-                  loc, tensorType, DenseElementsAttr::get(tensorType, posAttr));
-              field = rewriter.create<bufferization::ToMemrefOp>(
-                  loc, memrefType, cstPtr);
             } else {
               assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
               MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
-              field = rewriter.create<memref::AllocaOp>(loc, posMemTp);
+              field = rewriter.create<memref::AllocOp>(loc, posMemTp);
               populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
                                                field, nse, op);
             }
@@ -1430,6 +1424,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
 
 struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
   using OpConversionPattern::OpConversionPattern;
+  SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
+                          bool createDeallocs)
+      : OpConversionPattern(typeConverter, context),
+        createDeallocs(createDeallocs) {}
+
   LogicalResult
   matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -1443,6 +1442,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
                                  : desc.getAOSMemRef();
     Value valuesBuf = desc.getValMemRef();
+    Value posBuf = desc.getPosMemRef(0);
+    if (createDeallocs) {
+      // Unpack ends the lifetime of the sparse tensor. While the value array
+      // and coordinate array are unpacked and returned, the position array
+      // becomes useless and need to be freed (if user requests).
+      rewriter.create<memref::DeallocOp>(loc, posBuf);
+    }
 
     // If frontend requests a static buffer, we reallocate the
     // values/coordinates to ensure that we meet their need.
@@ -1474,6 +1480,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     rewriter.replaceOp(op, {values, coordinates, nse});
     return success();
   }
+
+private:
+  const bool createDeallocs;
 };
 
 struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
@@ -1627,11 +1636,11 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
 void mlir::populateSparseTensorCodegenPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     bool createSparseDeallocs, bool enableBufferInitialization) {
-  patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
-               SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
-               SparseCastConverter, SparseExtractSliceConverter,
-               SparseTensorLoadConverter, SparseExpandConverter,
-               SparseCompressConverter, SparseInsertConverter,
+  patterns.add<SparsePackOpConverter, SparseReturnConverter,
+               SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
+               SparseExtractSliceConverter, SparseTensorLoadConverter,
+               SparseExpandConverter, SparseCompressConverter,
+               SparseInsertConverter,
                SparseSliceGetterOpConverter<ToSliceOffsetOp,
                                             StorageSpecifierKind::DimOffset>,
                SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1641,7 +1650,7 @@ void mlir::populateSparseTensorCodegenPatterns(
                SparseConvertConverter, SparseNewOpConverter,
                SparseNumberOfEntriesConverter>(typeConverter,
                                                patterns.getContext());
-  patterns.add<SparseTensorDeallocConverter>(
+  patterns.add<SparseTensorDeallocConverter, SparseUnpackOpConverter>(
       typeConverter, patterns.getContext(), createSparseDeallocs);
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 99befbeb2f1a5..4648cb3bf2983 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -7,26 +7,29 @@
 
 // CHECK-LABEL:   func.func @sparse_pack(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<6xf64>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<6x2xi32>) -> (memref<?xindex>, memref<?xi32>, memref<?xf64>,
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[0, 6]> : tensor<2xindex>
-// CHECK:           %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2xindex>
-// CHECK:           %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
-// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
-// CHECK:           %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
-// CHECK:           %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
-// CHECK:           %[[VAL_11:.*]] = arith.constant 6 : index
-// CHECK:           %[[VAL_12:.*]] = arith.constant 100 : index
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  lvl_sz at 0 with %[[VAL_12]]
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<6x2xi32>)
+// CHECK-DAG:       %[[VAL_2:.*]] = memref.alloc() : memref<2xindex>
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       memref.store %[[VAL_3]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 6 : index
+// CHECK-DAG:       memref.store %[[VAL_5]], %[[VAL_2]]{{\[}}%[[VAL_4]]] : memref<2xindex>
+// CHECK:           %[[VAL_6:.*]] = memref.cast %[[VAL_2]] : memref<2xindex> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_7]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK:           %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<12xi32> to memref<?xi32>
+// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK:           %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<6xf64> to memref<?xf64>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.storage_specifier.init
+// CHECK:           %[[VAL_13:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]]  lvl_sz at 0 with %[[VAL_13]]
 // CHECK:           %[[VAL_15:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_15]]
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  crd_mem_sz at 0 with %[[VAL_11]]
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  lvl_sz at 1 with %[[VAL_12]]
-// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]]  crd_mem_sz at 1 with %[[VAL_11]]
-// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  val_mem_sz with %[[VAL_11]]
-// CHECK:           return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_15]]
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_16]]  crd_mem_sz at 0 with %[[VAL_5]]
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  lvl_sz at 1 with %[[VAL_13]]
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  crd_mem_sz at 1 with %[[VAL_5]]
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  val_mem_sz with %[[VAL_5]]
+// CHECK:           return %[[VAL_6]], %[[VAL_9]], %[[VAL_11]], %[[VAL_20]]
 // CHECK:         }
 func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
                     -> tensor<100x100xf64, #COO> {
@@ -39,9 +42,10 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
 // CHECK-SAME:      %[[VAL_0:.*]]: memref<?xindex>,
 // CHECK-SAME:      %[[VAL_1:.*]]: memref<?xi32>,
 // CHECK-SAME:      %[[VAL_2:.*]]: memref<?xf64>,
-// CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_4:.*]] = arith.constant 6 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-SAME:      %[[VAL_3:.*]]
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 6 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:       memref.dealloc %[[VAL_0]] : memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
 // CHECK:           %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
 // CHECK:           %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index ef27050eab32c..b3ba3529f1d0a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -98,7 +98,6 @@ module {
         vector.print %v: f64
      }
 
-
     %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
                                          to tensor<3xf64>, tensor<3x2xi32>, i32
 
@@ -115,6 +114,8 @@ module {
     // CHECK-NEXT: 3
     vector.print %n : i32
 
+    %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
+                                         to tensor<3xf64>, tensor<3x2xindex>, index
     return
   }
 }


        


More information about the Mlir-commits mailing list