[Mlir-commits] [mlir] fd2211d - use heap memory for position buffer allocated for PackOp.
Peiming Liu
llvmlistbot at llvm.org
Thu Apr 20 13:26:08 PDT 2023
Author: Peiming Liu
Date: 2023-04-20T20:26:01Z
New Revision: fd2211d84a071633d007aac90d2ecdf0d990091c
URL: https://github.com/llvm/llvm-project/commit/fd2211d84a071633d007aac90d2ecdf0d990091c
DIFF: https://github.com/llvm/llvm-project/commit/fd2211d84a071633d007aac90d2ecdf0d990091c.diff
LOG: use heap memory for position buffer allocated for PackOp.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D148818
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9cc3967b6f293..eea58f91b583c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -123,7 +123,7 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
let hasVerifier = 1;
}
-def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
+def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs 1DTensorOf<[AnyType]>:$values,
2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 55f4419df53d0..4e1e66d8bc0f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -827,7 +827,7 @@ class SparseTensorDeallocConverter
}
private:
- bool createDeallocs;
+ const bool createDeallocs;
};
/// Sparse codegen rule for tensor rematerialization.
@@ -1343,29 +1343,23 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
break;
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
- // By creating a constant value for it, we avoid the complexity of
- // memory management.
const auto posTp = stt.getPosType();
if (isCompressedDLT(dlt)) {
- RankedTensorType tensorType;
- SmallVector<Attribute> posAttr;
- tensorType = RankedTensorType::get({batchedCount + 1}, posTp);
- posAttr.push_back(IntegerAttr::get(posTp, 0));
- for (unsigned i = 0; i < batchedCount; i++) {
+ auto memrefType = MemRefType::get({batchedCount + 1}, posTp);
+ field = rewriter.create<memref::AllocOp>(loc, memrefType);
+ Value c0 = constantIndex(rewriter, loc, 0);
+ genStore(rewriter, loc, c0, field, c0);
+ for (unsigned i = 1; i <= batchedCount; i++) {
// The postion memref will have values as
// [0, nse, 2 * nse, ..., batchedCount * nse]
- posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1)));
+ Value idx = constantIndex(rewriter, loc, i);
+ Value val = constantIndex(rewriter, loc, nse * i);
+ genStore(rewriter, loc, val, field, idx);
}
- MemRefType memrefType = MemRefType::get(
- tensorType.getShape(), tensorType.getElementType());
- auto cstPtr = rewriter.create<arith::ConstantOp>(
- loc, tensorType, DenseElementsAttr::get(tensorType, posAttr));
- field = rewriter.create<bufferization::ToMemrefOp>(
- loc, memrefType, cstPtr);
} else {
assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
- field = rewriter.create<memref::AllocaOp>(loc, posMemTp);
+ field = rewriter.create<memref::AllocOp>(loc, posMemTp);
populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
field, nse, op);
}
@@ -1430,6 +1424,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
using OpConversionPattern::OpConversionPattern;
+ SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
+ bool createDeallocs)
+ : OpConversionPattern(typeConverter, context),
+ createDeallocs(createDeallocs) {}
+
LogicalResult
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -1443,6 +1442,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
: desc.getAOSMemRef();
Value valuesBuf = desc.getValMemRef();
+ Value posBuf = desc.getPosMemRef(0);
+ if (createDeallocs) {
+ // Unpack ends the lifetime of the sparse tensor. While the value array
+ // and coordinate array are unpacked and returned, the position array
+ // becomes useless and need to be freed (if user requests).
+ rewriter.create<memref::DeallocOp>(loc, posBuf);
+ }
// If frontend requests a static buffer, we reallocate the
// values/coordinates to ensure that we meet their need.
@@ -1474,6 +1480,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
rewriter.replaceOp(op, {values, coordinates, nse});
return success();
}
+
+private:
+ const bool createDeallocs;
};
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
@@ -1627,11 +1636,11 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) {
- patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
- SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
- SparseCastConverter, SparseExtractSliceConverter,
- SparseTensorLoadConverter, SparseExpandConverter,
- SparseCompressConverter, SparseInsertConverter,
+ patterns.add<SparsePackOpConverter, SparseReturnConverter,
+ SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
+ SparseExtractSliceConverter, SparseTensorLoadConverter,
+ SparseExpandConverter, SparseCompressConverter,
+ SparseInsertConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1641,7 +1650,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseConvertConverter, SparseNewOpConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
- patterns.add<SparseTensorDeallocConverter>(
+ patterns.add<SparseTensorDeallocConverter, SparseUnpackOpConverter>(
typeConverter, patterns.getContext(), createSparseDeallocs);
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 99befbeb2f1a5..4648cb3bf2983 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -7,26 +7,29 @@
// CHECK-LABEL: func.func @sparse_pack(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2xi32>) -> (memref<?xindex>, memref<?xi32>, memref<?xf64>,
-// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 6]> : tensor<2xindex>
-// CHECK: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2xindex>
-// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
-// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
-// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
-// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
-// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
-// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
-// CHECK: %[[VAL_11:.*]] = arith.constant 6 : index
-// CHECK: %[[VAL_12:.*]] = arith.constant 100 : index
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] lvl_sz at 0 with %[[VAL_12]]
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2xi32>)
+// CHECK-DAG: %[[VAL_2:.*]] = memref.alloc() : memref<2xindex>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: memref.store %[[VAL_3]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<2xindex>
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 6 : index
+// CHECK-DAG: memref.store %[[VAL_5]], %[[VAL_2]]{{\[}}%[[VAL_4]]] : memref<2xindex>
+// CHECK: %[[VAL_6:.*]] = memref.cast %[[VAL_2]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
+// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_7]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<12xi32> to memref<?xi32>
+// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<6xf64> to memref<?xf64>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.init
+// CHECK: %[[VAL_13:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] lvl_sz at 0 with %[[VAL_13]]
// CHECK: %[[VAL_15:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_15]]
-// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] crd_mem_sz at 0 with %[[VAL_11]]
-// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] lvl_sz at 1 with %[[VAL_12]]
-// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] crd_mem_sz at 1 with %[[VAL_11]]
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_11]]
-// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_15]]
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_16]] crd_mem_sz at 0 with %[[VAL_5]]
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] lvl_sz at 1 with %[[VAL_13]]
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] crd_mem_sz at 1 with %[[VAL_5]]
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] val_mem_sz with %[[VAL_5]]
+// CHECK: return %[[VAL_6]], %[[VAL_9]], %[[VAL_11]], %[[VAL_20]]
// CHECK: }
func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
-> tensor<100x100xf64, #COO> {
@@ -39,9 +42,10 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>)
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-SAME: %[[VAL_3:.*]]
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG: memref.dealloc %[[VAL_0]] : memref<?xindex>
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
// CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index ef27050eab32c..b3ba3529f1d0a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -98,7 +98,6 @@ module {
vector.print %v: f64
}
-
%d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
to tensor<3xf64>, tensor<3x2xi32>, i32
@@ -115,6 +114,8 @@ module {
// CHECK-NEXT: 3
vector.print %n : i32
+ %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO>
+ to tensor<3xf64>, tensor<3x2xindex>, index
return
}
}
More information about the Mlir-commits
mailing list