[Mlir-commits] [mlir] [mlir][sparse] support CSR/BSR conversion (PR #69800)

Peiming Liu llvmlistbot at llvm.org
Fri Oct 20 17:06:43 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69800

None

>From 8dd2c0aa81c6587f934f4e8040fd864323ee6116 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Sat, 21 Oct 2023 00:06:03 +0000
Subject: [PATCH] [mlir][sparse] support CSR/BSR conversion

---
 .../SparseTensor/IR/SparseTensorDialect.cpp    |  5 -----
 .../SparseTensor/Transforms/LoopEmitter.cpp    | 18 +++++++++++-------
 .../Transforms/SparseTensorRewriting.cpp       |  4 ++--
 3 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fe2eaebbfa45d5a..56214c2b41c387b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -847,11 +847,6 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
                                                            AffineMap lvlPerm,
                                                            bool ordered) {
   const SparseTensorType src(rtt);
-  // TODO: This assertion is to match the behavior from before we merged
-  // dimOrdering and higherOrdering into dimToLvl.  However, there's no
-  // in-principle reason to require this.  (wrengr has a commit in the
-  // wings to fix this.)
-  assert(src.isPermutation());
   const Level lvlRank = src.getLvlRank();
   SmallVector<DimLevelType> lvlTypes;
   lvlTypes.reserve(lvlRank);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 2d4f40eceba3b0d..de7ad3c91d793a0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -422,10 +422,19 @@ void LoopEmitter::initializeLoopEmit(
     // FIXME: the definition of `lvlRank` looks more like a dim-rank;
     // but the variable is used as a level everywhere below, which
     // suggests there may be some dim/lvl confusion going on here.
-    const Level lvlRank = rtp.getRank();
+    auto stt = getSparseTensorType(tensor);
+    const Level lvlRank = stt.getLvlRank();
     const auto shape = rtp.getShape();
     const auto enc = getSparseTensorEncoding(rtp);
     const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
+
+    SmallVector<Value> dimSz;
+    for (Dimension d = 0; d < stt.getDimRank(); d++)
+      dimSz.push_back(linalg::createOrFoldDimOp(builder, loc, tensor, d));
+
+    ValueRange lvlSzs =
+        enc.translateCrds(builder, loc, dimSz, CrdTransDirectionKind::dim2lvl);
+
     // Scan all levels of current tensor.
     for (Level l = 0; l < lvlRank; l++) {
       // This should be called only once at beginning.
@@ -447,13 +456,8 @@ void LoopEmitter::initializeLoopEmit(
         assert(isDenseDLT(lvlTp));
       }
 
-      // FIXME: `toOrigDim` is deprecated.  For now this relies on the
-      // 1:1 mapping between levels and dimensions, since nowhere else
-      // in the code supports non-permutations yet either.
-      Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
-                                                    toOrigDim(enc, l));
       // Find upper bound in current dimension.
-      highs[t][l] = lvlSizes[t][l] = lvlSz;
+      highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
       if (isSparseSlices[t]) {
         sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
         sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c168f10dadbbfb5..5ce81f932faaf3f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1069,7 +1069,6 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     Value input = op.getTensor();
     SmallVector<Value> reduc = op.getInitArgs();
     const auto stt = getSparseTensorType(input);
-    const Dimension dimRank = stt.getDimRank();
     const Level lvlRank = stt.getLvlRank();
 
     // Special-case: for each over a sparse constant uses its own rewriting
@@ -1103,6 +1102,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
     if (op.getOrder()) {
       // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
+      const Dimension dimRank = stt.getDimRank();
       SmallVector<Value> dcvs = lcvs; // keep a copy
       for (Dimension d = 0; d < dimRank; d++) {
         auto l = op.getOrder()->getDimPosition(d);
@@ -1143,7 +1143,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
                                  args);
     }
 
-    for (Dimension d = 0; d < dimRank; d++) {
+    for (Level l = 0; l < lvlRank; l++) {
       // Link the reduction chain. Note that loop emitter update the reducValue
       // in place.
       loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);



More information about the Mlir-commits mailing list