[Mlir-commits] [mlir] [mlir][sparse] support sparsifying batch levels (PR #83898)
Peiming Liu
llvmlistbot at llvm.org
Mon Mar 4 14:08:24 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/83898
>From b1a46bd05b32fbdef1b37f0dd1f22560abefb04c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 28 Feb 2024 20:05:49 +0000
Subject: [PATCH 1/2] [mlir][sparse] support sparsifying batch levels
---
.../IR/SparseTensorStorageLayout.h | 5 +-
.../SparseTensor/IR/SparseTensorType.h | 5 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 4 +-
.../Transforms/SparseAssembler.cpp | 6 +-
.../Transforms/SparseTensorCodegen.cpp | 44 ++++--
.../Transforms/SparseTensorRewriting.cpp | 2 +-
.../Transforms/Sparsification.cpp | 23 +--
.../Transforms/Utils/CodegenUtils.cpp | 2 +-
.../Transforms/Utils/CodegenUtils.h | 2 +-
.../Transforms/Utils/LoopEmitter.h | 6 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 133 ++++++++++++------
.../Transforms/Utils/SparseTensorLevel.h | 18 ++-
.../Dialect/SparseTensor/sparse_batch.mlir | 48 +++++++
.../sparse_conv_2d_slice_based.mlir | 4 +-
14 files changed, 221 insertions(+), 81 deletions(-)
create mode 100644 mlir/test/Dialect/SparseTensor/sparse_batch.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index ce34ae43d1c181..7aa9cb6119434b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -42,7 +42,10 @@ namespace sparse_tensor {
///
/// struct sparse_tensor.storage_specifier {
/// array<rank x int> lvlSizes ; sizes/cardinalities for each level
-/// array<n x int> memSizes; ; sizes/lengths for each data memref
+/// // TODO: memSizes need to be expanded to array<[batch] x n x int> to
+/// // support different sizes for different batches. At the moment, we
+/// // assume that every batch occupies the same memory size.
+/// array<n x int> memSizes ; sizes/lengths for each data memref
/// }
/// };
///
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index bd2c3c1dd55159..beb1dcce9c15c0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,7 +253,10 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}
- /// Returns the Level-shape.
+ /// Returns the batched level rank.
+ unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); }
+
+ /// Returns the batched Level-shape.
SmallVector<Size> getBatchLvlShape() const {
auto lvlShape = getEncoding().translateShape(
getDimShape(), CrdTransDirectionKind::dim2lvl);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 244a082d04870e..6ba8b46370b038 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -374,7 +374,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
if (!getImpl())
- return LevelFormat::Dense;
+ return LevelFormat::Batch;
assert(l < getLvlRank() && "Level is out of bounds");
return getLvlTypes()[l];
}
@@ -1755,6 +1755,8 @@ LogicalResult ConcatenateOp::verify() {
LogicalResult InsertOp::verify() {
const auto stt = getSparseTensorType(getTensor());
+ if (stt.getEncoding().getBatchLvlRank() > 0)
+ return emitOpError("batched sparse tensor insertion not implemented");
if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
return emitOpError("incorrect number of coordinates");
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index cd6b9b49893731..b39a2d9c57d8b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,7 +33,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
// Convert the external representation of the values array.
const SparseTensorType stt(cast<RankedTensorType>(type));
- auto shape = {ShapedType::kDynamic};
+ auto shape = stt.getBatchLvlShape();
+ shape.push_back(ShapedType::kDynamic);
auto vtp = RankedTensorType::get(shape, stt.getElementType());
convTypes.push_back(vtp);
if (extraTypes)
@@ -72,7 +73,8 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
// Convert the external representation of the values array.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
- auto shape = {ShapedType::kDynamic};
+ auto shape = stt.getBatchLvlShape();
+ shape.push_back(ShapedType::kDynamic);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 4e3393195813c3..5da8a60d2d5fb0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -429,11 +429,18 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
}
/// Creates the reassociation array.
-static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
- ReassociationIndices reassociation;
- for (int i = 0, e = srcTp.getRank(); i < e; i++)
- reassociation.push_back(i);
- return reassociation;
+static SmallVector<ReassociationIndices>
+getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
+ SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
+ // Create reassociation in form:
+ // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
+ for (unsigned i = 0; i < batchLvls; i++)
+ ret[i].push_back(i);
+
+ for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
+ ret.back().push_back(i);
+
+ return ret;
}
//===----------------------------------------------------------------------===//
@@ -1287,9 +1294,10 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
: op.getLevels()[fIdx];
// TODO: handle batch.
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
- if (mem.getType().getRank() > 1) {
- // Flattens the buffer to rank 1.
- auto reassoc = getReassociationForFlattening(mem.getType());
+ if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
+ // Flattens the buffer to batchLvlRank.
+ auto reassoc = getReassociationForFlattening(
+ mem.getType(), stt.getBatchLvlRank());
mem = rewriter.create<memref::CastOp>(
loc, fType,
rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
@@ -1325,11 +1333,17 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
// Sets up the memory size by reading the last value in position array.
LevelType lt = stt.getLvlType(lvl);
// Simply forwards the position index when this is a dense level.
- if (isDenseLT(lt)) {
+ if (lt.isa<LevelFormat::Dense>()) {
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
continue;
}
+ if (lt.isa<LevelFormat::Batch>()) {
+ // Skips batch levels as it is not linearized.
+ // FIXME: this assumes that every batch has the same number of nse, need
+ // to be generalized to handle varied-size batches.
+ continue;
+ }
if (isWithPosLT(lt)) {
assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
@@ -1343,7 +1357,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setPosMemSize(rewriter, loc, lvl, memSize);
// The last value in position array is the memory size for next level.
- memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
+ // FIXME: this assumes that every batch has the same number of nse, need
+ // to be generalized to handle varied-size batches.
+ SmallVector<Value> batched(stt.getBatchLvlRank(),
+ constantIndex(rewriter, loc, 0));
+ batched.push_back(posBack);
+ memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
}
assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
@@ -1413,8 +1432,9 @@ struct SparseDisassembleOpConverter
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
- if (dst.getType().getRank() != 1) {
- auto reassoc = getReassociationForFlattening(dst.getType());
+ if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
+ auto reassoc =
+ getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
}
Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6ff21468e05764..5150615af180c8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1221,7 +1221,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
}
Value vals = loopEmitter.getValBuffer()[0];
- Value pos = loopEmitter.getValPosits(0);
+ SmallVector<Value> pos = loopEmitter.getValPosits(0);
// Loads the value from sparse tensor using position-index;
// loads the value from dense tensor using coords.
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 8f2ae60b311f7c..1fb70ed5035c03 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -86,7 +86,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Constant: {
- assert(isDenseLT(lt));
+ assert(lt.hasDenseSemantic());
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
// We do not set dim level format for affine expression like d0 + d1 on
// either loop index at d0 or d1. We continue the recursion merely to
@@ -211,7 +211,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
"AffineMap does not have dimension-rank many results");
unsigned num = 0;
for (Level l = 0; l < lvlRank; l++) {
- if (!isa<AffineDimExpr>(exprs[l]) && !stt.isDenseLvl(l))
+ if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
num++;
}
return num;
@@ -355,8 +355,8 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
if (stt.hasEncoding()) {
// For sparse tensors we only push the last-level's position onto `args`.
const auto pos = env.emitter().getValPosits(tid);
- assert(pos);
- args.push_back(pos);
+ assert(!pos.empty());
+ args.append(pos);
} else {
// For dense tensors we push all level's coordinates onto `args`.
const Level lvlRank = stt.getLvlRank();
@@ -801,7 +801,7 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
// `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
// should be consistent with the LT indexed by <TensorId, Level>.
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
- return isCompressedLT(lt) || isSingletonLT(lt);
+ return lt.hasSparseSemantic();
});
return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
}
@@ -890,15 +890,14 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
}
assert(curr == env.merger().loop(b));
Value clause;
- if (isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
+ if (lt.hasSparseSemantic()) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
crd, lvar);
} else {
- assert(isDenseLT(lt) || isUndefLT(lt));
+ assert(lt.hasDenseSemantic() || isUndefLT(lt));
clause = constantI1(builder, loc, true);
}
cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
@@ -988,7 +987,7 @@ static bool getAllTidLvlsInLatPoints(
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
- } else if (isDenseLT(lt) || isIdxReduc) {
+ } else if (lt.hasDenseSemantic() || isIdxReduc) {
callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
@@ -1010,7 +1009,8 @@ static bool getAllTidLvlsInLatPoints(
AffineExpr exp = affines[l];
// Skip simple affine expression and non-dense levels (which
// have their own filter loop).
- if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
+ LevelType lt = stt.getLvlType(l);
+ if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
continue;
// Constant affine expression are handled in genLoop.
@@ -1103,7 +1103,8 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
- if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ if (enc.getLvlType(l).hasDenseSemantic() &&
+ isa<AffineConstantExpr>(lvlExpr))
env.emitter().locateLvlAtAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index fa570159ba41ca..89af75dea2a0f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -175,7 +175,7 @@ Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
}
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
- Value s) {
+ ValueRange s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
if (!isa<IndexType>(load.getType())) {
if (load.getType().getIntOrFloatBitWidth() < 64)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index e8f6bd1c5eaeb1..ce5831d999e9a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -149,7 +149,7 @@ Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
-Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);
+Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s);
/// Generates a 1-valued attribute of the given type. This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 7bfe713cdd9f74..b5a0ac8484abdd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -220,9 +220,11 @@ class LoopEmitter {
///
/// Getters.
///
- Value getValPosits(TensorId tid) const {
+ SmallVector<Value> getValPosits(TensorId tid) const {
+ SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
- return lastLvlPos;
+ batchCrds.push_back(lastLvlPos);
+ return batchCrds;
};
Value getCoord(TensorId tid, Level lvl) const {
return getCurIterator(tid, lvl).getCrd();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 8edacaa9981ef8..a456c87445eafc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -52,8 +52,11 @@ class SparseLevel : public SparseTensorLevel {
Value crdBuffer)
: SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
- Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
- return genIndexLoad(b, l, crdBuffer, iv);
+ Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value iv) const override {
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(iv);
+ return genIndexLoad(b, l, crdBuffer, memCrd);
}
protected:
@@ -62,26 +65,35 @@ class SparseLevel : public SparseTensorLevel {
class DenseLevel : public SparseTensorLevel {
public:
- DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
- : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
- encoded(encoded) {}
+ DenseLevel(unsigned tid, Level lvl, Value lvlSize)
+ : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
- Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
- return pos;
+ Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
+ llvm_unreachable("locate dense level instead");
}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
+ Value max) const override {
+ Value posLo = MULI(p, lvlSize);
+ return {posLo, lvlSize};
+ }
+};
+
+class BatchLevel : public SparseTensorLevel {
+public:
+ BatchLevel(unsigned tid, Level lvl, Value lvlSize)
+ : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
+
+ Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
+ llvm_unreachable("locate dense level instead");
+ }
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
- if (encoded) {
- Value posLo = MULI(p, lvlSize);
- return {posLo, lvlSize};
- }
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
}
-
- const bool encoded;
};
class CompressedLevel : public SparseLevel {
@@ -90,14 +102,17 @@ class CompressedLevel : public SparseLevel {
Value posBuffer, Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr) {
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
- llvm_unreachable("compressed-nu should be the first non-unique level.");
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value p, Value max) const override {
+ assert(max == nullptr &&
+ "compressed level must be the first non-unique level.");
+
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ return {pLo, pHi};
}
private:
@@ -110,12 +125,17 @@ class LooseCompressedLevel : public SparseLevel {
Value posBuffer, Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- assert(max == nullptr && "loss compressed level can not be non-unique.");
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value p, Value max) const override {
+ assert(max == nullptr &&
+ "loose-compressed level must be the first non-unique level.");
+ SmallVector<Value> memCrd(batchPrefix);
+
p = MULI(p, C_IDX(2));
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
return {pLo, pHi};
}
@@ -129,8 +149,8 @@ class SingletonLevel : public SparseLevel {
Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value segHi) const override {
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value p, Value segHi) const override {
if (segHi == nullptr)
return {p, ADDI(p, C_IDX(1))};
@@ -145,8 +165,8 @@ class NOutOfMLevel : public SparseLevel {
Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value p, Value max) const override {
assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
@@ -225,7 +245,12 @@ class ConcreteIterator : public SparseIterator {
return from->kind == IterKind::kTrivial;
}
- bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
+ bool isBatchIterator() const override {
+ return stl.getLT().isa<LevelFormat::Batch>();
+ }
+ bool randomAccessible() const override {
+ return stl.getLT().hasDenseSemantic();
+ };
bool iteratableByFor() const override { return kind != IterKind::kDedup; };
Value upperBound(OpBuilder &b, Location l) const override {
return stl.getSize();
@@ -277,12 +302,19 @@ class TrivialIterator : public ConcreteIterator {
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
+
+ if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+ batchCrds.resize(stl.lvl + 1, nullptr);
+
Value pos = C_IDX(0);
Value hi = nullptr;
- if (parent)
+ // If the parent iterator is a batch iterator, we also start from 0 (but
+ // on a different batch).
+ if (parent && !parent->isBatchIterator())
std::tie(pos, hi) = parent->getCurPosition();
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi);
+ ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
// Seek to the lowest position.
seek(posLo);
}
@@ -302,7 +334,7 @@ class TrivialIterator : public ConcreteIterator {
if (randomAccessible()) {
updateCrd(SUBI(getItPos(), posLo));
} else {
- updateCrd(stl.peekCrdAt(b, l, getItPos()));
+ updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
}
return getCrd();
};
@@ -324,6 +356,11 @@ class TrivialIterator : public ConcreteIterator {
// Seek to the linearized position.
seek(ADDI(crd, posLo));
updateCrd(crd);
+ if (isBatchIterator()) {
+ // If this is a batch iterator, also update the batch coordinate.
+ assert(batchCrds.size() > lvl);
+ batchCrds[lvl] = crd;
+ }
}
Value getItPos() const { return getCursor().front(); }
@@ -358,11 +395,14 @@ class DedupIterator : public ConcreteIterator {
Value pos = C_IDX(0);
Value hi = nullptr;
- if (parent)
+ // If the parent iterator is a batch iterator, we also start from 0 (but
+ // on a different batch).
+ if (parent && !parent->isBatchIterator())
std::tie(pos, hi) = parent->getCurPosition();
Value posLo;
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi);
+ ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
seek({posLo, genSegmentHigh(b, l, posLo)});
}
@@ -384,7 +424,7 @@ class DedupIterator : public ConcreteIterator {
}
Value derefImpl(OpBuilder &b, Location l) override {
- updateCrd(stl.peekCrdAt(b, l, getPos()));
+ updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
return getCrd();
};
@@ -440,6 +480,7 @@ class FilterIterator : public SparseIterator {
return wrap->getCursorValTypes(b);
}
+ bool isBatchIterator() const override { return wrap->isBatchIterator(); }
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override { return size; };
@@ -576,6 +617,7 @@ class NonEmptySubSectIterator : public SparseIterator {
ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
TraverseBuilder builder) const;
+ bool isBatchIterator() const override { return delegate->isBatchIterator(); }
bool randomAccessible() const override {
return delegate->randomAccessible();
};
@@ -689,6 +731,7 @@ class SubSectIterator : public SparseIterator {
return ret;
}
+ bool isBatchIterator() const override { return wrap->isBatchIterator(); }
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override {
@@ -783,6 +826,9 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
seek(begin->getResults());
return;
}
+ // Inherent batch coordinates from parents
+ if (p)
+ inherentBatch(*p);
// TODO: support lowering to function call.
return genInitImpl(b, l, p);
}
@@ -825,6 +871,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) {
}
ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+ assert(!randomAccessible());
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
@@ -861,8 +908,8 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
OpBuilder::InsertionGuard guard(b);
// If in bound, load the next coordinates and check duplication.
b.setInsertionPointToStart(ifInBound.thenBlock());
- Value headCrd = stl.peekCrdAt(b, l, pos);
- Value tailCrd = stl.peekCrdAt(b, l, ivs.front());
+ Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
+ Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
Value isDup = CMPI(eq, headCrd, tailCrd);
YIELD(isDup);
// Else, the position is out of bound, yield false.
@@ -1277,9 +1324,9 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
- return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
+ return std::make_unique<DenseLevel>(tid, lvl, sz);
case LevelFormat::Batch:
- llvm_unreachable("not implemented");
+ return std::make_unique<BatchLevel>(tid, lvl, sz);
case LevelFormat::Compressed: {
Value pos = b.create<ToPositionsOp>(l, t, lvl);
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
@@ -1307,7 +1354,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
SparseEmitStrategy strategy) {
- auto stl = std::make_unique<DenseLevel>(tid, lvl, sz, /*encoded=*/false);
+ auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
auto it = std::make_unique<TrivialIterator>(*stl);
it->setSparseEmitStrategy(strategy);
return std::make_pair(std::move(stl), std::move(it));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index d1e94b790bea6b..9f92eecdf75cb6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -32,7 +32,8 @@ class SparseTensorLevel {
std::to_string(lvl) + "]";
}
- virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
+ virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value iv) const = 0;
/// Peeks the lower and upper bound to *fully* traverse the level with
/// the given position `p` that the immediate parent level is current at.
@@ -47,7 +48,8 @@ class SparseTensorLevel {
///
/// `bound` is only used when the level is `non-unique` and deduplication is
/// required. It specifies the max upper bound of the non-unique segment.
- virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l, Value p,
+ virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
+ ValueRange batchPrefix, Value p,
Value segHi = Value()) const = 0;
Level getLevel() const { return lvl; }
@@ -89,7 +91,7 @@ class SparseIterator {
SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
unsigned cursorValsCnt,
SmallVectorImpl<Value> &cursorValStorage)
- : kind(kind), tid(tid), lvl(lvl), crd(nullptr),
+ : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){};
SparseIterator(IterKind kind, unsigned cursorValsCnt,
@@ -119,6 +121,7 @@ class SparseIterator {
virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
Value getCrd() const { return crd; }
+ ValueRange getBatchCrds() const { return batchCrds; }
ValueRange getCursor() const {
return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
};
@@ -135,6 +138,9 @@ class SparseIterator {
// Iterator properties.
//
+ // Whether the iterator is a iterator over a batch level.
+ virtual bool isBatchIterator() const = 0;
+
// Whether the iterator support random access (i.e., support look up by
// *coordinate*). A random access iterator must also traverses a dense space.
virtual bool randomAccessible() const = 0;
@@ -243,12 +249,18 @@ class SparseIterator {
protected:
void updateCrd(Value crd) { this->crd = crd; }
+
MutableArrayRef<Value> getMutCursorVals() {
MutableArrayRef<Value> ref = cursorValsStorageRef;
return ref.take_front(cursorValsCnt);
}
+ void inherentBatch(const SparseIterator &parent) {
+ batchCrds = parent.batchCrds;
+ }
+
SparseEmitStrategy emitStrategy;
+ SmallVector<Value> batchCrds;
public:
const IterKind kind; // For LLVM-style RTTI.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_batch.mlir b/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
new file mode 100644
index 00000000000000..f6d2d0d4f76699
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_batch.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#BCSR = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)}>
+
+// CHECK-LABEL: func.func @main(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x4x2xf32, #sparse{{[0-9]*}}>) -> tensor<8x4x2xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<8x4x2xf32>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x4x2xf32, #sparse{{[0-9]*}}> to memref<8x?xf32>
+// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_6]] : memref<8x4x2xf32>
+// CHECK: linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_10]] : memref<8x4x2xf32>)
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_1]] {
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_1]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x?xindex>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : index
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_14]]] : memref<8x?xindex>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_1]] {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<8x?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<8x?xf32>
+// CHECK: %[[VAL_19:.*]] = arith.negf %[[VAL_18]] : f32
+// CHECK: memref.store %[[VAL_19]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]], %[[VAL_17]]] : memref<8x4x2xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<8x4x2xf32>
+// CHECK: return %[[VAL_20]] : tensor<8x4x2xf32>
+// CHECK: }
+func.func @main(%arg0: tensor<8x4x2xf32, #BCSR>) -> tensor<8x4x2xf32> {
+ %0 = tensor.empty() : tensor<8x4x2xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ ins(%arg0 : tensor<8x4x2xf32, #BCSR>)
+ outs(%0 : tensor<8x4x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.negf %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<8x4x2xf32>
+ return %1 : tensor<8x4x2xf32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 6aba0ada947e10..6076c1fbe76f21 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -24,13 +24,13 @@
// CHECK: "subsect<trivial<compressed[0,0]>>.not_end
// CHECK: } do {
// CHECK: %[[D2:.*]] = "subsect<trivial<compressed[0,0]>>.deref"
-// CHECK: "trivial<dense[1,0]>.locate"(%{{.*}}, %[[D2]])
+// CHECK: "trivial<batch[1,0]>.locate"(%{{.*}}, %[[D2]])
// CHECK: "subsect<trivial<compressed[0,1]>>.begin"
// CHECK: scf.while {{.*}} {
// CHECK: "subsect<trivial<compressed[0,1]>>.not_end"
// CHECK: } do {
// CHECK: %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
-// CHECK: "trivial<dense[1,1]>.locate"(%{{.*}}, %[[D3]])
+// CHECK: "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
// CHECK: tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
// CHECK: arith.muli
// CHECK: arith.addi
>From c1c79547a7d49de8c3a8605c055a6d4cb7bd4b39 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 4 Mar 2024 22:08:08 +0000
Subject: [PATCH 2/2] address comments
---
.../SparseTensor/Transforms/Utils/SparseTensorLevel.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index a456c87445eafc..bc27fae5d19480 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -69,7 +69,7 @@ class DenseLevel : public SparseTensorLevel {
: SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
- llvm_unreachable("locate dense level instead");
+ llvm_unreachable("locate random-accessible level instead");
}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
@@ -85,7 +85,7 @@ class BatchLevel : public SparseTensorLevel {
: SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
- llvm_unreachable("locate dense level instead");
+ llvm_unreachable("locate random-accessible level instead");
}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
@@ -547,7 +547,8 @@ class NonEmptySubSectIterator : public SparseIterator {
assert(p->lvl + 1 == lvl);
maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
}
- // We don't need an extra buffer to find subsections on dense levels.
+ // We don't need an extra buffer to find subsections on random-accessible
+ // levels.
if (randomAccessible())
return;
subSectPosBuf = allocSubSectPosBuf(b, l);
@@ -826,7 +827,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
seek(begin->getResults());
return;
}
- // Inherent batch coordinates from parents
+ // Inherent batch coordinates from parents.
if (p)
inherentBatch(*p);
// TODO: support lowering to function call.
More information about the Mlir-commits
mailing list