[Mlir-commits] [mlir] 1e0966c - [mlir][sparse] add util for ToCoordinatesBuffer for COO AoS
Aart Bik
llvmlistbot at llvm.org
Thu May 11 10:43:43 PDT 2023
Author: Aart Bik
Date: 2023-05-11T10:43:31-07:00
New Revision: 1e0966cd6c80dffb084692f265377ef8fa5fbe96
URL: https://github.com/llvm/llvm-project/commit/1e0966cd6c80dffb084692f265377ef8fa5fbe96
DIFF: https://github.com/llvm/llvm-project/commit/1e0966cd6c80dffb084692f265377ef8fa5fbe96.diff
LOG: [mlir][sparse] add util for ToCoordinatesBuffer for COO AoS
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D150382
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 54a6786231e90..6fd55c7799306 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -679,6 +679,14 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
builder.getIndexAttr(lvl));
}
+Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
+ Value tensor) {
+ const auto srcTp = getSparseTensorType(tensor);
+ const Type crdTp = srcTp.getEncoding().getCrdType();
+ const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
+ return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
+}
+
Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Value tensor) {
RankedTensorType srcTp = getRankedTensorType(tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 28acf1aed7de2..e04475ea2e8f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -364,6 +364,9 @@ Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
Level lvl, Level cooStart);
+/// Infers the result type and generates `ToCoordinatesBufferOp`.
+Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
+
/// Infers the result type and generates `ToValuesOp`.
Value genToValues(OpBuilder &builder, Location loc, Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 4aba829d7dbdd..2a4bbb06eb507 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -895,9 +895,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// coordinates for the storage ordering of the dst tensor. Use SortCoo
// if the COO tensor has the same ordering as the dst tensor.
if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
- MemRefType coordsTp =
- get1DMemRefType(encSrc.getCrdType(), /*withLayout=*/false);
- Value xs = rewriter.create<ToCoordinatesBufferOp>(loc, coordsTp, src);
+ Value xs = genToCoordinatesBuffer(rewriter, loc, src);
rewriter.create<SortCooOp>(
loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
More information about the Mlir-commits
mailing list