[Mlir-commits] [mlir] [mlir][sparse] support sparsifying batch levels (PR #83898)

Peiming Liu llvmlistbot at llvm.org
Mon Mar 4 14:08:24 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/83898

>From b1a46bd05b32fbdef1b37f0dd1f22560abefb04c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 28 Feb 2024 20:05:49 +0000
Subject: [PATCH 1/2] [mlir][sparse] support sparsifying batch levels

---
 .../IR/SparseTensorStorageLayout.h            |   5 +-
 .../SparseTensor/IR/SparseTensorType.h        |   5 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |   4 +-
 .../Transforms/SparseAssembler.cpp            |   6 +-
 .../Transforms/SparseTensorCodegen.cpp        |  44 ++++--
 .../Transforms/SparseTensorRewriting.cpp      |   2 +-
 .../Transforms/Sparsification.cpp             |  23 +--
 .../Transforms/Utils/CodegenUtils.cpp         |   2 +-
 .../Transforms/Utils/CodegenUtils.h           |   2 +-
 .../Transforms/Utils/LoopEmitter.h            |   6 +-
 .../Transforms/Utils/SparseTensorLevel.cpp    | 133 ++++++++++++------
 .../Transforms/Utils/SparseTensorLevel.h      |  18 ++-
 .../Dialect/SparseTensor/sparse_batch.mlir    |  48 +++++++
 .../sparse_conv_2d_slice_based.mlir           |   4 +-
 14 files changed, 221 insertions(+), 81 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_batch.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index ce34ae43d1c181..7aa9cb6119434b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -42,7 +42,10 @@ namespace sparse_tensor {
 ///
 ///   struct sparse_tensor.storage_specifier {
 ///     array<rank x int> lvlSizes    ; sizes/cardinalities for each level
-///     array<n x int> memSizes;      ; sizes/lengths for each data memref
+///     // TODO: memSizes need to be expanded to array<[batch] x n x int> to
+///     // support different sizes for different batches. At the moment, we
+///     // assume that every batch occupies the same memory size.
+///     array<n x int> memSizes       ; sizes/lengths for each data memref
 ///   }
 /// };
 ///
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index bd2c3c1dd55159..beb1dcce9c15c0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,7 +253,10 @@ class SparseTensorType {
                                         CrdTransDirectionKind::dim2lvl);
   }
 
-  /// Returns the Level-shape.
+  /// Returns the batched level rank.
+  unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); }
+
+  /// Returns the batched Level-shape.
   SmallVector<Size> getBatchLvlShape() const {
     auto lvlShape = getEncoding().translateShape(
         getDimShape(), CrdTransDirectionKind::dim2lvl);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 244a082d04870e..6ba8b46370b038 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -374,7 +374,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
 
 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
   if (!getImpl())
-    return LevelFormat::Dense;
+    return LevelFormat::Batch;
   assert(l < getLvlRank() && "Level is out of bounds");
   return getLvlTypes()[l];
 }
@@ -1755,6 +1755,8 @@ LogicalResult ConcatenateOp::verify() {
 
 LogicalResult InsertOp::verify() {
   const auto stt = getSparseTensorType(getTensor());
+  if (stt.getEncoding().getBatchLvlRank() > 0)
+    return emitOpError("batched sparse tensor insertion not implemented");
   if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
     return emitOpError("incorrect number of coordinates");
   return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index cd6b9b49893731..b39a2d9c57d8b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,7 +33,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
     }
     // Convert the external representation of the values array.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    auto shape = {ShapedType::kDynamic};
+    auto shape = stt.getBatchLvlShape();
+    shape.push_back(ShapedType::kDynamic);
     auto vtp = RankedTensorType::get(shape, stt.getElementType());
     convTypes.push_back(vtp);
     if (extraTypes)
@@ -72,7 +73,8 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     // Convert the external representation of the values array.
     auto rtp = cast<RankedTensorType>(type);
     const SparseTensorType stt(rtp);
-    auto shape = {ShapedType::kDynamic};
+    auto shape = stt.getBatchLvlShape();
+    shape.push_back(ShapedType::kDynamic);
     SmallVector<Value> inputs;
     SmallVector<Type> retTypes;
     SmallVector<Type> cntTypes;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 4e3393195813c3..5da8a60d2d5fb0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -429,11 +429,18 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
 }
 
 /// Creates the reassociation array.
-static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
-  ReassociationIndices reassociation;
-  for (int i = 0, e = srcTp.getRank(); i < e; i++)
-    reassociation.push_back(i);
-  return reassociation;
+static SmallVector<ReassociationIndices>
+getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
+  SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
+  // Create reassociation in form:
+  // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
+  for (unsigned i = 0; i < batchLvls; i++)
+    ret[i].push_back(i);
+
+  for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
+    ret.back().push_back(i);
+
+  return ret;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1287,9 +1294,10 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
                                : op.getLevels()[fIdx];
             // TODO: handle batch.
             TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
-            if (mem.getType().getRank() > 1) {
-              // Flattens the buffer to rank 1.
-              auto reassoc = getReassociationForFlattening(mem.getType());
+            if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
+              // Flattens the buffer to batchLvlRank.
+              auto reassoc = getReassociationForFlattening(
+                  mem.getType(), stt.getBatchLvlRank());
               mem = rewriter.create<memref::CastOp>(
                   loc, fType,
                   rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
@@ -1325,11 +1333,17 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
       // Sets up the memory size by reading the last value in position array.
       LevelType lt = stt.getLvlType(lvl);
       // Simply forwards the position index when this is a dense level.
-      if (isDenseLT(lt)) {
+      if (lt.isa<LevelFormat::Dense>()) {
         memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
         posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
         continue;
       }
+      if (lt.isa<LevelFormat::Batch>()) {
+        // Skips batch levels as it is not linearized.
+        // FIXME: this assumes that every batch has the same number of nse, need
+        // to be generalized to handle varied-size batches.
+        continue;
+      }
 
       if (isWithPosLT(lt)) {
         assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
@@ -1343,7 +1357,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
         }
         desc.setPosMemSize(rewriter, loc, lvl, memSize);
         // The last value in position array is the memory size for next level.
-        memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
+        // FIXME: this assumes that every batch has the same number of nse, need
+        // to be generalized to handle varied-size batches.
+        SmallVector<Value> batched(stt.getBatchLvlRank(),
+                                   constantIndex(rewriter, loc, 0));
+        batched.push_back(posBack);
+        memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
         posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
       }
       assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
@@ -1413,8 +1432,9 @@ struct SparseDisassembleOpConverter
         retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
       }
       Value flatOut = dst;
-      if (dst.getType().getRank() != 1) {
-        auto reassoc = getReassociationForFlattening(dst.getType());
+      if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
+        auto reassoc =
+            getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
         flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
       }
       Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6ff21468e05764..5150615af180c8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1221,7 +1221,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     }
 
     Value vals = loopEmitter.getValBuffer()[0];
-    Value pos = loopEmitter.getValPosits(0);
+    SmallVector<Value> pos = loopEmitter.getValPosits(0);
     // Loads the value from sparse tensor using position-index;
     // loads the value from dense tensor using coords.
     Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 8f2ae60b311f7c..1fb70ed5035c03 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -86,7 +86,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
   case AffineExprKind::Add:
   case AffineExprKind::Mul:
   case AffineExprKind::Constant: {
-    assert(isDenseLT(lt));
+    assert(lt.hasDenseSemantic());
     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
       // We do not set dim level format for affine expression like d0 + d1 on
       // either loop index at d0 or d1. We continue the recursion merely to
@@ -211,7 +211,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
          "AffineMap does not have dimension-rank many results");
   unsigned num = 0;
   for (Level l = 0; l < lvlRank; l++) {
-    if (!isa<AffineDimExpr>(exprs[l]) && !stt.isDenseLvl(l))
+    if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
       num++;
   }
   return num;
@@ -355,8 +355,8 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
   if (stt.hasEncoding()) {
     // For sparse tensors we only push the last-level's position onto `args`.
     const auto pos = env.emitter().getValPosits(tid);
-    assert(pos);
-    args.push_back(pos);
+    assert(!pos.empty());
+    args.append(pos);
   } else {
     // For dense tensors we push all level's coordinates onto `args`.
     const Level lvlRank = stt.getLvlRank();
@@ -801,7 +801,7 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
     // `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
     // should be consistent with the LT indexed by <TensorId, Level>.
     const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
-    return isCompressedLT(lt) || isSingletonLT(lt);
+    return lt.hasSparseSemantic();
   });
   return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
 }
@@ -890,15 +890,14 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
         }
         assert(curr == env.merger().loop(b));
         Value clause;
-        if (isCompressedLT(lt) || isSingletonLT(lt) ||
-            isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
+        if (lt.hasSparseSemantic()) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoord(tid, *lvl);
           const Value lvar = env.getLoopVar(curr);
           clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                                  crd, lvar);
         } else {
-          assert(isDenseLT(lt) || isUndefLT(lt));
+          assert(lt.hasDenseSemantic() || isUndefLT(lt));
           clause = constantI1(builder, loc, true);
         }
         cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
@@ -988,7 +987,7 @@ static bool getAllTidLvlsInLatPoints(
           hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
           callback(env.makeTensorLevel(tid, *lvl), nullptr);
           numloopCond++;
-        } else if (isDenseLT(lt) || isIdxReduc) {
+        } else if (lt.hasDenseSemantic() || isIdxReduc) {
           callback(env.makeTensorLevel(tid, *lvl), nullptr);
         } else {
           assert(isUndefLT(lt));
@@ -1010,7 +1009,8 @@ static bool getAllTidLvlsInLatPoints(
             AffineExpr exp = affines[l];
             // Skip simple affine expression and non-dense levels (which
             // have their own filter loop).
-            if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
+            LevelType lt = stt.getLvlType(l);
+            if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
               continue;
 
             // Constant affine expression are handled in genLoop.
@@ -1103,7 +1103,8 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
     assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
     for (Level l = startLvl; l < lvlRank; l++) {
       AffineExpr lvlExpr = lvlExprs[l];
-      if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+      if (enc.getLvlType(l).hasDenseSemantic() &&
+          isa<AffineConstantExpr>(lvlExpr))
         env.emitter().locateLvlAtAffineAddress(
             builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
       else
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index fa570159ba41ca..89af75dea2a0f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -175,7 +175,7 @@ Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
 }
 
 Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
-                                  Value s) {
+                                  ValueRange s) {
   Value load = builder.create<memref::LoadOp>(loc, mem, s);
   if (!isa<IndexType>(load.getType())) {
     if (load.getType().getIntOrFloatBitWidth() < 64)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index e8f6bd1c5eaeb1..ce5831d999e9a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -149,7 +149,7 @@ Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
 /// Generates a pointer/index load from the sparse storage scheme. Narrower
 /// data types need to be zero extended before casting the value into the
 /// index type used for looping and indexing.
-Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);
+Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s);
 
 /// Generates a 1-valued attribute of the given type.  This supports
 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 7bfe713cdd9f74..b5a0ac8484abdd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -220,9 +220,11 @@ class LoopEmitter {
   ///
   /// Getters.
   ///
-  Value getValPosits(TensorId tid) const {
+  SmallVector<Value> getValPosits(TensorId tid) const {
+    SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
     Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
-    return lastLvlPos;
+    batchCrds.push_back(lastLvlPos);
+    return batchCrds;
   };
   Value getCoord(TensorId tid, Level lvl) const {
     return getCurIterator(tid, lvl).getCrd();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 8edacaa9981ef8..a456c87445eafc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -52,8 +52,11 @@ class SparseLevel : public SparseTensorLevel {
               Value crdBuffer)
       : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
 
-  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
-    return genIndexLoad(b, l, crdBuffer, iv);
+  Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                  Value iv) const override {
+    SmallVector<Value> memCrd(batchPrefix);
+    memCrd.push_back(iv);
+    return genIndexLoad(b, l, crdBuffer, memCrd);
   }
 
 protected:
@@ -62,26 +65,35 @@ class SparseLevel : public SparseTensorLevel {
 
 class DenseLevel : public SparseTensorLevel {
 public:
-  DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
-      : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
-        encoded(encoded) {}
+  DenseLevel(unsigned tid, Level lvl, Value lvlSize)
+      : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
 
-  Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
-    return pos;
+  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
+    llvm_unreachable("locate dense level instead");
   }
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
+                        Value max) const override {
+    Value posLo = MULI(p, lvlSize);
+    return {posLo, lvlSize};
+  }
+};
+
+class BatchLevel : public SparseTensorLevel {
+public:
+  BatchLevel(unsigned tid, Level lvl, Value lvlSize)
+      : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
+
+  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
+    llvm_unreachable("locate dense level instead");
+  }
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
-    if (encoded) {
-      Value posLo = MULI(p, lvlSize);
-      return {posLo, lvlSize};
-    }
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
   }
-
-  const bool encoded;
 };
 
 class CompressedLevel : public SparseLevel {
@@ -90,14 +102,17 @@ class CompressedLevel : public SparseLevel {
                   Value posBuffer, Value crdBuffer)
       : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    if (max == nullptr) {
-      Value pLo = genIndexLoad(b, l, posBuffer, p);
-      Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-      return {pLo, pHi};
-    }
-    llvm_unreachable("compressed-nu should be the first non-unique level.");
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                        Value p, Value max) const override {
+    assert(max == nullptr &&
+           "compressed level must be the first non-unique level.");
+
+    SmallVector<Value> memCrd(batchPrefix);
+    memCrd.push_back(p);
+    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    memCrd.back() = ADDI(p, C_IDX(1));
+    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    return {pLo, pHi};
   }
 
 private:
@@ -110,12 +125,17 @@ class LooseCompressedLevel : public SparseLevel {
                        Value posBuffer, Value crdBuffer)
       : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    assert(max == nullptr && "loss compressed level can not be non-unique.");
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                        Value p, Value max) const override {
+    assert(max == nullptr &&
+           "loose-compressed level must be the first non-unique level.");
+    SmallVector<Value> memCrd(batchPrefix);
+
     p = MULI(p, C_IDX(2));
-    Value pLo = genIndexLoad(b, l, posBuffer, p);
-    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+    memCrd.push_back(p);
+    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    memCrd.back() = ADDI(p, C_IDX(1));
+    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
     return {pLo, pHi};
   }
 
@@ -129,8 +149,8 @@ class SingletonLevel : public SparseLevel {
                  Value crdBuffer)
       : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value segHi) const override {
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                        Value p, Value segHi) const override {
     if (segHi == nullptr)
       return {p, ADDI(p, C_IDX(1))};
 
@@ -145,8 +165,8 @@ class NOutOfMLevel : public SparseLevel {
                Value crdBuffer)
       : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                        Value p, Value max) const override {
     assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
@@ -225,7 +245,12 @@ class ConcreteIterator : public SparseIterator {
     return from->kind == IterKind::kTrivial;
   }
 
-  bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
+  bool isBatchIterator() const override {
+    return stl.getLT().isa<LevelFormat::Batch>();
+  }
+  bool randomAccessible() const override {
+    return stl.getLT().hasDenseSemantic();
+  };
   bool iteratableByFor() const override { return kind != IterKind::kDedup; };
   Value upperBound(OpBuilder &b, Location l) const override {
     return stl.getSize();
@@ -277,12 +302,19 @@ class TrivialIterator : public ConcreteIterator {
 
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
+
+    if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+      batchCrds.resize(stl.lvl + 1, nullptr);
+
     Value pos = C_IDX(0);
     Value hi = nullptr;
-    if (parent)
+    // If the parent iterator is a batch iterator, we also start from 0 (but
+    // on a different batch).
+    if (parent && !parent->isBatchIterator())
       std::tie(pos, hi) = parent->getCurPosition();
 
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi);
+    ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
     // Seek to the lowest position.
     seek(posLo);
   }
@@ -302,7 +334,7 @@ class TrivialIterator : public ConcreteIterator {
     if (randomAccessible()) {
       updateCrd(SUBI(getItPos(), posLo));
     } else {
-      updateCrd(stl.peekCrdAt(b, l, getItPos()));
+      updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
     }
     return getCrd();
   };
@@ -324,6 +356,11 @@ class TrivialIterator : public ConcreteIterator {
     // Seek to the linearized position.
     seek(ADDI(crd, posLo));
     updateCrd(crd);
+    if (isBatchIterator()) {
+      // If this is a batch iterator, also update the batch coordinate.
+      assert(batchCrds.size() > lvl);
+      batchCrds[lvl] = crd;
+    }
   }
 
   Value getItPos() const { return getCursor().front(); }
@@ -358,11 +395,14 @@ class DedupIterator : public ConcreteIterator {
 
     Value pos = C_IDX(0);
     Value hi = nullptr;
-    if (parent)
+    // If the parent iterator is a batch iterator, we also start from 0 (but
+    // on a different batch).
+    if (parent && !parent->isBatchIterator())
       std::tie(pos, hi) = parent->getCurPosition();
 
     Value posLo;
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi);
+    ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
 
     seek({posLo, genSegmentHigh(b, l, posLo)});
   }
@@ -384,7 +424,7 @@ class DedupIterator : public ConcreteIterator {
   }
 
   Value derefImpl(OpBuilder &b, Location l) override {
-    updateCrd(stl.peekCrdAt(b, l, getPos()));
+    updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
     return getCrd();
   };
 
@@ -440,6 +480,7 @@ class FilterIterator : public SparseIterator {
     return wrap->getCursorValTypes(b);
   }
 
+  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
   bool randomAccessible() const override { return wrap->randomAccessible(); };
   bool iteratableByFor() const override { return randomAccessible(); };
   Value upperBound(OpBuilder &b, Location l) const override { return size; };
@@ -576,6 +617,7 @@ class NonEmptySubSectIterator : public SparseIterator {
   ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
                                 TraverseBuilder builder) const;
 
+  bool isBatchIterator() const override { return delegate->isBatchIterator(); }
   bool randomAccessible() const override {
     return delegate->randomAccessible();
   };
@@ -689,6 +731,7 @@ class SubSectIterator : public SparseIterator {
     return ret;
   }
 
+  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
   bool randomAccessible() const override { return wrap->randomAccessible(); };
   bool iteratableByFor() const override { return randomAccessible(); };
   Value upperBound(OpBuilder &b, Location l) const override {
@@ -783,6 +826,9 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
     seek(begin->getResults());
     return;
   }
+  // Inherent batch coordinates from parents
+  if (p)
+    inherentBatch(*p);
   // TODO: support lowering to function call.
   return genInitImpl(b, l, p);
 }
@@ -825,6 +871,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) {
 }
 
 ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+  assert(!randomAccessible());
   if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
     std::string prefix = getDebugInterfacePrefix();
     Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
@@ -861,8 +908,8 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
           OpBuilder::InsertionGuard guard(b);
           // If in bound, load the next coordinates and check duplication.
           b.setInsertionPointToStart(ifInBound.thenBlock());
-          Value headCrd = stl.peekCrdAt(b, l, pos);
-          Value tailCrd = stl.peekCrdAt(b, l, ivs.front());
+          Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
+          Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
           Value isDup = CMPI(eq, headCrd, tailCrd);
           YIELD(isDup);
           // Else, the position is out of bound, yield false.
@@ -1277,9 +1324,9 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
 
   switch (lt.getLvlFmt()) {
   case LevelFormat::Dense:
-    return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
+    return std::make_unique<DenseLevel>(tid, lvl, sz);
   case LevelFormat::Batch:
-    llvm_unreachable("not implemented");
+    return std::make_unique<BatchLevel>(tid, lvl, sz);
   case LevelFormat::Compressed: {
     Value pos = b.create<ToPositionsOp>(l, t, lvl);
     Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
@@ -1307,7 +1354,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
                                        SparseEmitStrategy strategy) {
-  auto stl = std::make_unique<DenseLevel>(tid, lvl, sz, /*encoded=*/false);
+  auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
   auto it = std::make_unique<TrivialIterator>(*stl);
   it->setSparseEmitStrategy(strategy);
   return std::make_pair(std::move(stl), std::move(it));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index d1e94b790bea6b..9f92eecdf75cb6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -32,7 +32,8 @@ class SparseTensorLevel {
            std::to_string(lvl) + "]";
   }
 
-  virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
+  virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                          Value iv) const = 0;
 
   /// Peeks the lower and upper bound to *fully* traverse the level with
   /// the given position `p` that the immediate parent level is current at.
@@ -47,7 +48,8 @@ class SparseTensorLevel {
   ///
   /// `bound` is only used when the level is `non-unique` and deduplication is
   /// required. It specifies the max upper bound of the non-unique segment.
-  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l, Value p,
+  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
+                                              ValueRange batchPrefix, Value p,
                                               Value segHi = Value()) const = 0;
 
   Level getLevel() const { return lvl; }
@@ -89,7 +91,7 @@ class SparseIterator {
   SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
                  unsigned cursorValsCnt,
                  SmallVectorImpl<Value> &cursorValStorage)
-      : kind(kind), tid(tid), lvl(lvl), crd(nullptr),
+      : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
         cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){};
 
   SparseIterator(IterKind kind, unsigned cursorValsCnt,
@@ -119,6 +121,7 @@ class SparseIterator {
   virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
 
   Value getCrd() const { return crd; }
+  ValueRange getBatchCrds() const { return batchCrds; }
   ValueRange getCursor() const {
     return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
   };
@@ -135,6 +138,9 @@ class SparseIterator {
   // Iterator properties.
   //
 
+  // Whether the iterator is a iterator over a batch level.
+  virtual bool isBatchIterator() const = 0;
+
   // Whether the iterator support random access (i.e., support look up by
   // *coordinate*). A random access iterator must also traverses a dense space.
   virtual bool randomAccessible() const = 0;
@@ -243,12 +249,18 @@ class SparseIterator {
 
 protected:
   void updateCrd(Value crd) { this->crd = crd; }
+
   MutableArrayRef<Value> getMutCursorVals() {
     MutableArrayRef<Value> ref = cursorValsStorageRef;
     return ref.take_front(cursorValsCnt);
   }
 
+  void inherentBatch(const SparseIterator &parent) {
+    batchCrds = parent.batchCrds;
+  }
+
   SparseEmitStrategy emitStrategy;
+  SmallVector<Value> batchCrds;
 
 public:
   const IterKind kind;     // For LLVM-style RTTI.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_batch.mlir b/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
new file mode 100644
index 00000000000000..f6d2d0d4f76699
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#BCSR = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)}>
+
+// CHECK-LABEL:   func.func @main(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x4x2xf32, #sparse{{[0-9]*}}>) -> tensor<8x4x2xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<8x4x2xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xf32>
+// CHECK:           %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_6]] : memref<8x4x2xf32>
+// CHECK:           linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_10]] : memref<8x4x2xf32>)
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_1]] {
+// CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_1]] {
+// CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x?xindex>
+// CHECK:               %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : index
+// CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_14]]] : memref<8x?xindex>
+// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_1]] {
+// CHECK:                 %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<8x?xindex>
+// CHECK:                 %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<8x?xf32>
+// CHECK:                 %[[VAL_19:.*]] = arith.negf %[[VAL_18]] : f32
+// CHECK:                 memref.store %[[VAL_19]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], %[[VAL_17]]] : memref<8x4x2xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<8x4x2xf32>
+// CHECK:           return %[[VAL_20]] : tensor<8x4x2xf32>
+// CHECK:         }
+func.func @main(%arg0: tensor<8x4x2xf32, #BCSR>) -> tensor<8x4x2xf32> {
+  %0 = tensor.empty() : tensor<8x4x2xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  }
+  ins(%arg0 : tensor<8x4x2xf32, #BCSR>)
+  outs(%0 : tensor<8x4x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.negf %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<8x4x2xf32>
+  return %1 : tensor<8x4x2xf32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 6aba0ada947e10..6076c1fbe76f21 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -24,13 +24,13 @@
 // CHECK:                 "subsect<trivial<compressed[0,0]>>.not_end
 // CHECK:               } do {
 // CHECK:                 %[[D2:.*]] = "subsect<trivial<compressed[0,0]>>.deref"
-// CHECK:                 "trivial<dense[1,0]>.locate"(%{{.*}}, %[[D2]])
+// CHECK:                 "trivial<batch[1,0]>.locate"(%{{.*}}, %[[D2]])
 // CHECK:                 "subsect<trivial<compressed[0,1]>>.begin"
 // CHECK:                 scf.while {{.*}} {
 // CHECK:                   "subsect<trivial<compressed[0,1]>>.not_end"
 // CHECK:                 } do {
 // CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
-// CHECK:                   "trivial<dense[1,1]>.locate"(%{{.*}}, %[[D3]])
+// CHECK:                   "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
 // CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
 // CHECK:                   arith.muli
 // CHECK:                   arith.addi

>From c1c79547a7d49de8c3a8605c055a6d4cb7bd4b39 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 4 Mar 2024 22:08:08 +0000
Subject: [PATCH 2/2] address comments

---
 .../SparseTensor/Transforms/Utils/SparseTensorLevel.cpp  | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index a456c87445eafc..bc27fae5d19480 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -69,7 +69,7 @@ class DenseLevel : public SparseTensorLevel {
       : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
 
   Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
-    llvm_unreachable("locate dense level instead");
+    llvm_unreachable("locate random-accessible level instead");
   }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
@@ -85,7 +85,7 @@ class BatchLevel : public SparseTensorLevel {
       : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
 
   Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
-    llvm_unreachable("locate dense level instead");
+    llvm_unreachable("locate random-accessible level instead");
   }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
@@ -547,7 +547,8 @@ class NonEmptySubSectIterator : public SparseIterator {
       assert(p->lvl + 1 == lvl);
       maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
     }
-    // We don't need an extra buffer to find subsections on dense levels.
+    // We don't need an extra buffer to find subsections on random-accessible
+    // levels.
     if (randomAccessible())
       return;
     subSectPosBuf = allocSubSectPosBuf(b, l);
@@ -826,7 +827,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
     seek(begin->getResults());
     return;
   }
-  // Inherent batch coordinates from parents
+  // Inherent batch coordinates from parents.
   if (p)
     inherentBatch(*p);
   // TODO: support lowering to function call.



More information about the Mlir-commits mailing list