[Mlir-commits] [mlir] dc6427d - [mlir][sparse] implement lowering rules for sparse_tensor::unpack
Peiming Liu
llvmlistbot at llvm.org
Fri Feb 10 17:05:51 PST 2023
Author: Peiming Liu
Date: 2023-02-11T01:05:46Z
New Revision: dc6427d687c4640e33c9bdab0c888b6f7d0569be
URL: https://github.com/llvm/llvm-project/commit/dc6427d687c4640e33c9bdab0c888b6f7d0569be
DIFF: https://github.com/llvm/llvm-project/commit/dc6427d687c4640e33c9bdab0c888b6f7d0569be.diff
LOG: [mlir][sparse] implement lowering rules for sparse_tensor::unpack
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143672
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6a5a85fb8271..641b429f300a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -143,6 +143,37 @@ struct PackOpInterface
// ones when packing into a COO format.
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
+
+ BufferRelation bufferRelation(Operation *oo, OpResult opResult,
+ const AnalysisState &state) const {
+ return BufferRelation::Unknown;
+ }
+};
+
+struct UnpackOpInterface
+ : public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
+ sparse_tensor::UnpackOp> {
+ bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
+ // Similar to InsertOp, reallocation is not considered to allocate a new
+ // piece of memory.
+ return false;
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Conceptually, UnpackOp equals to a list of toIndices/toValueOp
+ return {};
+ }
};
struct InsertOpInterface
@@ -285,6 +316,8 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
sparse_tensor::NumberOfEntriesOp::attachInterface<
NumberOfEntriesOpInterface>(*ctx);
+ sparse_tensor::PackOp::attachInterface<PackOpInterface>(*ctx);
+ sparse_tensor::UnpackOp::attachInterface<UnpackOpInterface>(*ctx);
sparse_tensor::ToIndicesBufferOp::attachInterface<
ToIndicesBufferOpInterface>(*ctx);
sparse_tensor::ToIndicesOp::attachInterface<ToIndicesOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index b2541c38a30b..17319b7ffa2a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1154,6 +1154,53 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
};
+struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ Location loc = op.getLoc();
+ int64_t rank = op.getTensor().getType().getRank();
+
+ assert(isUniqueCOOType(op.getTensor().getType()) &&
+ desc.getFields().size() == 4);
+
+ Value flatBuf = rank == 1 ? desc.getIdxMemRefOrView(rewriter, loc, 0)
+ : desc.getAOSMemRef();
+ Value dataBuf = desc.getValMemRef();
+
+ // If frontend requests a static buffer, we reallocate the data/indices
+ // to ensure that we meet their need.
+ TensorType dataTp = op.getData().getType();
+ if (dataTp.hasStaticShape()) {
+ dataBuf = rewriter.create<memref::ReallocOp>(
+ loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
+ dataBuf);
+ }
+
+ TensorType indicesTp = op.getIndices().getType();
+ if (indicesTp.hasStaticShape()) {
+ auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
+ flatBuf = rewriter.create<memref::ReallocOp>(
+ loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
+ }
+
+ Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
+ loc, MemRefType::get(indicesTp.getShape(), indicesTp.getElementType()),
+ flatBuf, ArrayRef{ReassociationIndices{0, 1}});
+
+ // Converts MemRefs back to Tensors.
+ Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
+ Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
+ Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc),
+ op.getNnz().getType());
+
+ rewriter.replaceOp(op, {data, indices, nnz});
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1165,15 +1212,16 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool enableBufferInitialization) {
- patterns.add<SparsePackOpConverter, SparseReturnConverter,
- SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
- SparseTensorDeallocConverter, SparseExtractSliceCoverter,
- SparseTensorLoadConverter, SparseExpandConverter,
- SparseCompressConverter, SparseInsertConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToIndicesBufferConverter, SparseToValuesConverter,
- SparseConvertConverter, SparseNumberOfEntriesConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
+ SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+ SparseCastConverter, SparseTensorDeallocConverter,
+ SparseExtractSliceCoverter, SparseTensorLoadConverter,
+ SparseExpandConverter, SparseCompressConverter,
+ SparseInsertConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToIndicesBufferConverter,
+ SparseToValuesConverter, SparseConvertConverter,
+ SparseNumberOfEntriesConverter>(typeConverter,
+ patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 6dbf258c7e7b..eeb41fe7d00a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -37,3 +37,23 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
to tensor<100x100xf64, #COO>
return %0 : tensor<100x100xf64, #COO>
}
+
+// CHECK-LABEL: func.func @sparse_unpack(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
+// CHECK: %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
+// CHECK: %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
+// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64>
+// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] val_mem_sz
+// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index
+// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK: }
+func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
+ %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
+ to tensor<6xf64>, tensor<6x2xi32>, index
+ return %d, %i, %nnz : tensor<6xf64>, tensor<6x2xi32>, index
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 35d9cc6409d6..b20c5330397e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -36,6 +36,9 @@ module {
// Main driver.
//
func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f64
+ %i0 = arith.constant 0 : i32
//
// Initialize a 3-dim dense tensor.
//
@@ -95,6 +98,23 @@ module {
vector.print %v: f64
}
+
+ %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
+ to tensor<3xf64>, tensor<3x2xi32>, i32
+
+
+
+ // CHECK-NEXT: ( 1, 2, 3 )
+ %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
+ vector.print %vd : vector<3xf64>
+
+ // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) )
+ %vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
+ vector.print %vi : vector<3x2xi32>
+
+ // CHECK-NEXT: 3
+ vector.print %n : i32
+
return
}
}
More information about the Mlir-commits
mailing list