[Mlir-commits] [mlir] 81e3079 - [mlir][sparse] Replace sparse_tensor.sort with sparse_tensor.sort_coo for sorting COO tensors.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 5 15:43:02 PST 2023
Author: bixia1
Date: 2023-01-05T15:42:57-08:00
New Revision: 81e3079d0f0cc407f4e295417c3d9c8b6203e736
URL: https://github.com/llvm/llvm-project/commit/81e3079d0f0cc407f4e295417c3d9c8b6203e736
DIFF: https://github.com/llvm/llvm-project/commit/81e3079d0f0cc407f4e295417c3d9c8b6203e736.diff
LOG: [mlir][sparse] Replace sparse_tensor.sort with sparse_tensor.sort_coo for sorting COO tensors.
Add codegen pattern for sparse_tensor.indices_buffer.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140871
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 8db2fb6ba1751..0ffcfd8190ae2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -937,6 +937,26 @@ class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
}
};
+/// Sparse codegen rule for accessing the linear indices buffer.
+class SparseToIndicesBufferConverter
+ : public OpConversionPattern<ToIndicesBufferOp> {
+public:
+ using OpAdaptor = typename ToIndicesBufferOp::Adaptor;
+ using OpConversionPattern<ToIndicesBufferOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ToIndicesBufferOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Replace the requested pointer access with corresponding field.
+ // The cast_op is inserted by type converter to intermix 1:N type
+ // conversion.
+ SmallVector<Value> fields;
+ auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ rewriter.replaceOp(op, desc.getAOSMemRef());
+
+ return success();
+ }
+};
+
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
@@ -1005,9 +1025,9 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseToPointersConverter, SparseToIndicesConverter,
- SparseToValuesConverter, SparseConvertConverter,
- SparseNumberOfEntriesConverter>(typeConverter,
- patterns.getContext());
+ SparseToIndicesBufferConverter, SparseToValuesConverter,
+ SparseConvertConverter, SparseNumberOfEntriesConverter>(
+ typeConverter, patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 348996cbc3faf..d82d1e4b16e00 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -152,7 +152,8 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
// TODO: The dim level property of the COO type relies on input tensors, the
// shape relies on the output tensor
// Helpers to setup a COO type.
-static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
+static RankedTensorType
+getUnorderedCOOFromTypeWithOrdering(RankedTensorType src, AffineMap ordering) {
auto *ctx = src.getContext();
auto rank = src.getRank();
SmallVector<DimLevelType> dims;
@@ -176,12 +177,16 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
// default value.
unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0;
unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0;
- auto enc = SparseTensorEncodingAttr::get(
- ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(),
- pointerBitWidth, indexBitWidth);
+ auto enc = SparseTensorEncodingAttr::get(ctx, dims, ordering, AffineMap(),
+ pointerBitWidth, indexBitWidth);
return RankedTensorType::get(src.getShape(), src.getElementType(), enc);
}
+static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
+ return getUnorderedCOOFromTypeWithOrdering(
+ src, AffineMap::getMultiDimIdentityMap(src.getRank(), src.getContext()));
+}
+
/// Collects the dynamic dimension sizes for `tp` with the assumption that
/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
/// sizes to dynSizes.
@@ -771,6 +776,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
+ int64_t rank = dstTp.getRank();
SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, src);
@@ -788,16 +794,21 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// the overhead types.
SmallVector<Value> dynSrcSizes;
getDynamicSizes(srcTp, srcSizes, dynSrcSizes);
- srcTp = getUnorderedCOOFromType(srcTp);
+ srcTp =
+ getUnorderedCOOFromTypeWithOrdering(srcTp, encDst.getDimOrdering());
tmpCoo =
rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, tmpCoo,
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
ValueRange reduc) {
- // The resulting COO tensor has identity ordering.
- auto t = builder.create<InsertOp>(loc, v, reduc.front(),
- args.slice(0, srcTp.getRank()));
+ SmallVector<Value> dstIndices(srcTp.getRank(), Value());
+ for (int64_t i = 0; i < rank; i++) {
+ uint64_t dim = toStoredDim(encDst, i);
+ dstIndices[dim] = args[i];
+ }
+ auto t =
+ builder.create<InsertOp>(loc, v, reduc.front(), dstIndices);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
@@ -806,19 +817,6 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// Only need to sort if the srcTp is not already sorted (we faithfully take
// the guarantee from the sparse tensor encoding).
if (!isAllDimOrdered(srcTp)) {
- // Sort the COO tensor so that its elements are ordered via increasing
- // indices for the storage ordering of the dst tensor.
- SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
- uint64_t rank = dstTp.getRank();
- uint64_t cooStart = getCOOStart(encSrc);
- // Gather the indices-arrays in the dst tensor storage order.
- SmallVector<Value> xs(rank, Value());
- for (uint64_t i = 0; i < rank; i++) {
- uint64_t orgDim = toOrigDim(encSrc, i);
- xs[toStoredDim(encDst, orgDim)] =
- genToIndices(rewriter, loc, src, i, cooStart);
- }
-
// Retrieve NNZ.
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
nnz = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
@@ -826,9 +824,28 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// Retrieve the values-array.
Value y = genToValues(rewriter, loc, src);
-
- // Sort the COO tensor.
- rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
+ SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
+ // Sort the COO tensor so that its elements are ordered via increasing
+ // indices for the storage ordering of the dst tensor. Use SortCoo if the
+ // COO tensor has the same dim ordering as the dst tensor.
+ if (rank > 1 && hasSameDimOrdering(srcTp, dstTp)) {
+ MemRefType indTp =
+ get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
+ /*withLayout=*/false);
+ Value xs = rewriter.create<ToIndicesBufferOp>(loc, indTp, src);
+ rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y},
+ rewriter.getIndexAttr(rank),
+ rewriter.getIndexAttr(0));
+ } else {
+ // Gather the indices-arrays in the dst tensor storage order.
+ SmallVector<Value> xs(rank, Value());
+ for (uint64_t i = 0; i < rank; i++) {
+ uint64_t orgDim = toOrigDim(encSrc, i);
+ xs[toStoredDim(encDst, orgDim)] =
+ genToIndices(rewriter, loc, src, i, /*cooStart=*/0);
+ }
+ rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
+ }
}
// For each element in the COO tensor, insert the element to the dst tensor.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index e42708d6dee18..be7e44a54c678 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -390,6 +390,13 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
idxDim);
}
+
+ Value getAOSMemRef() const {
+ auto enc = getSparseTensorEncoding(rType);
+ unsigned cooStart = getCOOStart(enc);
+ assert(cooStart < enc.getDimLevelType().size());
+ return getIdxMemRef(cooStart);
+ }
};
class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 9f1a9fe08fe2d..652923ea22d07 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -270,6 +270,19 @@ func.func @sparse_indices_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex
return %0 : memref<?xindex, strided<[?], offset: ?>>
}
+// CHECK-LABEL: func.func @sparse_indices_buffer_coo(
+// CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
+// CHECK: return %[[A3]] : memref<?xindex>
+func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex> {
+ %0 = sparse_tensor.indices_buffer %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xindex>
+ return %0 : memref<?xindex>
+}
+
// CHECK-LABEL: func @sparse_noe(
// CHECK-SAME: %[[A0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 436b67661a02f..4eb16a6ba3f82 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -122,11 +122,10 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-RWT: sparse_tensor.yield %[[IFR]]
// CHECK-RWT: }
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T2]] hasInserts
-// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
-// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
-// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]]
+// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]]
+// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor
@@ -182,11 +181,10 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-RWT: sparse_tensor.yield %[[L0T2]]
// CHECK-RWT: }
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts
-// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
-// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
-// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]]
+// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]]
+// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor
More information about the Mlir-commits
mailing list