[Mlir-commits] [mlir] 1e0966c - [mlir][sparse] add util for ToCoordinatesBuffer for COO AoS

Aart Bik llvmlistbot at llvm.org
Thu May 11 10:43:43 PDT 2023


Author: Aart Bik
Date: 2023-05-11T10:43:31-07:00
New Revision: 1e0966cd6c80dffb084692f265377ef8fa5fbe96

URL: https://github.com/llvm/llvm-project/commit/1e0966cd6c80dffb084692f265377ef8fa5fbe96
DIFF: https://github.com/llvm/llvm-project/commit/1e0966cd6c80dffb084692f265377ef8fa5fbe96.diff

LOG: [mlir][sparse] add util for ToCoordinatesBuffer for COO AoS

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D150382

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 54a6786231e90..6fd55c7799306 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -679,6 +679,14 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
                                          builder.getIndexAttr(lvl));
 }
 
+Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
+                                            Value tensor) {
+  const auto srcTp = getSparseTensorType(tensor);
+  const Type crdTp = srcTp.getEncoding().getCrdType();
+  const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
+  return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
+}
+
 Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
                                  Value tensor) {
   RankedTensorType srcTp = getRankedTensorType(tensor);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 28acf1aed7de2..e04475ea2e8f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -364,6 +364,9 @@ Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
 Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
                        Level lvl, Level cooStart);
 
+/// Infers the result type and generates `ToCoordinatesBufferOp`.
+Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
+
 /// Infers the result type and generates `ToValuesOp`.
 Value genToValues(OpBuilder &builder, Location loc, Value tensor);
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 4aba829d7dbdd..2a4bbb06eb507 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -895,9 +895,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
       // coordinates for the storage ordering of the dst tensor.  Use SortCoo
       // if the COO tensor has the same ordering as the dst tensor.
       if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
-        MemRefType coordsTp =
-            get1DMemRefType(encSrc.getCrdType(), /*withLayout=*/false);
-        Value xs = rewriter.create<ToCoordinatesBufferOp>(loc, coordsTp, src);
+        Value xs = genToCoordinatesBuffer(rewriter, loc, src);
         rewriter.create<SortCooOp>(
             loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
             rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);


        


More information about the Mlir-commits mailing list